Allennlp加载本地预训练权重文件

Posted by Weihao Zeng on July 30, 2021

2021-07-30-Allennlp加载本地预训练权重文件

Allennlp的文档是又稀少又杂乱,最有效的方式还是直接看源码(Allennlp版本为0.9.0)。

如果想直接导入bert-base-chinese模型,那如下所示的配置方案是可行的,但是使用其他的预训练模型,比如哈工大的hfl/chinese-bert-wwm-ext, 会报错查无此模型,还是因为allennlp的版本过低。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
	"model": {
		"type": "rewrite",
		"word_embedder": {
			"bert": {
				"type": "bert-pretrained",
				"pretrained_model": "bert-base-chinese",
				"top_layer_only": true,
				"requires_grad": true
			},
			"allow_unmatched_keys": true,
			"embedder_to_indexer_map": {
				"bert": [
					"bert",
					"bert-offsets",
					"bert-type-ids"
				]
			}
		},
		}

没办法,只能将模型下载到本地在加载了,模型下载链接:

https://github.com/ymcui/Chinese-BERT-wwm

查看allennlp导入bert模型对应的源码,会发现模型不资瓷我们下载的权重格式,需要将模型打包成archive格式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
        """
        Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
        Download and cache the pre-trained model file if needed.

        Params:
            pretrained_model_name_or_path: either:
                - a str with the name of a pre-trained model to load selected in the list of:
                    . `bert-base-uncased`
                    . `bert-large-uncased`
                    . `bert-base-cased`
                    . `bert-large-cased`
                    . `bert-base-multilingual-uncased`
                    . `bert-base-multilingual-cased`
                    . `bert-base-chinese`
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
                - a path or url to a pretrained model archive containing:
                    . `bert_config.json` a configuration file for the model
                    . `model.chkpt` a TensorFlow checkpoint
            from_tf: should we load the weights from a locally saved TensorFlow checkpoint
            cache_dir: an optional path to a folder in which the pre-trained models will be cached.
            state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
            *inputs, **kwargs: additional input for the specific Bert class
                (ex: num_labels for BertForSequenceClassification)
        """

及时你有bert_config.json以及pytorch_model.bin也没有办法,人家不认啊。于是斗胆写了格式转换的代码,希望能给大家提供参考。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def save_model(save_dir: str,
               weights: str,
               config_name: str,
               archive_path: str) -> None:
    '''
    save_dir: 对应包含权重的文件夹
    weights: 权重的文件名,比如pytorch_model.bin
    config_name: bert模型的配置文件名,如bert_config.json
    archive_path: 保存转换后格式的文件名
    '''
    weights_file = os.path.join(save_dir, weights)
    if not os.path.exists(weights_file):
        logger.error("weights file %s does not exist, unable to archive model", weights_file)
        return
    config_file = os.path.join(save_dir, config_name)
    if not os.path.exists(config_file):
        logger.error("config file %s does not exist, unable to archive model", config_file)
    if archive_path is not None:
        archive_file = archive_path
        if os.path.isdir(archive_file):
            archive_file = os.path.join(archive_file, "model.tar.gz")
    else:
        archive_file = os.path.join(save_dir, "model.tar.gz")
    logger.info("archiving weights and vocabulary to %s", archive_file)
    with tarfile.open(archive_file, 'w:gz') as archive:
        archive.add(config_file, arcname='bert_config.json')
        archive.add(weights_file, arcname='pytorch_model.bin')

使用示例:

1
2
3
4
save_model('chinese_wwm_ext_pytorch',
           weights='pytorch_model.bin',
           config_name='bert_config.json',
           archive_path='chinese_wwm_ext_pytorch')