-
Notifications
You must be signed in to change notification settings - Fork 156
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c473ce6
commit c58bdeb
Showing
98 changed files
with
33,368 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
################################################################################################## | ||
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved. | ||
# Filename : __init__.py | ||
# Abstract : | ||
# Current Version: 1.0.0 | ||
# Date : 2022-05-06 | ||
################################################################################################## | ||
""" | ||
from .datasets import * | ||
from .models import * | ||
from .core import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
""" | ||
################################################################################################## | ||
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved. | ||
# Filename : __init__.py | ||
# Abstract : | ||
# Current Version: 1.0.0 | ||
# Date : 2022-05-06 | ||
################################################################################################## | ||
""" | ||
from .evaluation import eval_ner_f1 | ||
from .converters import SpanConverter, TransformersConverter | ||
|
||
__all__ = [ | ||
'eval_ner_f1', | ||
'SpanConverter', | ||
'TransformersConverter' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
################################################################################################## | ||
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved. | ||
# Filename : __init__.py | ||
# Abstract : | ||
# Current Version: 1.0.0 | ||
# Date : 2022-05-06 | ||
################################################################################################## | ||
""" | ||
from .transformers_converter import TransformersConverter | ||
from .span_converter import SpanConverter | ||
|
||
|
||
__all__ = ['TransformersConverter', 'SpanConverter'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
""" | ||
################################################################################################## | ||
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved. | ||
# Filename : base_converter.py | ||
# Abstract : | ||
# Current Version: 1.0.0 | ||
# Date : 2022-05-06 | ||
################################################################################################## | ||
""" | ||
from abc import ABCMeta, abstractmethod | ||
|
||
|
||
class BaseConverter(metaclass=ABCMeta): | ||
""" Base converter, Convert between text, index and tensor for NER pipeline. | ||
""" | ||
@abstractmethod | ||
def convert_text2id(self, results): | ||
""" Convert token to ids. | ||
Args: | ||
results (dict): A dict must containing the token key: | ||
- tokens (list]): Tokens list. | ||
Returns: | ||
dict: corresponding ids | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def convert_pred2entities(self, preds, masks, **kwargs): | ||
""" Gets entities from preds. | ||
Args: | ||
preds (list): Sequence of preds. | ||
masks (Tensor): The valid part is 1 and the invalid part is 0. | ||
Returns: | ||
list: List of entities. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def convert_entity2label(self, labels): | ||
""" Convert labeled entities to ids. | ||
Args: | ||
labels (list): eg:['B-PER', 'I-PER'] | ||
Returns: | ||
dict: corresponding labels | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
""" | ||
################################################################################################## | ||
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved. | ||
# Filename : span_converter.py | ||
# Abstract : | ||
# Current Version: 1.0.0 | ||
# Date : 2022-05-06 | ||
################################################################################################## | ||
""" | ||
from seqeval.scheme import Tokens, IOBES | ||
from davarocr.davar_common.core import CONVERTERS | ||
from .transformers_converter import TransformersConverter | ||
|
||
|
||
@CONVERTERS.register_module() | ||
class SpanConverter(TransformersConverter): | ||
"""Span converter, converter for span model. | ||
""" | ||
def _generate_labelid_dict(self): | ||
label2id_dict = {label: i for i, label in enumerate(['O'] + self.label_list)} | ||
id2label_dict = {value: key for key, value in label2id_dict.items()} | ||
return label2id_dict, id2label_dict | ||
|
||
|
||
def _extract_subjects(self, seq): | ||
"""Get entities from label sequence | ||
""" | ||
entities = [(t.to_tuple()[1], t.to_tuple()[2], t.to_tuple()[3]) for t in Tokens(seq, IOBES).entities] | ||
return entities | ||
|
||
|
||
def convert_entity2label(self, labels): | ||
"""Convert labeled entities to ids. | ||
Args: | ||
labels (list): eg:['B-PER', 'I-PER'] | ||
Returns: | ||
dict: corresponding ids | ||
""" | ||
labels = self._labels_convert(labels, self.only_label_first_subword) | ||
cls_token_at_end=self.cls_token_at_end | ||
pad_on_left = self.pad_on_left | ||
label2id = self.label2id_dict | ||
subjects = self._extract_subjects(labels)#get entities | ||
start_ids = [0] * len(labels) | ||
end_ids = [0] * len(labels) | ||
subjects_id = [] | ||
for subject in subjects: | ||
label = subject[0] | ||
start = subject[1] | ||
end = subject[2] | ||
|
||
#set label for span | ||
start_ids[start] = label2id[label] | ||
end_ids[end-1] = label2id[label]#the true position is end-1 | ||
subjects_id.append((label2id[label], start, end)) | ||
|
||
# Account for [CLS] and [SEP] with "- 2". | ||
special_tokens_count = 2 | ||
if len(labels) > self.max_len - special_tokens_count: | ||
start_ids = start_ids[: (self.max_len - special_tokens_count)] | ||
end_ids = end_ids[: (self.max_len - special_tokens_count)] | ||
|
||
#add sep | ||
start_ids += [0] | ||
end_ids += [0] | ||
if cls_token_at_end: | ||
#add [CLS] at end | ||
start_ids += [0] | ||
end_ids += [0] | ||
else: | ||
#add [CLS] at begin | ||
start_ids = [0]+ start_ids | ||
end_ids = [0]+ end_ids | ||
padding_length = self.max_len - len(labels) - 2 | ||
if pad_on_left: | ||
#pad on left | ||
start_ids = ([0] * padding_length) + start_ids | ||
end_ids = ([0] * padding_length) + end_ids | ||
else: | ||
#pad on right | ||
start_ids += ([0] * padding_length) | ||
end_ids += ([0] * padding_length) | ||
res = dict(start_positions=start_ids, end_positions=end_ids) | ||
return res | ||
|
||
def convert_pred2entities(self, preds, masks, **kwargs): | ||
"""Gets entities from preds. | ||
Args: | ||
preds (list): Sequence of preds. | ||
masks (tensor): The valid part is 1 and the invalid part is 0. | ||
Returns: | ||
list: List of [[[entity_type, | ||
entity_start, entity_end]]]. | ||
""" | ||
id2label = self.id2label | ||
pred_entities = [] | ||
for pred in preds: | ||
entities = [] | ||
entity = [0, 0, 0] | ||
for tag in pred: | ||
entity[0] = id2label[tag[0]] | ||
entity[1] = tag[1] - 1 | ||
entity[2] = tag[2] - 1 | ||
entities.append(entity.copy()) | ||
pred_entities.append(entities.copy()) | ||
tokens_index = [index.cpu().numpy().tolist()[0] for index in kwargs['tokens_index']] | ||
pred_entities = [self._labels_convert_ori(pred_entity, tokens_index) for pred_entity in pred_entities] | ||
return pred_entities |
Oops, something went wrong.