diff --git a/medcat/cdb.py b/medcat/cdb.py index 6ae15d3f5..477b1bf94 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -6,7 +6,6 @@ import aiofiles import numpy as np from typing import Dict, Set, Optional, List, Union, cast -from functools import partial import os from medcat import __version__ @@ -14,8 +13,12 @@ from medcat.utils.matutils import unitvec from medcat.utils.ml_utils import get_lr_linking from medcat.utils.decorators import deprecated -from medcat.config import Config, weighted_average, workers +from medcat.config import Config, workers from medcat.utils.saving.serializer import CDBSerializer +from medcat.utils.config_utils import get_and_del_weighted_average_from_config +from medcat.utils.config_utils import default_weighted_average +from medcat.utils.config_utils import ensure_backward_compatibility +from medcat.utils.config_utils import fix_waf_lambda, attempt_fix_weighted_average_function logger = logging.getLogger(__name__) @@ -98,6 +101,7 @@ def __init__(self, config: Union[Config, None] = None) -> None: self.vocab: Dict = {} # Vocabulary of all words ever in our cdb self._optim_params = None self.is_dirty = False + self._init_waf_from_config() self._hash: Optional[str] = None # the config hash is kept track of here so that # the CDB hash can be re-calculated when the config changes @@ -107,6 +111,18 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._config_hash: Optional[str] = None self._memory_optimised_parts: Set[str] = set() + def _init_waf_from_config(self): + waf = get_and_del_weighted_average_from_config(self.config) + if waf is not None: + logger.info("Using (potentially) custom value of weighed " + "average function") + self.weighted_average_function = attempt_fix_weighted_average_function(waf) + elif hasattr(self, 'weighted_average_function'): + # keep existing + pass + else: + self.weighted_average_function = default_weighted_average + def get_name(self, cui: str) -> str: """Returns preferred name if it exists, otherwise it will return the longest name assigned to the concept. @@ -558,6 +574,8 @@ def load_config(self, config_path: str) -> None: # this should be the behaviour for all newer models self.config = cast(Config, Config.load(config_path)) logger.debug("Loaded config from CDB from %s", config_path) + # new config, potentially new weighted_average_function to read + self._init_waf_from_config() # mark config read from file self._config_from_file = True @@ -582,7 +600,8 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[ ser = CDBSerializer(path, json_path) cdb = ser.deserialize(CDB) cls._check_medcat_version(cdb.config.asdict()) - cls._ensure_backward_compatibility(cdb.config) + fix_waf_lambda(cdb) + ensure_backward_compatibility(cdb.config, workers) # Overwrite the config with new data if config_dict is not None: @@ -855,19 +874,6 @@ def most_similar(self, return res - @staticmethod - def _ensure_backward_compatibility(config: Config) -> None: - # Hacky way of supporting old CDBs - weighted_average_function = config.linking.weighted_average_function - if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "": - # the following type ignoring is for mypy because it is unable to detect the signature - config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore - if config.general.workers is None: - config.general.workers = workers() - disabled_comps = config.general.spacy_disabled_components - if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps: - config.general.spacy_disabled_components.append('lemmatizer') - @classmethod def _check_medcat_version(cls, config_data: Dict) -> None: cdb_medcat_version = config_data.get('version', {}).get('medcat_version', None) diff --git a/medcat/config.py b/medcat/config.py index 88c4aad14..ba8aaef91 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -1,17 +1,19 @@ from datetime import datetime from pydantic import BaseModel, Extra, ValidationError from pydantic.fields import ModelField -from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union +from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type from multiprocessing import cpu_count import logging import jsonpickle +import json from functools import partial import re from medcat.utils.hasher import Hasher from medcat.utils.matutils import intersect_nonempty_set from medcat.utils.config_utils import attempt_fix_weighted_average_function -from medcat.utils.config_utils import weighted_average +from medcat.utils.config_utils import weighted_average, is_old_type_config_dict +from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ def __getitem__(self, arg: str) -> Any: raise KeyError from e def __setattr__(self, arg: str, val) -> None: + # TODO: remove this in the future when we stop stupporting this in config if isinstance(self, Linking) and arg == "weighted_average_function": val = attempt_fix_weighted_average_function(val) super().__setattr__(arg, val) @@ -103,8 +106,8 @@ def save(self, save_path: str) -> None: save_path(str): Where to save the created json file """ # We want to save the dict here, not the whole class - json_string = jsonpickle.encode( - {field: getattr(self, field) for field in self.fields()}) + json_string = json.dumps(self.asdict(), cls=cast(Type[json.JSONEncoder], + CustomDelegatingEncoder.def_inst)) with open(save_path, 'w') as f: f.write(json_string) @@ -204,7 +207,11 @@ def load(cls, save_path: str) -> "MixingConfig": # Read the jsonpickle string with open(save_path) as f: - config_dict = jsonpickle.decode(f.read()) + config_dict = json.load(f, object_hook=default_hook) + if is_old_type_config_dict(config_dict): + logger.warning("Loading an old type of config (jsonpickle) from '%s'", + save_path) + config_dict = jsonpickle.decode(f.read()) config.merge_config(config_dict) @@ -511,9 +518,6 @@ class Linking(MixingConfig, BaseModel): similarity calculation and will have a similarity of -1.""" always_calculate_similarity: bool = False """Do we want to calculate context similarity even for concepts that are not ambigous.""" - weighted_average_function: Callable[..., Any] = _DEFAULT_PARTIAL - """Weights for a weighted average - 'weighted_average_function': partial(weighted_average, factor=0.02),""" calculate_dynamic_threshold: bool = False """Concepts below this similarity will be ignored. Type can be static/dynamic - if dynamic each CUI has a different TH and it is calcualted as the average confidence for that CUI * similarity_threshold. Take care that dynamic works only @@ -597,3 +601,39 @@ def get_hash(self): hasher.update(v2, length=True) self.hash = hasher.hexdigest() return self.hash + + +class UseOfOldConfigOptionException(AttributeError): + + def __init__(self, conf_type: Type[FakeDict], arg_name: str, advice: str) -> None: + super().__init__(f"Tried to use {conf_type.__name__}.{arg_name}. " + f"Advice: {advice}") + self.conf_type = conf_type + self.arg_name = arg_name + self.advice = advice + + +# NOTE: The following is for backwards compatibility and should be removed +# at some point in the future + +# wrapper for functions for a better error in case of weighted_average_function +# access +def _wrapper(func, check_type: Type[FakeDict], advice: str, exp_type: Type[Exception]): + def wrapper(*args, **kwargs): + try: + res = func(*args, **kwargs) + except exp_type as ex: + if ((len(args) == 2 and len(kwargs) == 0) and + (isinstance(args[0], check_type) and + args[1] == "weighted_average_function")): + raise UseOfOldConfigOptionException(Linking, args[1], advice) from ex + raise ex + return res + return wrapper + + +# wrap Linking.__getattribute__ so that when getting weighted_average_function +# we get a nicer exceptio +_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly" +Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore +Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore diff --git a/medcat/linking/vector_context_model.py b/medcat/linking/vector_context_model.py index 7c4c11a69..e4875c32f 100644 --- a/medcat/linking/vector_context_model.py +++ b/medcat/linking/vector_context_model.py @@ -71,7 +71,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict: values = [] # Add left - values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_) + values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_) for step, tkn in enumerate(tokens_left) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) if not self.config.linking['context_ignore_center_tokens']: @@ -83,7 +83,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict: values.extend([self.vocab.vec(tkn.lower_) for tkn in tokens_center if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) # Add right - values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_) + values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_) for step, tkn in enumerate(tokens_right) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None]) if len(values) > 0: diff --git a/medcat/utils/config_utils.py b/medcat/utils/config_utils.py index 1aafbf3f1..92ea111ed 100644 --- a/medcat/utils/config_utils.py +++ b/medcat/utils/config_utils.py @@ -1,15 +1,64 @@ from functools import partial -from typing import Callable +from typing import Callable, Optional, Protocol import logging +from pydantic import BaseModel + + +class WAFCarrier(Protocol): + + @property + def weighted_average_function(self) -> Callable[[float], int]: + pass logger = logging.getLogger(__name__) +def is_old_type_config_dict(d: dict) -> bool: + if set(('py/object', 'py/state')) <= set(d.keys()): + return True + return False + + +def fix_waf_lambda(carrier: WAFCarrier) -> None: + weighted_average_function = carrier.weighted_average_function # type: ignore + if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "": + # the following type ignoring is for mypy because it is unable to detect the signature + carrier.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore + + +# NOTE: This method is a hacky workaround. The type ignores are because I cannot +# import config here since it would produce a circular import +def ensure_backward_compatibility(config: BaseModel, workers: Callable[[], int]) -> None: + # Hacky way of supporting old CDBs + if hasattr(config.linking, 'weighted_average_function'): # type: ignore + fix_waf_lambda(config.linking) # type: ignore + if config.general.workers is None: # type: ignore + config.general.workers = workers() # type: ignore + disabled_comps = config.general.spacy_disabled_components # type: ignore + if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps: + config.general.spacy_disabled_components.append('lemmatizer') # type: ignore + + +def get_and_del_weighted_average_from_config(config: BaseModel) -> Optional[Callable[[int], float]]: + if not hasattr(config, 'linking'): + return None + linking = config.linking + if not hasattr(linking, 'weighted_average_function'): + return None + waf = linking.weighted_average_function + delattr(linking, 'weighted_average_function') + return waf + + def weighted_average(step: int, factor: float) -> float: return max(0.1, 1 - (step ** 2 * factor)) +def default_weighted_average(step: int) -> float: + return weighted_average(step, factor=0.0004) + + def attempt_fix_weighted_average_function(waf: Callable[[int], float] ) -> Callable[[int], float]: """Attempf fix weighted_average_function. diff --git a/medcat/utils/saving/coding.py b/medcat/utils/saving/coding.py index 81a8420aa..89f9c0651 100644 --- a/medcat/utils/saving/coding.py +++ b/medcat/utils/saving/coding.py @@ -1,6 +1,7 @@ from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable import json +import re @runtime_checkable @@ -35,6 +36,7 @@ def try_encode(self, obj: object) -> Any: SET_IDENTIFIER = '==SET==' +PATTERN_IDENTIFIER = "==PATTERN==" class SetEncoder(PartEncoder): @@ -79,10 +81,34 @@ def try_decode(self, dct: dict) -> Union[dict, set]: return dct +class PatternEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, re.Pattern): + return {PATTERN_IDENTIFIER: obj.pattern} + raise UnsuitableObject() + + +class PatternDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, re.Pattern]: + """Decode re.Patttern from input dicts. + + Args: + dct (dict): The input dict + + Returns: + Union[dict, set]: The original dict if this was not a serialized pattern, the pattern otherwise + """ + if PATTERN_IDENTIFIER in dct: + return re.compile(dct[PATTERN_IDENTIFIER]) + return dct + + PostProcessor = Callable[[Any], None] # CDB -> None -DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ] -DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ] +DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, PatternEncoder] +DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, PatternDecoder] LOADING_POSTPROCESSORS: List[PostProcessor] = [] @@ -133,6 +159,8 @@ def object_hook(self, dct: dict) -> Any: def def_inst(cls) -> 'CustomDelegatingDecoder': if cls._def_inst is None: cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) + elif len(cls._def_inst._delegates) < len(DEFAULT_DECODERS): + cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) return cls._def_inst diff --git a/tests/test_cat.py b/tests/test_cat.py index ce1b62d98..deaea6ccb 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -712,7 +712,7 @@ class TestLoadingOldWeights(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.cdb = CDB.load(cls.cdb_path) - cls.wf = cls.cdb.config.linking.weighted_average_function + cls.wf = cls.cdb.weighted_average_function def test_can_call_weights(self): res = self.wf(step=1) diff --git a/tests/test_config.py b/tests/test_config.py index ce6ed76eb..bfd440a78 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,7 @@ import pickle import tempfile from medcat.config import Config, MixingConfig, VersionInfo, General, LinkingFilters +from medcat.config import UseOfOldConfigOptionException, Linking from pydantic import ValidationError import os @@ -208,6 +209,13 @@ def test_config_hash_recalc_same_changed(self): h2 = config.get_hash() self.assertEqual(h1, h2) + def test_can_save_load(self): + config = Config() + with tempfile.NamedTemporaryFile() as file: + config.save(file.name) + config2 = Config.load(file.name) + self.assertEqual(config, config2) + class ConfigLinkingFiltersTests(unittest.TestCase): @@ -228,5 +236,36 @@ def test_not_allow_empty_dict_for_cuis_exclude(self): LinkingFilters(cuis_exclude={}) +class BackwardsCompatibilityTests(unittest.TestCase): + + def setUp(self) -> None: + self.config = Config() + + def test_use_weighted_average_function_identifier_nice_error(self): + with self.assertRaises(UseOfOldConfigOptionException): + self.config.linking.weighted_average_function(0) + + def test_use_weighted_average_function_dict_nice_error(self): + with self.assertRaises(UseOfOldConfigOptionException): + self.config.linking['weighted_average_function'](0) + + +class BackwardsCompatibilityWafPayloadTests(unittest.TestCase): + arg = 'weighted_average_function' + + @classmethod + def setUpClass(cls) -> None: + cls.config = Config() + with cls.assertRaises(cls, UseOfOldConfigOptionException) as cls.context: + cls.config.linking.weighted_average_function(0) + cls.raised = cls.context.exception + + def test_exception_has_correct_conf_type(self): + self.assertIs(self.raised.conf_type, Linking) + + def test_exception_has_correct_arg(self): + self.assertEqual(self.raised.arg_name, self.arg) + + if __name__ == '__main__': unittest.main() diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index c2c44da16..6e883bad2 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -117,10 +117,6 @@ def test_round_trip(self): # The spacy model has full path in the loaded model, thus won't be equal cat.config.general.spacy_model = os.path.basename( cat.config.general.spacy_model) - # There can also be issues with loading the config.linking.weighted_average_function from file - # This should be fixed with newer models, - # but the example model is older, so has the older functionalitys - cat.config.linking.weighted_average_function = self.undertest.config.linking.weighted_average_function self.assertEqual(cat.config.asdict(), self.undertest.config.asdict()) self.assertEqual(cat.cdb.config, self.undertest.cdb.config) self.assertEqual(len(cat.vocab.vocab), len(self.undertest.vocab.vocab)) diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py new file mode 100644 index 000000000..713d15bb0 --- /dev/null +++ b/tests/utils/test_config_utils.py @@ -0,0 +1,50 @@ +from medcat.config import Config +from medcat.utils.saving.coding import default_hook, CustomDelegatingEncoder +from medcat.utils import config_utils +import json + +import unittest + +OLD_STYLE_DICT = {'py/object': 'medcat.config.VersionInfo', + 'py/state': { + '__dict__': { + 'history': ['0c0de303b6dc0020',], + 'meta_cats': [], + 'cdb_info': { + 'Number of concepts': 785910, + 'Number of names': 2480049, + 'Number of concepts that received training': 378746, + 'Number of seen training examples in total': 1863973060, + 'Average training examples per concept': { + 'py/reduce': [{'py/function': 'numpy.core.multiarray.scalar'},] + } + }, + 'performance': {'ner': {}, 'meta': {}}, + 'description': 'No description', + 'id': 'ff4f4e00bc97de58', + 'last_modified': '26 April 2024', + 'location': None, + 'ontology': ['ONTOLOGY1'], + 'medcat_version': '1.10.2' + }, + '__fields_set__': { + 'py/set': ['id', 'ontology', 'description', 'history', + 'location', 'medcat_version', 'last_modified', + 'meta_cats', 'cdb_info', 'performance'] + }, + '__private_attribute_values__': {} + } + } + + +NEW_STYLE_DICT = json.loads(json.dumps(Config().asdict(), cls=CustomDelegatingEncoder.def_inst), + object_hook=default_hook) + + +class ConfigUtilsTests(unittest.TestCase): + + def test_identifies_old_style_dict(self): + self.assertTrue(config_utils.is_old_type_config_dict(OLD_STYLE_DICT)) + + def test_identifies_new_style_dict(self): + self.assertFalse(config_utils.is_old_type_config_dict(NEW_STYLE_DICT))