Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加uie-base-torch模型 #62

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions compare_uie_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
File Description:
Compare PPaddle UIE-base and Torch UIE-base effect
Author: liushu
Mail: [email protected]
Created Time: 2022/11/07
ark_nlp@https://github.com/xiangking/ark-nlp
"""
import os
import sys
from collections import defaultdict
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# create paddle uie
from paddlenlp import Taskflow

os.environ['CUDA_VISIBLE_DEVICES'] = '3'
schema = ['时间', '选手', '赛事名称']
_text = "2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!"
ie = Taskflow('information_extraction', schema=schema)
print(ie(_text))
# create torch uie
from ark_nlp.model.ie.prompt_uie import Tokenizer, PromptUIEConfig, PromptUIE, PromptUIEPredictor

torch_model_path = 'uie-base-torch' # 因为ark库的原因,需要将uie转为bert类型模型
tokenizer = Tokenizer(vocab=torch_model_path, max_seq_len=512) # Paddle UIE-base中默认max_seq_len=512
config = PromptUIEConfig.from_pretrained(torch_model_path)
dl_module = PromptUIE.from_pretrained(torch_model_path, config=config)
ner_predictor_instance = PromptUIEPredictor(dl_module, tokenizer)
entities = []
tmp_entity = {}
for prompt_type in schema:
for entity in ner_predictor_instance.predict_one_sample([_text, prompt_type]):
if prompt_type not in tmp_entity:
tmp_entity[prompt_type] = [{
'text': entity['entity'],
'start': entity['start_idx'],
'end': entity['end_idx'],
}]
else:
tmp_entity[prompt_type].append({
'text': entity['entity'],
'start': entity['start_idx'],
'end': entity['end_idx'],
})
entities.append(tmp_entity)
print(entities)
114 changes: 114 additions & 0 deletions convert_uie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#!/usr/bin/env python
# encoding: utf-8
"""
File Description:
uie model conversion based on paddlenlp repository
official repo: https://github.com/PaddlePaddle/PaddleNLP/tree/develop/model_zoo/uie
Author: liushu
Mail: [email protected]
Created Time: 2022/11/07
"""
import collections
import os
import json
import paddle.fluid.dygraph as D
import torch
from paddle import fluid
from transformers import BertTokenizer, BertModel

def build_params_map(attention_num=12):
"""
build params map from paddle-paddle's ERNIE to transformer's BERT
:return:
"""
weight_map = collections.OrderedDict({
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
'encoder.embeddings.task_type_embeddings.weight': "bert.embeddings.task_type_embeddings.weight",
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.gamma',
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.beta',
})
# add attention layers
for i in range(attention_num):
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.gamma'
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.beta'
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.gamma'
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.beta'
# add pooler
weight_map.update(
{
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
'cls.predictions.transform.weight': 'cls.predictions.transform.dense.weight',
'cls.predictions.transform.bias': 'cls.predictions.transform.dense.bias',
'cls.predictions.layer_norm.weight': 'cls.predictions.transform.LayerNorm.gamma',
'cls.predictions.layer_norm.bias': 'cls.predictions.transform.LayerNorm.beta',
'cls.predictions.decoder_bias': 'cls.predictions.bias',
'linear_start.weight': 'start_linear.weight',
'linear_start.bias': 'start_linear.bias',
'linear_end.weight': 'end_linear.weight',
'linear_end.bias': 'end_linear.bias'
}
)
return weight_map


def extract_and_convert(input_dir, output_dir):
"""
抽取并转换
:param input_dir:
:param output_dir:
:return:
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print('=' * 20 + 'save config file' + '=' * 20)
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
if 'init_args' in config:
config = config['init_args'][0]
del config['init_class']
config['layer_norm_eps'] = 1e-5
config['model_type'] = 'bert'
config['architectures'] = ["BertModel"] # or 'BertModel'
config['intermediate_size'] = 4 * config['hidden_size']
json.dump(config, open(os.path.join(output_dir, 'config.json'), 'wt', encoding='utf-8'), indent=4)
print('=' * 20 + 'save vocab file' + '=' * 20)
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
words = f.read().splitlines()
words = [word.split('\t')[0] for word in words]
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
for word in words:
f.write(word + "\n")
print('=' * 20 + 'extract weights' + '=' * 20)
state_dict = collections.OrderedDict()
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
with fluid.dygraph.guard():
paddle_paddle_params, _ = D.load_dygraph(os.path.join(input_dir, 'model_state.pdparams'))
for weight_name, weight_value in paddle_paddle_params.items():
if 'weight' in weight_name:
if 'encoder.encoder' in weight_name or 'encoder.pooler' in weight_name or 'cls.' in weight_name or \
'linear_start' in weight_name or 'linear_end' in weight_name:
weight_value = weight_value.transpose()
if weight_name not in weight_map:
print('=' * 20, '[SKIP]', weight_name, '=' * 20)
continue
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
print(weight_name, '->', weight_map[weight_name], weight_value.shape)
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))


if __name__ == '__main__':
extract_and_convert('./uie-base', './uie-base-torch')