Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8695d4www pydantic 2 #476

Open
wants to merge 46 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
b0b3d43
CU-8695d4www: Bump pydantic requirement to 2.6+
mart-r Aug 12, 2024
cb0104f
CU-8695d4www: Update methods to use pydantic2 based ones
mart-r Aug 12, 2024
e806d54
CU-8695d4www: Update methods to use pydantic2 based ones [part 2]
mart-r Aug 12, 2024
ea7e04a
CU-8695d4www: Use identifier based config when setting last train dat…
mart-r Aug 12, 2024
3879fe5
CU-8695d4www: Use pydantic2-based model validation
mart-r Aug 12, 2024
960e405
CU-8695d4www: Add workarounds for pydantic1 methods
mart-r Aug 12, 2024
10a7a58
CU-8695d4www: Add missing utils module for pydantic1 methods
mart-r Aug 13, 2024
080ae71
Revert "CU-8695d4www: Bump pydantic requirement to 2.6+"
mart-r Aug 13, 2024
b86135a
CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for …
mart-r Aug 13, 2024
0eb9f76
CU-8695d4www: Make pydantic2-requires getattribute wrapper only apply…
mart-r Aug 13, 2024
0e9fe91
CU-8695d4www: Fix missin model dump getter abstraction
mart-r Aug 13, 2024
0cb31ee
CU-8695d4www: Fix missin model dump getter abstraction (in CAT)
mart-r Aug 13, 2024
a7aab98
CU-8695d4www: Update tests for pydantic 1 and 2 support
mart-r Aug 13, 2024
897df2d
Revert "CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based meth…
mart-r Aug 13, 2024
1bbe88e
Reapply "CU-8695d4www: Bump pydantic requirement to 2.6+"
mart-r Aug 13, 2024
cc7c2ce
CU-8695d4www: Allow both pydantic 1 and 2
mart-r Aug 13, 2024
0ee1a8a
CU-8695d4www: Deprecated pydantic utils for removal in 1.15
mart-r Aug 13, 2024
a89e680
CU-8695d4www: Allow usage of specified deprecated method(s) during tests
mart-r Aug 13, 2024
825628e
CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during t…
mart-r Aug 13, 2024
927f807
CU-8695d4www: Add documentation for argument allowing usage during te…
mart-r Aug 13, 2024
fadc7d1
CU-8695d4www: Fix allowing deprecation during test time
mart-r Aug 13, 2024
b1b11ce
CU-8695d4www: Fix model dump getting in regression checker
mart-r Aug 14, 2024
e30ca16
Revert "CU-8695d4www: Fix allowing deprecation during test time"
mart-r Aug 15, 2024
0c5b7ca
Revert "CU-8695d4www: Add documentation for argument allowing usage d…
mart-r Aug 15, 2024
6c76acc
Revert "CU-8695d4www: Allow usage of pydantic 1-2 workaround methods …
mart-r Aug 15, 2024
a4b2ea0
Revert "CU-8695d4www: Allow usage of specified deprecated method(s) d…
mart-r Aug 15, 2024
414f70a
Revert "CU-8695d4www: Deprecated pydantic utils for removal in 1.15"
mart-r Aug 15, 2024
ecc54ab
CU-8695d4www: Add comment regarding pydantic backwards compatiblity w…
mart-r Aug 21, 2024
b160295
CU-8695d4www: Add pydantic 1 check to GHA workflow
mart-r Aug 21, 2024
6c6881a
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 29, 2024
b5ddf91
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 29, 2024
23d03c7
CU-8695d4www: Fix usage of pydantic-1 based dict method in regression…
mart-r Aug 29, 2024
8777256
CU-8695d4www: Fix usage of pydantic-1 based dict method in regression…
mart-r Aug 29, 2024
44e470a
CU-8695d4www: New workflow step to install and run mypy on pydantic 1
mart-r Aug 29, 2024
9eab8f0
CU-8695d4www: Add type ignore comments to pydantic2 versions in versi…
mart-r Aug 29, 2024
3d19cd3
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Aug 30, 2024
ebe17e0
Merge branch 'master' into CU-8695d4www-pydantic-2
mart-r Oct 15, 2024
6746b34
CU-8695d4www: Update pydantic requirement to 2.0+ only
mart-r Oct 15, 2024
b7f895e
CU-8695d4www: Update to pydantic 2 ONLY
mart-r Oct 15, 2024
3fe2c47
CU-869671bn4: Update mypy dev requirement to be less than 1.12
mart-r Oct 15, 2024
65d653f
CU-869671bn4: Fix model fields in config
mart-r Oct 15, 2024
11b1c7a
CU-869671bn4: Fix stats helper method - use correct type adapter
mart-r Oct 15, 2024
bc5458b
CU-869671bn4: Fix some model type issues
mart-r Oct 15, 2024
95d294e
CU-869671bn4: Line up with previous model dump methods
mart-r Oct 15, 2024
4e716ae
CU-869671bn4: Fix overwriting model dump methods
mart-r Oct 15, 2024
2834c27
CU-869671bn4: Remove pydantic1 workflow step
mart-r Oct 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ jobs:
- name: Lint
run: |
flake8 medcat
- name: Pydantic 1 check
# NOTE: the following will look for use of pydantic1-specific .dict() method and .__fields__ attribute
# if there are some (that are not annotated for pydantic1 backwards compatibility) a non-zero exit
# code is returned, which will hald the workflow and print out the offending parts
run: |
grep "\.__fields__" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0
grep "\.dict(" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0
- name: Test
run: |
all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g')
Expand All @@ -52,7 +59,6 @@ jobs:
repo: context.repo.repo
});
core.setOutput('latest_version', latestRelease.data.tag_name);

- name: Make sure there's no deprecated methods that should be removed.
# only run this for master -> production PR. I.e just before doing a release.
if: github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.ref == 'production'
Expand Down
2 changes: 1 addition & 1 deletion install_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
'blis>=0.7.5,<1.0.0' # allow later versions, tested with 0.7.9, avoid 1.0.0 (depends on numpy 2)
'click>=8.0.4' # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
'pydantic>=2.0.0,<3.0' # avoid next major release
"humanfriendly~=10.0" # for human readable file / RAM sizes
"peft>=0.8.2"
2 changes: 1 addition & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def _print_stats(self,

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.dict())
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.model_dump())
checkpoint_manager = CheckpointManager('cat_train', checkpoint_config)
if is_resumed:
# TODO: probably remove is_resumed mark and always resume if a checkpoint is provided,
Expand Down
41 changes: 21 additions & 20 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import ModelField
from pydantic import BaseModel, ValidationError
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type, Literal
from multiprocessing import cpu_count
import logging
Expand Down Expand Up @@ -125,7 +124,7 @@ def merge_config(self, config_dict: Dict) -> None:
attr = None # new attribute
value = config_dict[key]
if isinstance(value, BaseModel):
value = value.dict()
value = value.model_dump()
if isinstance(attr, MixingConfig):
attr.merge_config(value)
else:
Expand Down Expand Up @@ -177,7 +176,7 @@ def rebuild_re(self) -> None:
def _calc_hash(self, hasher: Optional[Hasher] = None) -> Hasher:
if hasher is None:
hasher = Hasher()
for _, v in cast(BaseModel, self).dict().items():
for _, v in cast(BaseModel, self).model_dump().items():
if isinstance(v, MixingConfig):
v._calc_hash(hasher)
else:
Expand All @@ -189,7 +188,7 @@ def get_hash(self, hasher: Optional[Hasher] = None):
return hasher.hexdigest()

def __str__(self) -> str:
return str(cast(BaseModel, self).dict())
return str(cast(BaseModel, self).model_dump())

@classmethod
def load(cls, save_path: str) -> "MixingConfig":
Expand Down Expand Up @@ -238,15 +237,15 @@ def asdict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The dictionary associated with this config
"""
return cast(BaseModel, self).dict()
return cast(BaseModel, self).model_dump()

def fields(self) -> Dict[str, ModelField]:
def fields(self) -> dict:
"""Get the fields associated with this config.

Returns:
Dict[str, ModelField]: The dictionary of the field names and fields
dict: The dictionary of the field names and fields
"""
return cast(BaseModel, self).__fields__
return cast(BaseModel, self).model_fields


class VersionInfo(MixingConfig, BaseModel):
Expand All @@ -272,7 +271,7 @@ class VersionInfo(MixingConfig, BaseModel):
"""Which version of medcat was used to build the CDB"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -290,7 +289,7 @@ class CDBMaker(MixingConfig, BaseModel):
"""Minimum number of letters required in a name to be accepted for a concept"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -303,7 +302,7 @@ class AnnotationOutput(MixingConfig, BaseModel):
include_text_in_output: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -317,7 +316,7 @@ class CheckPoint(MixingConfig, BaseModel):
"""When training the maximum checkpoints will be kept on the disk"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -354,7 +353,7 @@ class General(MixingConfig, BaseModel):

NB! For these changes to take effect, the pipe would need to be recreated."""
checkpoint: CheckPoint = CheckPoint()
usage_monitor = UsageMonitor()
usage_monitor: UsageMonitor = UsageMonitor()
"""Checkpointing config"""
log_level: int = logging.INFO
"""Logging config for everything | 'tagger' can be disabled, but will cause a drop in performance"""
Expand Down Expand Up @@ -395,7 +394,7 @@ class General(MixingConfig, BaseModel):
reliable due to not taking into account all the details of the changes."""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -424,7 +423,7 @@ class Preprocessing(MixingConfig, BaseModel):
NB! For these changes to take effect, the pipe would need to be recreated."""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -444,7 +443,7 @@ class Ner(MixingConfig, BaseModel):
"""Try reverse word order for short concepts (2 words max), e.g. heart disease -> disease heart"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -579,7 +578,7 @@ class Linking(MixingConfig, BaseModel):
"""If true when the context of a concept is calculated (embedding) the words making that concept are not taken into account"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -600,7 +599,7 @@ class Config:
# this if for word_skipper and punct_checker which would otherwise
# not have a validator
arbitrary_types_allowed = True
extra = Extra.allow
extra = 'allow'
validate_assignment = True

def __init__(self, *args, **kwargs):
Expand All @@ -618,7 +617,7 @@ def rebuild_re(self) -> None:
# Override
def get_hash(self):
hasher = Hasher()
for k, v in self.dict().items():
for k, v in self.model_dump().items():
if k in ['hash', ]:
# ignore hash
continue
Expand Down Expand Up @@ -674,4 +673,6 @@ def wrapper(*args, **kwargs):
# we get a nicer exceptio
_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly"
Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore
if hasattr(Linking, '__getattr__'):
Linking.__getattr__ = _wrapper(Linking.__getattr__, Linking, _waf_advice, AttributeError) # type: ignore
Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore
12 changes: 6 additions & 6 deletions medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -65,7 +65,7 @@ class General(MixingConfig, BaseModel):
Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -169,7 +169,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -191,7 +191,7 @@ class Train(MixingConfig, BaseModel):
"""If set only this CUIs will be used for training"""
auto_save_model: bool = True
"""Should do model be saved during training for best results"""
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"""When was the last training run"""
metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
"""What metric should be used for choosing the best model"""
Expand All @@ -206,7 +206,7 @@ class Train(MixingConfig, BaseModel):
"""Focal Loss hyperparameter - determines importance the loss gives to hard-to-classify examples"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -217,5 +217,5 @@ class ConfigMetaCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_rel_cat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -89,7 +89,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -116,7 +116,7 @@ class Train(MixingConfig, BaseModel):
"""Should the model be saved during training for best results"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -127,5 +127,5 @@ class ConfigRelCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand All @@ -16,11 +16,11 @@ class General(MixingConfig, BaseModel):
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
verbose_metrics: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -29,5 +29,5 @@ class ConfigTransformersNER(MixingConfig, BaseModel):
general: General = General()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
6 changes: 3 additions & 3 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.train['last_train_on'] is None:
self.config.train['last_train_on'] = datetime.now().timestamp()
if self.config.train.last_train_on is None:
self.config.train.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -310,7 +310,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
# Save everything now
self.save(save_dir_path=save_dir_path)

self.config.train['last_train_on'] = datetime.now().timestamp()
self.config.train.last_train_on = datetime.now().timestamp()
return report

def eval(self, json_path: str) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.general['last_train_on'] is None:
self.config.general['last_train_on'] = datetime.now().timestamp()
if self.config.general.last_train_on is None:
self.config.general.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -242,7 +242,7 @@ def train(self,
trainer.train() # type: ignore

# Save the training time
self.config.general['last_train_on'] = datetime.now().timestamp() # type: ignore
self.config.general.last_train_on = datetime.now().timestamp() # type: ignore

# Save everything
self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model'))
Expand Down
2 changes: 1 addition & 1 deletion medcat/utils/regression/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def to_dict(self) -> dict:
d = {}
for case in self.cases:
d[case.name] = case.to_dict()
d['meta'] = self.metadata.dict()
d['meta'] = self.metadata.model_dump()
fix_np_float64(d['meta'])

return d
Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/regression/regression_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def main(model_pack_dir: Path, test_suite_file: Path,
examples_strictness = Strictness[examples_strictness_str]
if jsonpath:
logger.info('Writing to %s', str(jsonpath))
jsonpath.write_text(json.dumps(res.dict(strictness=examples_strictness),
indent=jsonindent))
dumped = res.model_dump(strictness=examples_strictness)
jsonpath.write_text(json.dumps(dumped, indent=jsonindent))
else:
logger.info(res.get_report(phrases_separately=phrases,
hide_empty=hide_empty, examples_strictness=examples_strictness,
Expand Down
Loading
Loading