Skip to content

Commit

Permalink
CU-8694fk90t (almost) only primitive config (CogStack#425)
Browse files Browse the repository at this point in the history
* CU-8694fk90r: Move backwards compatibility method from CDB to config utils

* CU-8694fk90r: Move weighted_average_function from config to CDB; create necessary backwards compatibility workarounds

* CU-8694fk90r: Move usage of weighted_average_function in tests

* CU-8694fk90r: Add JSON encode and decoder for re.Pattern

* CU-8694fk90r: Rebuild custom decoder if needed

* CU-8694fk90r: Add method to detect old style config

* CU-8694fk90r: Use regular json serialisation for config; Retain option to read old jsonpickled config

* CU-8694fk90r: Add test for config serialisation

* CU-8694fk90r: Make sure to fix weighted_average_function upon setting it

* CU-8694fk90t: Add missing tests for config utils

* CU-8694fk90t: Add tests for better raised exception upon old way of using weighted_average_function

* CU-8694fk90t: Fix exception type in an added test

* CU-8694fk90t: Add further tests for exception payload

* CU-8694fk90t: Add improved exceptions when using old/unsupported value of weighted_average_function in config

* CU-8694fk90t: Add typing fix exceptions

* CU-8694fk90t: Make custom exception derive from AttributeError to correctly handle hasattr calls
  • Loading branch information
mart-r authored May 22, 2024
1 parent e46dca8 commit 2872d5e
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 34 deletions.
38 changes: 22 additions & 16 deletions medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
import aiofiles
import numpy as np
from typing import Dict, Set, Optional, List, Union, cast
from functools import partial
import os

from medcat import __version__
from medcat.utils.hasher import Hasher
from medcat.utils.matutils import unitvec
from medcat.utils.ml_utils import get_lr_linking
from medcat.utils.decorators import deprecated
from medcat.config import Config, weighted_average, workers
from medcat.config import Config, workers
from medcat.utils.saving.serializer import CDBSerializer
from medcat.utils.config_utils import get_and_del_weighted_average_from_config
from medcat.utils.config_utils import default_weighted_average
from medcat.utils.config_utils import ensure_backward_compatibility
from medcat.utils.config_utils import fix_waf_lambda, attempt_fix_weighted_average_function


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,6 +101,7 @@ def __init__(self, config: Union[Config, None] = None) -> None:
self.vocab: Dict = {} # Vocabulary of all words ever in our cdb
self._optim_params = None
self.is_dirty = False
self._init_waf_from_config()
self._hash: Optional[str] = None
# the config hash is kept track of here so that
# the CDB hash can be re-calculated when the config changes
Expand All @@ -107,6 +111,18 @@ def __init__(self, config: Union[Config, None] = None) -> None:
self._config_hash: Optional[str] = None
self._memory_optimised_parts: Set[str] = set()

def _init_waf_from_config(self):
waf = get_and_del_weighted_average_from_config(self.config)
if waf is not None:
logger.info("Using (potentially) custom value of weighed "
"average function")
self.weighted_average_function = attempt_fix_weighted_average_function(waf)
elif hasattr(self, 'weighted_average_function'):
# keep existing
pass
else:
self.weighted_average_function = default_weighted_average

def get_name(self, cui: str) -> str:
"""Returns preferred name if it exists, otherwise it will return
the longest name assigned to the concept.
Expand Down Expand Up @@ -558,6 +574,8 @@ def load_config(self, config_path: str) -> None:
# this should be the behaviour for all newer models
self.config = cast(Config, Config.load(config_path))
logger.debug("Loaded config from CDB from %s", config_path)
# new config, potentially new weighted_average_function to read
self._init_waf_from_config()
# mark config read from file
self._config_from_file = True

Expand All @@ -582,7 +600,8 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[
ser = CDBSerializer(path, json_path)
cdb = ser.deserialize(CDB)
cls._check_medcat_version(cdb.config.asdict())
cls._ensure_backward_compatibility(cdb.config)
fix_waf_lambda(cdb)
ensure_backward_compatibility(cdb.config, workers)

# Overwrite the config with new data
if config_dict is not None:
Expand Down Expand Up @@ -855,19 +874,6 @@ def most_similar(self,

return res

@staticmethod
def _ensure_backward_compatibility(config: Config) -> None:
# Hacky way of supporting old CDBs
weighted_average_function = config.linking.weighted_average_function
if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "<lambda>":
# the following type ignoring is for mypy because it is unable to detect the signature
config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore
if config.general.workers is None:
config.general.workers = workers()
disabled_comps = config.general.spacy_disabled_components
if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
config.general.spacy_disabled_components.append('lemmatizer')

@classmethod
def _check_medcat_version(cls, config_data: Dict) -> None:
cdb_medcat_version = config_data.get('version', {}).get('medcat_version', None)
Expand Down
56 changes: 48 additions & 8 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import ModelField
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type
from multiprocessing import cpu_count
import logging
import jsonpickle
import json
from functools import partial
import re

from medcat.utils.hasher import Hasher
from medcat.utils.matutils import intersect_nonempty_set
from medcat.utils.config_utils import attempt_fix_weighted_average_function
from medcat.utils.config_utils import weighted_average
from medcat.utils.config_utils import weighted_average, is_old_type_config_dict
from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook


logger = logging.getLogger(__name__)
Expand All @@ -31,6 +33,7 @@ def __getitem__(self, arg: str) -> Any:
raise KeyError from e

def __setattr__(self, arg: str, val) -> None:
# TODO: remove this in the future when we stop stupporting this in config
if isinstance(self, Linking) and arg == "weighted_average_function":
val = attempt_fix_weighted_average_function(val)
super().__setattr__(arg, val)
Expand Down Expand Up @@ -103,8 +106,8 @@ def save(self, save_path: str) -> None:
save_path(str): Where to save the created json file
"""
# We want to save the dict here, not the whole class
json_string = jsonpickle.encode(
{field: getattr(self, field) for field in self.fields()})
json_string = json.dumps(self.asdict(), cls=cast(Type[json.JSONEncoder],
CustomDelegatingEncoder.def_inst))

with open(save_path, 'w') as f:
f.write(json_string)
Expand Down Expand Up @@ -204,7 +207,11 @@ def load(cls, save_path: str) -> "MixingConfig":

# Read the jsonpickle string
with open(save_path) as f:
config_dict = jsonpickle.decode(f.read())
config_dict = json.load(f, object_hook=default_hook)
if is_old_type_config_dict(config_dict):
logger.warning("Loading an old type of config (jsonpickle) from '%s'",
save_path)
config_dict = jsonpickle.decode(f.read())

config.merge_config(config_dict)

Expand Down Expand Up @@ -511,9 +518,6 @@ class Linking(MixingConfig, BaseModel):
similarity calculation and will have a similarity of -1."""
always_calculate_similarity: bool = False
"""Do we want to calculate context similarity even for concepts that are not ambigous."""
weighted_average_function: Callable[..., Any] = _DEFAULT_PARTIAL
"""Weights for a weighted average
'weighted_average_function': partial(weighted_average, factor=0.02),"""
calculate_dynamic_threshold: bool = False
"""Concepts below this similarity will be ignored. Type can be static/dynamic - if dynamic each CUI has a different TH
and it is calcualted as the average confidence for that CUI * similarity_threshold. Take care that dynamic works only
Expand Down Expand Up @@ -597,3 +601,39 @@ def get_hash(self):
hasher.update(v2, length=True)
self.hash = hasher.hexdigest()
return self.hash


class UseOfOldConfigOptionException(AttributeError):

def __init__(self, conf_type: Type[FakeDict], arg_name: str, advice: str) -> None:
super().__init__(f"Tried to use {conf_type.__name__}.{arg_name}. "
f"Advice: {advice}")
self.conf_type = conf_type
self.arg_name = arg_name
self.advice = advice


# NOTE: The following is for backwards compatibility and should be removed
# at some point in the future

# wrapper for functions for a better error in case of weighted_average_function
# access
def _wrapper(func, check_type: Type[FakeDict], advice: str, exp_type: Type[Exception]):
def wrapper(*args, **kwargs):
try:
res = func(*args, **kwargs)
except exp_type as ex:
if ((len(args) == 2 and len(kwargs) == 0) and
(isinstance(args[0], check_type) and
args[1] == "weighted_average_function")):
raise UseOfOldConfigOptionException(Linking, args[1], advice) from ex
raise ex
return res
return wrapper


# wrap Linking.__getattribute__ so that when getting weighted_average_function
# we get a nicer exceptio
_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly"
Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore
Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore
4 changes: 2 additions & 2 deletions medcat/linking/vector_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:

values = []
# Add left
values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
for step, tkn in enumerate(tokens_left) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

if not self.config.linking['context_ignore_center_tokens']:
Expand All @@ -83,7 +83,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:
values.extend([self.vocab.vec(tkn.lower_) for tkn in tokens_center if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

# Add right
values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
for step, tkn in enumerate(tokens_right) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

if len(values) > 0:
Expand Down
51 changes: 50 additions & 1 deletion medcat/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
from functools import partial
from typing import Callable
from typing import Callable, Optional, Protocol
import logging
from pydantic import BaseModel


class WAFCarrier(Protocol):

@property
def weighted_average_function(self) -> Callable[[float], int]:
pass


logger = logging.getLogger(__name__)


def is_old_type_config_dict(d: dict) -> bool:
if set(('py/object', 'py/state')) <= set(d.keys()):
return True
return False


def fix_waf_lambda(carrier: WAFCarrier) -> None:
weighted_average_function = carrier.weighted_average_function # type: ignore
if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "<lambda>":
# the following type ignoring is for mypy because it is unable to detect the signature
carrier.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore


# NOTE: This method is a hacky workaround. The type ignores are because I cannot
# import config here since it would produce a circular import
def ensure_backward_compatibility(config: BaseModel, workers: Callable[[], int]) -> None:
# Hacky way of supporting old CDBs
if hasattr(config.linking, 'weighted_average_function'): # type: ignore
fix_waf_lambda(config.linking) # type: ignore
if config.general.workers is None: # type: ignore
config.general.workers = workers() # type: ignore
disabled_comps = config.general.spacy_disabled_components # type: ignore
if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
config.general.spacy_disabled_components.append('lemmatizer') # type: ignore


def get_and_del_weighted_average_from_config(config: BaseModel) -> Optional[Callable[[int], float]]:
if not hasattr(config, 'linking'):
return None
linking = config.linking
if not hasattr(linking, 'weighted_average_function'):
return None
waf = linking.weighted_average_function
delattr(linking, 'weighted_average_function')
return waf


def weighted_average(step: int, factor: float) -> float:
return max(0.1, 1 - (step ** 2 * factor))


def default_weighted_average(step: int) -> float:
return weighted_average(step, factor=0.0004)


def attempt_fix_weighted_average_function(waf: Callable[[int], float]
) -> Callable[[int], float]:
"""Attempf fix weighted_average_function.
Expand Down
32 changes: 30 additions & 2 deletions medcat/utils/saving/coding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable

import json
import re


@runtime_checkable
Expand Down Expand Up @@ -35,6 +36,7 @@ def try_encode(self, obj: object) -> Any:


SET_IDENTIFIER = '==SET=='
PATTERN_IDENTIFIER = "==PATTERN=="


class SetEncoder(PartEncoder):
Expand Down Expand Up @@ -79,10 +81,34 @@ def try_decode(self, dct: dict) -> Union[dict, set]:
return dct


class PatternEncoder(PartEncoder):

def try_encode(self, obj):
if isinstance(obj, re.Pattern):
return {PATTERN_IDENTIFIER: obj.pattern}
raise UnsuitableObject()


class PatternDecoder(PartDecoder):

def try_decode(self, dct: dict) -> Union[dict, re.Pattern]:
"""Decode re.Patttern from input dicts.
Args:
dct (dict): The input dict
Returns:
Union[dict, set]: The original dict if this was not a serialized pattern, the pattern otherwise
"""
if PATTERN_IDENTIFIER in dct:
return re.compile(dct[PATTERN_IDENTIFIER])
return dct


PostProcessor = Callable[[Any], None] # CDB -> None

DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ]
DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ]
DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, PatternEncoder]
DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, PatternDecoder]
LOADING_POSTPROCESSORS: List[PostProcessor] = []


Expand Down Expand Up @@ -133,6 +159,8 @@ def object_hook(self, dct: dict) -> Any:
def def_inst(cls) -> 'CustomDelegatingDecoder':
if cls._def_inst is None:
cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
elif len(cls._def_inst._delegates) < len(DEFAULT_DECODERS):
cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
return cls._def_inst


Expand Down
2 changes: 1 addition & 1 deletion tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ class TestLoadingOldWeights(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.cdb = CDB.load(cls.cdb_path)
cls.wf = cls.cdb.config.linking.weighted_average_function
cls.wf = cls.cdb.weighted_average_function

def test_can_call_weights(self):
res = self.wf(step=1)
Expand Down
Loading

0 comments on commit 2872d5e

Please sign in to comment.