2021-07-30-Allennlp加载本地预训练权重文件
Allennlp的文档是又稀少又杂乱,最有效的方式还是直接看源码(Allennlp版本为0.9.0)。
如果想直接导入bert-base-chinese
模型,那如下所示的配置方案是可行的,但是使用其他的预训练模型,比如哈工大的hfl/chinese-bert-wwm-ext
, 会报错查无此模型,还是因为allennlp的版本过低。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"model": {
"type": "rewrite",
"word_embedder": {
"bert": {
"type": "bert-pretrained",
"pretrained_model": "bert-base-chinese",
"top_layer_only": true,
"requires_grad": true
},
"allow_unmatched_keys": true,
"embedder_to_indexer_map": {
"bert": [
"bert",
"bert-offsets",
"bert-type-ids"
]
}
},
}
没办法,只能将模型下载到本地在加载了,模型下载链接:
https://github.com/ymcui/Chinese-BERT-wwm
查看allennlp导入bert模型对应的源码,会发现模型不资瓷我们下载的权重格式,需要将模型打包成archive格式。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
及时你有bert_config.json以及pytorch_model.bin也没有办法,人家不认啊。于是斗胆写了格式转换的代码,希望能给大家提供参考。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def save_model(save_dir: str,
weights: str,
config_name: str,
archive_path: str) -> None:
'''
save_dir: 对应包含权重的文件夹
weights: 权重的文件名,比如pytorch_model.bin
config_name: bert模型的配置文件名,如bert_config.json
archive_path: 保存转换后格式的文件名
'''
weights_file = os.path.join(save_dir, weights)
if not os.path.exists(weights_file):
logger.error("weights file %s does not exist, unable to archive model", weights_file)
return
config_file = os.path.join(save_dir, config_name)
if not os.path.exists(config_file):
logger.error("config file %s does not exist, unable to archive model", config_file)
if archive_path is not None:
archive_file = archive_path
if os.path.isdir(archive_file):
archive_file = os.path.join(archive_file, "model.tar.gz")
else:
archive_file = os.path.join(save_dir, "model.tar.gz")
logger.info("archiving weights and vocabulary to %s", archive_file)
with tarfile.open(archive_file, 'w:gz') as archive:
archive.add(config_file, arcname='bert_config.json')
archive.add(weights_file, arcname='pytorch_model.bin')
使用示例:
1
2
3
4
save_model('chinese_wwm_ext_pytorch',
weights='pytorch_model.bin',
config_name='bert_config.json',
archive_path='chinese_wwm_ext_pytorch')