pytorch transformers ....

huggingface模型下载

transformers的预训练模型下载到本地特定位置,默认是在~/.cache/huggingface/transformers

1
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir="...")

想知道transformers的模型都是什么结构的,比如bert模型:

1
transformers/models/bert/__init__.py

这里可以看到导入了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from .modeling_bert import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMHeadModel,
BertModel,
BertPreTrainedModel,
load_tf_weights_in_bert,
)

然后点进去就可以看了,可以看他们的forward函数等

Trainer

Trainer提供了训练、验证、预测的功能,可以通过继承Trainer并覆写其中一些方法来自定义。

  • compute_loss()

计算损失的函数,compute_loss(self, model, inputs, return_outputs=False),该函数执行forward和loss计算。

  • 参数 compute_metrics(eval_pred: EvalPrediction)

该参数指定的函数在tranier.evaluate()时会调用,该函数的参数是EvalPrediction类型,必须返回一个字典,类型是string-> value。比如

1
2
3
4
{
'accuracy': 0.98,
'sensitivity': 0.65
}

EvalPrediction包括self.predictionsself.label_ids,以及可能有的self.inputs

  • 参数metric_for_best_model

这是个字符串,并且必须是compute_metrics返回的字典的一个key,对应上例就是只能是'accuracy'或者'sensitivity'


pytorch transformers ....
https://jcdu.top/2022/05/19/pytorch transformers/
作者
horizon86
发布于
2022年5月19日
许可协议