当前位置: 代码迷 >> 综合 >> Huggingface-4.8.2自定义训练
  详细解决方案

Huggingface-4.8.2自定义训练

热度:88   发布时间:2023-12-24 13:38:24.0

Huggingface走到4.8.2这个版本,已经有了很好的封装。训练一个语言网络只需要调用Trainer.train(...)即可完成。如果要根据自己的需求修改训练的过程,比如自定义loss,输出梯度,直接修改huggingface的源码显然是不可取的了。好在huggingface提供了相应的接口,让我们可以深入到训练过程中,加入自定义的内容。根据官方的教程,有两种推荐的方法:

  1. 重载trainer中的方法,将其修改为我们需要的内容。比如trainer.compute_loss()这个函数,它定义了如何计算loss,我们只需要修改其中的逻辑,就可以自定义loss的计算。
  2. 使用callbacks。callbacks可以查看训练过程中一些关键变量的值,并根据其状态做出相应的决策,比如early stop。

关于trainer和callbacks这两个的官方文档分别是这里和这里,这两个方法都可以很优雅地修改原有的逻辑。但个人感觉重载trainer的方法是一种更灵活也更强大的方法。callbacks其实只能查看提供的一些变量,并且也只是查看,不能做出修改。而重载方法可以定义任意的全新的函数。接下来给出这两种方法的两个例子。

重载方法

在官方给的教程中是一个重载compute loss的例子,这里给一个不一样的,定义trainging_step的例子,代码如下:

class PrintGradientTrainer(Trainer):def training_step(self, model, inputs):model.train()inputs = self._prepare_inputs(inputs)loss = self.compute_loss(model, inputs)loss.backward()# ------------------------new added codes.--------------------------for name, param in model.named_parameters():if param.requires_grad:if param.grad is not None:print("{}, gradient: {}".format(name, param.grad.mean()))else:print("{} has not gradient".format(name))# ------------------------new added codes.--------------------------return loss.detach()# originally the Trainer() is called
#trainer = Trainer(
#    model=model, args=training_args, train_dataset=small_train_dataset, #eval_dataset=small_eval_dataset,
#    tokenizer=tokenizer, data_collator=data_collator
#)# Now call the new defined PrintGradientTrainer()
trainer = PrintGradientTrainer(model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,tokenizer=tokenizer, data_collator=data_collator
)trainer.train()

只给出了关键部分的代码,其他的就按照正常写即可。

Callbacks

这个方法也需要定义一个原本的TrainerCallback的子类,然后重载原有的空的callbacks方法。代码实例如下,这个例子打出了现在是第几个epoch。

class MyCallback(TrainerCallback):def on_step_begin(self, args, state, control, **kwargs):print("train step start")control.should_log = Falsecontrol.should_evaluate = Falsecontrol.should_save = Falseprint('---------------------------------------',state.epoch)# return self.call_event("on_step_begin", args, state, control)
trainer = PrintGradientTrainer(model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,tokenizer=tokenizer, data_collator=data_collator,callbacks=[MyCallback()]
)

在定义trainer的时候,给callbacks加入自己定义的类就可以了。