Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pushing bug fix for metacat #487

Merged
merged 4 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,20 +257,19 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
category_value2id = g_config['category_value2id']
if not category_value2id:
# Encode the category values
data_undersampled, full_data, category_value2id = encode_category_values(data,
full_data, data_undersampled, category_value2id = encode_category_values(data,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
else:
# We already have everything, just get the data
data_undersampled, full_data, category_value2id = encode_category_values(data,
full_data, data_undersampled, category_value2id = encode_category_values(data,
existing_category_value2id=category_value2id,
category_undersample=self.config.model.category_undersample)
g_config['category_value2id'] = category_value2id
# Make sure the config number of classes is the same as the one found in the data
if len(category_value2id) != self.config.model['nclasses']:
logger.warning(
"The number of classes set in the config is not the same as the one found in the data: {} vs {}".format(
self.config.model['nclasses'], len(category_value2id)))
"The number of classes set in the config is not the same as the one found in the data: %d vs %d",self.config.model['nclasses'], len(category_value2id))
logger.warning("Auto-setting the nclasses value in config and rebuilding the model.")
self.config.model['nclasses'] = len(category_value2id)

Expand Down
14 changes: 8 additions & 6 deletions medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,12 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
Name of class that should be used to undersample the data (for 2 phase learning)

Returns:
dict:
New underesampled data (for 2 phase learning) with integers inplace of strings for category values
dict:
New data with integers inplace of strings for category values.
dict:
Map rom category value to ID for all categories in the data.
New undersampled data (for 2 phase learning) with integers inplace of strings for category values
dict:
Map from category value to ID for all categories in the data.
"""
data = list(data)
if existing_category_value2id is not None:
Expand All @@ -194,7 +194,7 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for k in keys_ls:
category_value2id_[k] = len(category_value2id_)

logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping:", category_value2id_)
logger.warning("Labels found with 0 data; updates made\nFinal label encoding mapping: %s",category_value2id_)
category_value2id = category_value2id_

for c in category_values:
Expand All @@ -210,6 +210,8 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for i in range(len(data)):
if data[i][2] in category_value2id.values():
label_data_[data[i][2]] = label_data_[data[i][2]] + 1

logger.info("Original label_data: %s",label_data_)
# Undersampling data
if category_undersample is None or category_undersample == '':
min_label = min(label_data_.values())
Expand All @@ -232,9 +234,9 @@ def encode_category_values(data: Dict, existing_category_value2id: Optional[Dict
for i in range(len(data_undersampled)):
if data_undersampled[i][2] in category_value2id.values():
label_data[data_undersampled[i][2]] = label_data[data_undersampled[i][2]] + 1
logger.info(f"Updated label_data: {label_data}")
logger.info("Updated label_data: %s",label_data)

return data_undersampled, data, category_value2id
return data, data_undersampled, category_value2id


def json_to_fake_spacy(data: Dict, id2text: Dict) -> Iterable:
Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
y_ = [x[2] for x in train_data]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_)
config.train['class_weights'] = class_weights.tolist()
logger.info(f"Class weights computed: {class_weights}")
logger.info("Class weights computed: %s",class_weights)

class_weights = torch.FloatTensor(class_weights).to(device)
if config.train['loss_funct'] == 'cross_entropy':
Expand Down Expand Up @@ -259,7 +259,7 @@ def initialize_model(classifier, data_, batch_size_, lr_, epochs=4):

# Total number of training steps
total_steps = int((len(data_) / batch_size_) * epochs)
logger.info('Total steps for optimizer: {}'.format(total_steps))
logger.info('Total steps for optimizer: %d',total_steps)

# Set up the learning rate scheduler
scheduler_ = get_linear_schedule_with_warmup(optimizer_,
Expand Down
2 changes: 1 addition & 1 deletion medcat/utils/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, config):
super(BertForMetaAnnotation, self).__init__()
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
if config.model['input_size'] != _bertconfig.hidden_size:
logger.warning(f"\nInput size for {config.model.model_variant} model should be {_bertconfig.hidden_size}, provided input size is {config.model['input_size']} Input size changed to {_bertconfig.hidden_size}")
logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size)

bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
self.config = config
Expand Down
Loading