From f122f240988adc858d3d264327993526eed92b28 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Tue, 17 Sep 2024 14:30:11 +0300 Subject: [PATCH] CU-8695pvhfe fix usage monitoring for multiprocessing (#488) * CU-8695pvhfe: Rename a test class * CU-8695pvhfe: Add tests for multiprocessig usage monitoring * CU-8695pvhfe: Fix usage monitor for multiprocessig. When using CAT.multiprocessing_batch_char_size (CAT._multiprocessing_batch and CAT._mp_cons internally), flush the usage monitor at the end of multiprocessing method. When using CAT.get_entities_multi_texts or CAT.multiprocessing_batch_docs_size (uses the former internally), add logging of usage to output * CU-8695pvhfe: Fix remaining issues with usage monitor for multiprocessig. Avoid checking length of (potentially) non-existent strings. Avoid early iteration of generator. --- medcat/cat.py | 23 +++++++++++++- tests/test_cat.py | 47 +++++++++++++++++++++++++--- tests/utils/test_usage_monitoring.py | 2 +- 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/medcat/cat.py b/medcat/cat.py index 621a2e83..707dbd7f 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -1127,11 +1127,29 @@ def get_entities_multi_texts(self, self.pipe.set_error_handler(self._pipe_error_handler) try: texts_ = self._get_trimmed_texts(texts) + if self.config.general.usage_monitor.enabled: + input_lengths: List[Tuple[int, int]] = [] + for orig_text, trimmed_text in zip(texts, texts_): + if orig_text is None or trimmed_text is None: + l1, l2 = 0, 0 + else: + l1 = len(orig_text) + l2 = len(trimmed_text) + input_lengths.append((l1, l2)) docs = self.pipe.batch_multi_process(texts_, n_process, batch_size) - for doc in tqdm(docs, total=len(texts_)): + for doc_nr, doc in tqdm(enumerate(docs), total=len(texts_)): doc = None if doc.text.strip() == '' else doc out.append(self._doc_to_out(doc, only_cui, addl_info, out_with_text=True)) + if self.config.general.usage_monitor.enabled: + l1, l2 = input_lengths[doc_nr] + if doc is None: + nents = 0 + elif self.config.general.show_nested_entities: + nents = len(doc._.ents) # type: ignore + else: + nents = len(doc.ents) # type: ignore + self.usage_monitor.log_inference(l1, l2, nents) # Currently spaCy cannot mark which pieces of texts failed within the pipe so be this workaround, # which also assumes texts are different from each others. @@ -1637,6 +1655,9 @@ def _mp_cons(self, in_q: Queue, out_list: List, min_free_memory: float, logger.warning("PID: %s failed one document in _mp_cons, running will continue normally. \n" + "Document length in chars: %s, and ID: %s", pid, len(str(text)), i_text) logger.warning(str(e)) + if self.config.general.usage_monitor.enabled: + # NOTE: This is in another process, so need to explicitly flush + self.usage_monitor._flush_logs() sleep(2) def _add_nested_ent(self, doc: Doc, _ents: List[Span], _ent: Union[Dict, Span]) -> None: diff --git a/tests/test_cat.py b/tests/test_cat.py index 4c237f58..17cdd281 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -2,6 +2,8 @@ import os import sys import time +from typing import Callable +from functools import partial import unittest from unittest.mock import mock_open, patch import tempfile @@ -595,18 +597,55 @@ def test_get_entities_gets_monitored(self, contents = f.readline() self.assertTrue(contents) + def assert_gets_usage_monitored(self, data_processor: Callable[[None], None], exp_logs: int = 1): + # clear usage monitor buffer + self.undertest.usage_monitor.log_buffer.clear() + data_processor() + file = self.undertest.usage_monitor.log_file + if os.path.exists(file): + with open(file) as f: + content = f.readlines() + content += self.undertest.usage_monitor.log_buffer + else: + content = self.undertest.usage_monitor.log_buffer + self.assertTrue(content) + self.assertEqual(len(content), exp_logs) + def test_get_entities_logs_usage(self, text="The dog is sitting outside the house."): # clear usage monitor buffer - self.undertest.usage_monitor.log_buffer.clear() - self.undertest.get_entities(text) - self.assertTrue(self.undertest.usage_monitor.log_buffer) - self.assertEqual(len(self.undertest.usage_monitor.log_buffer), 1) + self.assert_gets_usage_monitored(partial(self.undertest.get_entities, text), 1) line = self.undertest.usage_monitor.log_buffer[0] # the 1st element is the input text length input_text_length = line.split(",")[1] self.assertEqual(str(len(text)), input_text_length) + TEXT4MP_USAGE = [ + ("ID1", "Text with house and dog one"), + ("ID2", "Text with house and dog two"), + ("ID3", "Text with house and dog three"), + ("ID4", "Text with house and dog four"), + ("ID5", "Text with house and dog five"), + ("ID6", "Text with house and dog siz"), + ("ID7", "Text with house and dog seven"), + ("ID8", "Text with house and dog eight"), + ] + + def test_mp_batch_char_size_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.multiprocessing_batch_char_size, all_text, nproc=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + + def test_mp_get_multi_texts_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.get_entities_multi_texts, all_text, n_process=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + + def test_mp_batch_docs_size_logs_usage(self): + all_text = self.TEXT4MP_USAGE + proc = partial(self.undertest.multiprocessing_batch_docs_size, all_text, nproc=2) + self.assert_gets_usage_monitored(proc, len(all_text)) + def test_simple_hashing_is_faster(self): self.undertest.config.general.simple_hash = False st = time.perf_counter() diff --git a/tests/utils/test_usage_monitoring.py b/tests/utils/test_usage_monitoring.py index 936cde37..b47345bf 100644 --- a/tests/utils/test_usage_monitoring.py +++ b/tests/utils/test_usage_monitoring.py @@ -89,7 +89,7 @@ def test_some_in_file(self): self.assertEqual(len(lines), self.expected_in_file) -class UMT(UsageMonitorBaseTests): +class UsageMonitoringAutoTests(UsageMonitorBaseTests): ENABLED_DICT = { "MEDCAT_USAGE_LOGS": "True", "MEDCAT_USAGE_LOGS_LOCATION": "."