Task-adaptive Pretraining을 사용하면 성능향상에 도움이 된다는 연구결과가 있다. 학습할 데이터를 이용해 pretraining을 진행하는 것은 좋은 효과를 가지고 올 수 있음.
# train.py
def mlm_pretrain(config, model, loader, n_epochs, epoch, device):
model.train()
for step, batch in enumerate(loader):
input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
logits, labels = model.forward_pretrain(input_ids, tokenizer, config, device)
loss = loss_fnc_pretrain(logits.view(-1, config.vocab_size), labels.view(-1))
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
if step % 100 == 0:
print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(loader), loss.item()))
def forward_pretrain(self, input_ids, tokenizer, config, device):
input_ids, labels = self.mask_tokens(input_ids, tokenizer, config, device)
encoder_outputs, _ = self.encoder(input_ids=input_ids)
mlm_logits = self.mlm_head(encoder_outputs)
return mlm_logits, labels