Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cu 8694gzbn3 k fold metrics debug #31

Closed
wants to merge 78 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
35a66f3
CU-8694gzbud: Add context manager that is able to snapshot CDB state
mart-r May 8, 2024
df8459b
CU-8694gzbud: Add tests to snapshotting CDB state
mart-r May 8, 2024
20a403e
CU-8694gzbud: Refactor tests for CDB state snapshotting
mart-r May 8, 2024
557602b
CU-8694gzbud: Remove use of deprecated method in CDB utils and use no…
mart-r May 8, 2024
16e875b
CU-8694gzbud: Add tests for training and CDB state capturing
mart-r May 8, 2024
0294fea
CU-8694gzbud: Small refactor in tests
mart-r May 8, 2024
cc2bcf5
CU-8694gzbud: Add option to save state on disk
mart-r May 8, 2024
8b19609
CU-8694gzbud: Add debug logging output when saving state on disk
mart-r May 8, 2024
3017067
CU-8694gzbud: Remove unused import
mart-r May 8, 2024
de74d0c
CU-8694gzbud: Add tests for disk-based state save
mart-r May 9, 2024
4a269ab
CU-8694gzbud: Move CDB state code to its own module
mart-r May 9, 2024
bfcff96
CU-8694gzbud: Remove unused import
mart-r May 9, 2024
c0ba2b9
CU-8694gzbud: Add doc strings to methods
mart-r May 9, 2024
da585f5
CU-8694gzbx4: Small optimisation for stats
mart-r May 9, 2024
50ae4d7
CU-8694gzbx4: Add MCTExport related module
mart-r May 10, 2024
b51117d
CU-8694gzbx4: Add MCTExport related tests
mart-r May 10, 2024
c9944d3
CU-8694gzbx4: Add code for k-fold statistics
mart-r May 10, 2024
bad5936
CU-8694gzbx4: Add tests for k-fold statistics
mart-r May 10, 2024
988905c
CU-8694gzbx4: Add test-MCT export with fake concepts
mart-r May 10, 2024
b8c4ff4
CU-8694gzbx4: Fix a doc string
mart-r May 10, 2024
bfdb4ca
CU-8694gzbx4: Fix types in MCT export module
mart-r May 10, 2024
df2bde7
CU-8694gzbx4: Fix types in k-fold module
mart-r May 10, 2024
08a2d0d
CU-8694gzbx4: Remove accidentally committed test class
mart-r May 10, 2024
05eb171
CU-8694gzbn3: Add missing test helper file
mart-r May 11, 2024
383c929
CU-8694gzbn3: Remove whitespace change from otherwise uncahnged file
mart-r May 11, 2024
48ce03e
CU-8694gzbn3: Allow 5 minutes longer for tests
mart-r May 11, 2024
316ec76
CU-8694gzbn3: Move to python 3.8-compatible typed dict
mart-r May 11, 2024
716e939
CU-8694gzbn3: Add more time for tests in worklow (now 30 minutes)
mart-r May 12, 2024
5be04d2
CU-8694gzbn3: Add more time for tests in worklow (now 45 minutes)
mart-r May 13, 2024
9459b8d
CU-8694gzbn3: Update test-pypi timeout to 45 minutes
mart-r May 13, 2024
3618b9c
CU-8694gzbn3: Remove timeout from unit tests in main workflow
mart-r May 13, 2024
94bce56
CU-8694gzbn3: Make tests stop upon first failure
mart-r May 13, 2024
666c013
CU-8694gzbn3: Fix test stop upon first failure (arg/option order)
mart-r May 13, 2024
8c5c0d0
CU-8694gzbn3: Remove debug code and old comments
mart-r May 13, 2024
e6debce
CU-8694gzbn3: Remove all timeouts from main workflow
mart-r May 13, 2024
abd8d0b
CU-8694gzbn3: Remove more old / useless comments in tests
mart-r May 14, 2024
03531da
CU-8694gzbn3: Add debug output when running k-fold tests to see where…
mart-r May 14, 2024
12c519a
CU-8694gzbn3: Add debug output when ANY tests to see where it may be …
mart-r May 14, 2024
9b02925
CU-8694gzbn3: Remove explicit debug output from k-fold test cases
mart-r May 14, 2024
faaf7fb
CU-8694gzbn3: Remove timeouts from DEID tests in case they're the one…
mart-r May 14, 2024
c720472
CU-8694gzbn3: Disable/skip some KFOLD tests
mart-r May 14, 2024
58d026a
CU-8694gzbn3: Disable/skip some MCT Export tests
mart-r May 14, 2024
68d0f73
CU-8694gzbn3: REMOVE all tests
mart-r May 14, 2024
64ded76
CU-8694gzbn3: Remove dependence on tests
mart-r May 14, 2024
359eff7
CU-8694gzbn3: Skip CDB state tests
mart-r May 14, 2024
513ae55
Revert "CU-8694gzbn3: Remove dependence on tests"
mart-r May 14, 2024
1fd324d
Revert "CU-8694gzbn3: REMOVE all tests"
mart-r May 14, 2024
37814ef
Revert "CU-8694gzbn3: Disable/skip some MCT Export tests"
mart-r May 14, 2024
840c443
Revert "CU-8694gzbn3: Disable/skip some KFOLD tests"
mart-r May 14, 2024
6824dc6
Revert "CU-8694gzbn3: Remove timeouts from DEID tests in case they're…
mart-r May 15, 2024
3eb46aa
Revert "CU-8694gzbn3: Remove explicit debug output from k-fold test c…
mart-r May 15, 2024
78d6032
Revert "CU-8694gzbn3: Add debug output when ANY tests to see where it…
mart-r May 15, 2024
69efaba
Revert "CU-8694gzbn3: Add debug output when running k-fold tests to s…
mart-r May 15, 2024
91e25a6
Revert "CU-8694gzbn3: Remove all timeouts from main workflow"
mart-r May 15, 2024
7f1c461
Revert "CU-8694gzbn3: Fix test stop upon first failure (arg/option or…
mart-r May 15, 2024
fdfdc27
Revert "CU-8694gzbn3: Make tests stop upon first failure"
mart-r May 15, 2024
75b8357
Revert "CU-8694gzbn3: Remove timeout from unit tests in main workflow"
mart-r May 15, 2024
77fae58
CU-8694gzbn3: Improve state copy code in CDB state tests
mart-r May 15, 2024
df5c4ef
CU-8694gzbn3: Re-allow in-memory state tests
mart-r May 15, 2024
8625bae
Revert "CU-8694gzbn3: Re-allow in-memory state tests"
mart-r May 15, 2024
972069b
CU-8694gzbn3: Allow on-disk state tests
mart-r May 15, 2024
d3fa875
CU-8694gzbn3: Add back an in-memory state tests
mart-r May 15, 2024
d5d1770
CU-8694gzbn3: Fix a CDB state test issue
mart-r May 15, 2024
e737878
CU-8694gzbn3: Allow CDB state tests to run
mart-r May 15, 2024
5d2dc29
CU-8694gzbn3: Remove on disk CDB state saves
mart-r May 15, 2024
fd6889c
CU-8694gzbn3: Add on-disk state save tests but without the tests-spec…
mart-r May 15, 2024
550a9a8
CU-8694gzbn3: Add some debug output for on-disk state save test setup
mart-r May 15, 2024
11d183b
CU-8694gzbn3: Add more debug output for on-disk state save test setup
mart-r May 15, 2024
e3ac178
CU-8694gzbn3: Add even more debug output for on-disk state save test …
mart-r May 15, 2024
ad0d7f1
CU-8694gzbn3: Only test CDB state
mart-r May 15, 2024
689d6c0
CU-8694gzbn3: Run all CDB state tests
mart-r May 15, 2024
d02a792
CU-8694gzbn3: Run all new tests
mart-r May 15, 2024
f4f9997
CU-8694gzbn3: Separate the running of most tests from the new ones
mart-r May 15, 2024
e2e2094
CU-8694gzbn3: Remove old (unused) medmentions tests(?)
mart-r May 15, 2024
28d008d
CU-8694gzbn3: Remove old (unused) achived tests
mart-r May 15, 2024
b940fb3
CU-8694gzbn3: Fix running of non-new tests (I think)
mart-r May 15, 2024
df5155e
CU-8694gzbn3: Split tests into 2 sets
mart-r May 16, 2024
61add5b
CU-8694gzbn3: Add timeout to the two halves of the tests
mart-r May 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ jobs:
flake8 medcat
- name: Test
run: |
timeout 17m python -m unittest discover
all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g')
num_files=$(echo "$all_files" | wc -l)
midpoint=$((num_files / 2))
first_half_nl=$(echo "$all_files" | head -n $midpoint)
second_half_nl=$(echo "$all_files" | tail -n +$(($midpoint + 1)))
timeout 25m python -m unittest ${first_half_nl[@]}
timeout 25m python -m unittest ${second_half_nl[@]}

publish-to-test-pypi:

Expand All @@ -43,7 +49,7 @@ jobs:
github.event_name == 'push' &&
startsWith(github.ref, 'refs/tags') != true
runs-on: ubuntu-20.04
timeout-minutes: 20
timeout-minutes: 45
concurrency: publish-to-test-pypi
needs: [build]

Expand Down
286 changes: 286 additions & 0 deletions medcat/stats/kfold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterator, Union, Callable, cast, Any

from copy import deepcopy

from medcat.utils.checkpoint import Checkpoint
from medcat.utils.cdb_state import captured_state_cdb

from medcat.stats.stats import get_stats
from medcat.stats.mctexport import MedCATTrainerExport, MedCATTrainerExportProject
from medcat.stats.mctexport import MedCATTrainerExportDocument, MedCATTrainerExportAnnotation
from medcat.stats.mctexport import count_all_annotations, count_all_docs
from medcat.stats.mctexport import iter_anns, iter_docs, MedCATTrainerExportProjectInfo



class CDBLike(Protocol):
pass


class CATLike(Protocol):

@property
def cdb(self) -> CDBLike:
pass

def train_supervised_raw(self,
data: Dict[str, List[Dict[str, dict]]],
reset_cui_count: bool = False,
nepochs: int = 1,
print_stats: int = 0,
use_filters: bool = False,
terminate_last: bool = False,
use_overlaps: bool = False,
use_cui_doc_limit: bool = False,
test_size: float = 0,
devalue_others: bool = False,
use_groups: bool = False,
never_terminate: bool = False,
train_from_false_positives: bool = False,
extra_cui_filter: Optional[Set] = None,
retain_extra_cui_filter: bool = False,
checkpoint: Optional[Checkpoint] = None,
retain_filters: bool = False,
is_resumed: bool = False) -> Tuple:
pass


class FoldCreator:

def __init__(self, mct_export: MedCATTrainerExport, nr_of_folds: int,
use_annotations: bool) -> None:
self.mct_export = mct_export
self.nr_of_folds = nr_of_folds
self.use_annotations = use_annotations
self._targets: Union[
Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation]],
Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]
]
self._adder: Union[
Callable[[MedCATTrainerExportProject, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation], None],
Callable[[MedCATTrainerExportProject, MedCATTrainerExportDocument], None]
]
if self.use_annotations:
self.total = count_all_annotations(self.mct_export)
self._targets = iter_anns(self.mct_export)
self._adder = self._add_target_ann
else:
self.total = count_all_docs(self.mct_export)
self._targets = iter_docs(self.mct_export)
self._adder = self._add_target_doc
self.per_fold = self._init_per_fold()

def _add_target_doc(self, project: MedCATTrainerExportProject,
doc: MedCATTrainerExportDocument) -> None:
project['documents'].append(doc)

def _find_or_add_doc(self, project: MedCATTrainerExportProject, orig_doc: MedCATTrainerExportDocument
) -> MedCATTrainerExportDocument:
for existing_doc in project['documents']:
if existing_doc['name'] == orig_doc['name']:
return existing_doc
new_doc: MedCATTrainerExportDocument = deepcopy(orig_doc)
new_doc['annotations'].clear()
project['documents'].append(new_doc)
return new_doc

def _add_target_ann(self, project: MedCATTrainerExportProject, orig_doc: MedCATTrainerExportDocument,
ann: MedCATTrainerExportAnnotation) -> None:
cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc)
cur_doc['annotations'].append(ann)

def _init_per_fold(self) -> List[int]:
per_fold = [self.total // self.nr_of_folds for _ in range(self.nr_of_folds)]
total = sum(per_fold)
if total < self.total:
per_fold[-1] += self.total - total
if any(pf <= 0 for pf in per_fold):
raise ValueError(f"Failed to calculate per-fold items. Got: {per_fold}")
return per_fold

def _create_fold(self, fold_nr: int) -> MedCATTrainerExport:
per_fold = self.per_fold[fold_nr]
cur_fold: MedCATTrainerExport = {
'projects': []
}
cur_project: Optional[MedCATTrainerExportProject] = None
included = 0
for target in self._targets:
(proj_name, proj_id, proj_cuis, proj_tuis), *target_info = target
if not cur_project or cur_project['name'] != proj_name:
# first or new project
cur_project = cast(MedCATTrainerExportProject, {
'name': proj_name,
'id': proj_id,
'cuis': proj_cuis,
'documents': [],
})
# NOTE: Some MCT exports don't declare TUIs
if proj_tuis is not None:
cur_project['tuis'] = proj_tuis
cur_fold['projects'].append(cur_project)
self._adder(cur_project, *target_info)
included += 1
if included == per_fold:
break
if included > per_fold:
raise ValueError("Got a larger fold than expected. "
f"Expected {per_fold}, got {included}")
return cur_fold


def create_folds(self) -> List[MedCATTrainerExport]:
return [
self._create_fold(fold_nr) for fold_nr in range(self.nr_of_folds)
]


def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport],
*args, **kwargs) -> List[Tuple]:
metrics = []
for fold_nr, cur_fold in enumerate(folds):
others = list(folds)
others.pop(fold_nr)
with captured_state_cdb(cat.cdb):
for other in others:
cat.train_supervised_raw(cast(Dict[str, Any], other), *args, **kwargs)
stats = get_stats(cat, cast(Dict[str, Any], cur_fold), do_print=False)
metrics.append(stats)
return metrics


def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]],
single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None:
if len(joined) != len(single):
raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}")
for j, s in zip(joined, single):
_update_one_weighted_average(j, s, cui2count)


def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]],
one: Dict[str, float],
cui2count: Dict[str, int]) -> None:
for k in one:
if k not in joined:
joined[k] = (0, 0)
prev_w, prev_val = joined[k]
new_w, new_val = cui2count[k], one[k]
total_w = prev_w + new_w
total_val = (prev_w * prev_val + new_w * new_val) / total_w
joined[k] = (total_w, total_val)


def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None:
if len(joined) != len(single):
raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}")
for j, s in zip(joined, single):
for k, v in s.items():
j[k] = j.get(k, 0) + v


def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
for ex_type, ex_dict in cur_examples.items():
if ex_type not in all_examples:
all_examples[ex_type] = {}
per_type_examples = all_examples[ex_type]
for ex_cui, cui_examples_list in ex_dict.items():
if ex_cui not in per_type_examples:
per_type_examples[ex_cui] = []
per_type_examples[ex_cui].extend(cui_examples_list)


def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]]
) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
"""The the mean of the provided metrics.

Args:
metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics.

Returns:
fps (dict):
False positives for each CUI.
fns (dict):
False negatives for each CUI.
tps (dict):
True positives for each CUI.
cui_prec (dict):
Precision for each CUI.
cui_rec (dict):
Recall for each CUI.
cui_f1 (dict):
F1 for each CUI.
cui_counts (dict):
Number of occurrence for each CUI.
examples (dict):
Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][<list_of_examples>].
"""
# additives
all_fps: Dict[str, int] = {}
all_fns: Dict[str, int] = {}
all_tps: Dict[str, int] = {}
# weighted-averages
all_cui_prec: Dict[str, Tuple[int, float]] = {}
all_cui_rec: Dict[str, Tuple[int, float]] = {}
all_cui_f1: Dict[str, Tuple[int, float]] = {}
# additive
all_cui_counts: Dict[str, int] = {}
# combined
all_additives = [
all_fps, all_fns, all_tps, all_cui_counts
]
all_weighted_averages = [
all_cui_prec, all_cui_rec, all_cui_f1
]
# examples
all_examples: dict = {}
for current in metrics:
cur_wa: list = list(current[3:-2])
cur_counts = current[-2]
_update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts)
# update ones that just need to be added up
cur_adds = list(current[:3]) + [cur_counts]
_update_all_add(all_additives, cur_adds)
# merge examples
cur_examples = current[-1]
_merge_examples(all_examples, cur_examples)
cui_prec: Dict[str, float] = {}
cui_rec: Dict[str, float] = {}
cui_f1: Dict[str, float] = {}
final_wa = [
cui_prec, cui_rec, cui_f1
]
# just remove the weight / count
for df, d in zip(final_wa, all_weighted_averages):
for k, v in d.items():
df[k] = v[1] # only the value, ingore the weight
return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2],
all_cui_counts, all_examples)


def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3,
use_annotations: bool = False, *args, **kwargs) -> Tuple:
"""Get the k-fold stats for the model with the specified data.

First this will split the MCT export into `k` folds. You can do
this either per document or per-annotation.

For each of the `k` folds, it will start from the base model,
train it with with the other `k-1` folds and record the metrics.
After that the base model state is restored before doing the next fold.
After all the folds have been done, the metrics are averaged.

Args:
cat (CATLike): The model pack.
mct_export_data (MedCATTrainerExport): The MCT export.
k (int): The number of folds. Defaults to 3.
use_annotations (bool): Whether to use annodations or docs. Defaults to False (docs).
*args: Arguments passed to the `CAT.train_supervised_raw` method.
**kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method.

Returns:
Tuple: The averaged metrics.
"""
creator = FoldCreator(mct_export_data, k, use_annotations=use_annotations)
folds = creator.create_folds()
per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs)
return get_metrics_mean(per_fold_metrics)
69 changes: 69 additions & 0 deletions medcat/stats/mctexport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Iterator, Tuple, Any, Optional
from typing_extensions import TypedDict


class MedCATTrainerExportAnnotation(TypedDict):
start: int
end: int
cui: str
value: str


class MedCATTrainerExportDocument(TypedDict):
name: str
id: Any
last_modified: str
text: str
annotations: List[MedCATTrainerExportAnnotation]


class MedCATTrainerExportProject(TypedDict):
name: str
id: Any
cuis: str
tuis: Optional[str]
documents: List[MedCATTrainerExportDocument]


MedCATTrainerExportProjectInfo = Tuple[str, Any, str, Optional[str]]
"""The project name, project ID, CUIs str, and TUIs str"""


class MedCATTrainerExport(TypedDict):
projects: List[MedCATTrainerExportProject]


def iter_projects(export: MedCATTrainerExport) -> Iterator[MedCATTrainerExportProject]:
yield from export['projects']


def iter_docs(export: MedCATTrainerExport
) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument]]:
for project in iter_projects(export):
info: MedCATTrainerExportProjectInfo = (
project['name'], project['id'], project['cuis'], project.get('tuis', None)
)
for doc in project['documents']:
yield info, doc


def iter_anns(export: MedCATTrainerExport
) -> Iterator[Tuple[MedCATTrainerExportProjectInfo, MedCATTrainerExportDocument, MedCATTrainerExportAnnotation]]:
for proj_info, doc in iter_docs(export):
for ann in doc['annotations']:
yield proj_info, doc, ann


def count_all_annotations(export: MedCATTrainerExport) -> int:
return len(list(iter_anns(export)))


def count_all_docs(export: MedCATTrainerExport) -> int:
return len(list(iter_docs(export)))


# def stich_exports(exports: List[MedCATTrainerExport]) -> MedCATTrainerExport:
# stiched: MedCATTrainerExport = {"projects": []}
# for export in exports:
# pass
# pass
Loading
Loading