diff --git a/medcat/cat.py b/medcat/cat.py index 36fc265bb..b2d3f7cb3 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -40,7 +40,7 @@ from medcat.vocab import Vocab from medcat.utils.decorators import deprecated from medcat.ner.transformers_ner import TransformersNER -from medcat.utils.saving.serializer import SPECIALITY_NAMES +from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY logger = logging.getLogger(__name__) # separate logger from the package-level one @@ -356,7 +356,8 @@ def load_model_pack(cls, # Load the CDB cdb_path = os.path.join(model_pack_path, "cdb.dat") - has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= len(SPECIALITY_NAMES) + nr_of_jsons_expected = len(SPECIALITY_NAMES) - len(ONE2MANY) + has_jsons = len(glob.glob(os.path.join(model_pack_path, '*.json'))) >= nr_of_jsons_expected json_path = model_pack_path if has_jsons else None logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') cdb = CDB.load(cdb_path, json_path) diff --git a/medcat/cdb.py b/medcat/cdb.py index a4f87edf2..44d4fd9dd 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -95,6 +95,7 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._optim_params = None self.is_dirty = False self._hash: Optional[str] = None + self._memory_optimised_parts: Set[str] = set() def get_name(self, cui: str) -> str: """Returns preferred name if it exists, otherwise it will return @@ -180,9 +181,13 @@ def remove_cui(self, cui: str) -> None: for name, cuis2status in self.name2cuis2status.items(): if cui in cuis2status: del cuis2status[cui] - self.snames = set() - for cuis in self.cui2snames.values(): - self.snames |= cuis + if isinstance(self.snames, set): + # if this is a memory optimised CDB, this won't be a set + # but it also won't need to be changed since it + # relies directly on cui2snames + self.snames = set() + for cuis in self.cui2snames.values(): + self.snames |= cuis self.name2count_train = {name: len(cuis) for name, cuis in self.name2cuis.items()} self.is_dirty = True @@ -561,6 +566,10 @@ def filter_by_cui(self, cuis_to_keep: Union[List[str], Set[str]]) -> None: This also will not remove any data from cdb.addl_info - as this field can contain data of unknown structure. + As a side note, if the CDB has been memory-optimised, filtering will undo this memory optimisation. + This is because the dicts being involved will be rewritten. + However, the memory optimisation can be performed again afterwards. + Args: cuis_to_keep (List[str]): CUIs that will be kept, the rest will be removed (not completely, look above). @@ -624,6 +633,8 @@ def filter_by_cui(self, cuis_to_keep: Union[List[str], Set[str]]) -> None: self.cui2type_ids = new_cui2type_ids self.cui2preferred_name = new_cui2preferred_name self.is_dirty = True + # reset memory optimisation state + self._memory_optimised_parts.clear() def make_stats(self): stats = {} diff --git a/medcat/utils/memory_optimiser.py b/medcat/utils/memory_optimiser.py new file mode 100644 index 000000000..e8328734d --- /dev/null +++ b/medcat/utils/memory_optimiser.py @@ -0,0 +1,366 @@ +from typing import Any, Dict, KeysView, Iterator, List, Tuple, Union, Optional, Set + +from medcat.cdb import CDB +from medcat.utils.saving.coding import EncodeableObject, PartEncoder, PartDecoder, UnsuitableObject, register_encoder_decoder + + +CUI_DICT_NAMES_TO_COMBINE = [ + "cui2names", "cui2snames", "cui2context_vectors", + "cui2count_train", "cui2tags", "cui2type_ids", + "cui2preferred_name", "cui2average_confidence", +] +ONE2MANY = 'cui2many' + +NAME_DICT_NAMES_TO_COMBINE = [ + "cui2names", "name2cuis2status", "cui2preferred_name", +] +NAME2MANY = 'name2many' + +DELEGATING_DICT_IDENTIFIER = '==DELEGATING_DICT==' + +DELEGATING_SET_IDENTIFIER = '==DELEGATING_SET==' + +# these will be used in CDB._memory_optimised_parts +CUIS_PART = 'CUIS' +NAMES_PART = 'NAMES' +SNAMES_PART = 'snames' + + +class _KeysView: + def __init__(self, keys: KeysView, parent: 'DelegatingDict'): + self._keys = keys + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._keys: + if key in self._parent: + yield key + + def __len__(self) -> int: + return len([_ for _ in self]) + + +class _ItemsView: + def __init__(self, parent: 'DelegatingDict') -> None: + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._parent: + yield key, self._parent[key] + + def __len__(self) -> int: + return len(self._parent) + + +class _ValuesView: + def __init__(self, parent: 'DelegatingDict') -> None: + self._parent = parent + + def __iter__(self) -> Iterator[Any]: + for key in self._parent: + yield self._parent[key] + + def __len__(self) -> int: + return len(self._parent) + + +class DelegatingDict: + + def __init__(self, delegate: Dict[str, List[Any]], nr: int, + nr_of_overall_items: int = 8) -> None: + self.delegate = delegate + self.nr = nr + self.nr_of_overall_items = nr_of_overall_items + + def _generate_empty_entry(self) -> List[Any]: + return [None for _ in range(self.nr_of_overall_items)] + + def __getitem__(self, key: str) -> Any: + val = self.delegate[key][self.nr] + if val is None: + raise KeyError + return val + + def get(self, key: str, default: Any) -> Any: + try: + return self[key] + except KeyError: + return default + + def __setitem__(self, key: str, value: Any) -> None: + if key not in self.delegate: + self.delegate[key] = self._generate_empty_entry() + self.delegate[key][self.nr] = value + + def __contains__(self, key: str) -> bool: + return key in self.delegate and self.delegate[key][self.nr] is not None + + def keys(self) -> _KeysView: + return _KeysView(self.delegate.keys(), self) + + def items(self) -> _ItemsView: + return _ItemsView(self) + + def values(self) -> _ValuesView: + return _ValuesView(self) + + def __iter__(self) -> Iterator[str]: + yield from self.keys() + + def __len__(self) -> int: + return len(self.keys()) + + def to_dict(self) -> dict: + return {'delegate': None, + 'nr': self.nr, + 'nr_of_overall_items': self.nr_of_overall_items} + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, DelegatingDict): + return False + return self.delegate == __value.delegate and self.nr == __value.nr + + def __hash__(self) -> int: + return hash((self.delegate, self.nr)) + + def __delitem__(self, key: str) -> None: + self[key] = None + + def pop(self, key: str, default: Optional[Any] = None) -> Any: + if key in self: + item = self[key] + else: + item = default + del self[key] + return item + + +class DelegatingValueSet: + + def __init__(self, delegate: Dict[str, Set[str]]) -> None: + self.delegate = delegate + + def update(self, other: Any) -> None: + # do nothing since the value will be updated in delegate + pass + + def __contains__(self, value: str) -> bool: + for cui_value in self.delegate.values(): + if value in cui_value: + return True + return False + + def to_dict(self) -> dict: + return {'delegate': None} + + +class DelegatingDictEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, DelegatingDict): + return {DELEGATING_DICT_IDENTIFIER: obj.to_dict()} + raise UnsuitableObject() + + +class DelegatingDictDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, EncodeableObject]: + if DELEGATING_DICT_IDENTIFIER in dct: + info = dct[DELEGATING_DICT_IDENTIFIER] + delegate = info['delegate'] + nr = info['nr'] + overall = info['nr_of_overall_items'] + return DelegatingDict(delegate, nr, overall) + return dct + + +class DelegatingValueSetEncoder(PartEncoder): + + def try_encode(self, obj): + if isinstance(obj, DelegatingValueSet): + return {DELEGATING_SET_IDENTIFIER: obj.to_dict()} + raise UnsuitableObject() + + +class DelegatingValueSetDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, EncodeableObject]: + if DELEGATING_SET_IDENTIFIER in dct: + info = dct[DELEGATING_SET_IDENTIFIER] + delegate = info['delegate'] + return DelegatingValueSet(delegate) + return dct + + +def attempt_fix_after_load(cdb: CDB): + _attempt_fix_after_load(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + _attempt_fix_after_load(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + + +def attempt_fix_snames_after_load(cdb: CDB, snames_attr_name: str = 'snames'): + snames = getattr(cdb, snames_attr_name) + if isinstance(snames, DelegatingValueSet) and snames.delegate is None: + snames = DelegatingValueSet(cdb.cui2snames) + setattr(cdb, snames_attr_name, snames) + + +# register encoder and decoders +register_encoder_decoder(encoder=DelegatingDictEncoder, + decoder=DelegatingDictDecoder, + loading_postprocessor=attempt_fix_after_load) +register_encoder_decoder(encoder=DelegatingValueSetEncoder, + decoder=DelegatingValueSetDecoder, + loading_postprocessor=attempt_fix_snames_after_load) + + +def _optimise(cdb: CDB, to_many_name: str, dict_names_to_combine: List[str]) -> None: + dicts = [getattr(cdb, dict_name) + for dict_name in dict_names_to_combine] + one2many, delegators = map_to_many(dicts) + for delegator, name in zip(delegators, dict_names_to_combine): + setattr(cdb, name, delegator) + setattr(cdb, to_many_name, one2many) + cdb.is_dirty = True + + +def _optimise_snames(cdb: CDB, cui2snames: str = 'cui2snames', + snames_attr: str = 'snames') -> None: + """Optimise the snames part of a CDB. + + Args: + cdb (CDB): The CDB to optimise snames on. + one2many_name (str): The cui2snames dict name to delegate to. Defaults to 'cui2snames'. + snames_attr (str, optional): The `snames` attribute name. Defaults to 'snames'. + """ + delegate = getattr(cdb, cui2snames) + dvs = DelegatingValueSet(delegate) + setattr(cdb, snames_attr, dvs) + cdb.is_dirty = True + + +def perform_optimisation(cdb: CDB, optimise_cuis: bool = True, + optimise_names: bool = False, + optimise_snames: bool = True) -> None: + """Attempts to optimise the memory footprint of the CDB. + + This can perform optimisation for cui2<...> and name2<...> dicts. + However, by default, only cui2many optimisation will be done. + This is because at the time of writing, there were not enough name2<...> + dicts to be able to benefit from the optimisation. + + Does so by unifying the following dicts: + + cui2names (Dict[str, Set[str]]): + From cui to all names assigned to it. Mainly used for subsetting (maybe even only). + cui2snames (Dict[str, Set[str]]): + From cui to all sub-names assigned to it. Only used for subsetting. + cui2context_vectors (Dict[str, Dict[str, np.array]]): + From cui to a dictionary of different kinds of context vectors. Normally you would have here + a short and a long context vector - they are calculated separately. + cui2count_train (Dict[str, int]): + From CUI to the number of training examples seen. + cui2tags (Dict[str, List[str]]): + From CUI to a list of tags. This can be used to tag concepts for grouping of whatever. + cui2type_ids (Dict[str, Set[str]]): + From CUI to type id (e.g. TUI in UMLS). + cui2preferred_name (Dict[str, str]): + From CUI to the preferred name for this concept. + cui2average_confidence (Dict[str, str]): + Used for dynamic thresholding. Holds the average confidence for this CUI given the training examples. + + name2cuis (Dict[str, List[str]]): + Map fro concept name to CUIs - one name can map to multiple CUIs. + name2cuis2status (Dict[str, Dict[str, str]]): + What is the status for a given name and cui pair - each name can be: + P - Preferred, A - Automatic (e.g. let medcat decide), N - Not common. + name2count_train (Dict[str, str]): + Counts how often did a name appear during training. + + It can also delegate the `snames` set to use the various sets in `cui2snames` instead. + + They will all be included in 1 dict with CUI keys and a list of values for each pre-existing dict. + + Args: + cdb (CDB): The CDB to modify. + optimise_cuis (bool, optional): Whether to optimise cui2<...> dicts. Defaults to True. + optimise_names (bool, optional): Whether to optimise name2<...> dicts. Defaults to False. + optimise_snames (bool, optional): Whether to optimise `snames` set. Defaults to True. + """ + # cui2<...> -> cui2many + if optimise_cuis: + _optimise(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + cdb._memory_optimised_parts.add(CUIS_PART) + # name2<...> -> name2many + if optimise_names: + _optimise(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + cdb._memory_optimised_parts.add(NAMES_PART) + if optimise_snames: + # check snames based on cui2sanmes + _optimise_snames(cdb) + cdb._memory_optimised_parts.add(SNAMES_PART) + + +def _attempt_fix_after_load(cdb: CDB, one2many_name: str, dict_names: List[str]): + if not hasattr(cdb, one2many_name): + return + one2many = getattr(cdb, one2many_name) + for dict_name in dict_names: + d = getattr(cdb, dict_name) + if not isinstance(d, DelegatingDict): + raise ValueError(f'Unknown type for {dict_name}: {type(d)}') + d.delegate = one2many + + +def _unoptimise(cdb: CDB, to_many_name: str, dict_names_to_combine: List[str]): + # remove one2many attribute + # the references still exist on each delegator + delattr(cdb, to_many_name) + + delegating_dicts: List[Dict[str, Any]] = [getattr(cdb, dict_name) + for dict_name in dict_names_to_combine] + for del_dict, dict_name in zip(delegating_dicts, dict_names_to_combine): + raw_dict = dict(del_dict.items()) + setattr(cdb, dict_name, raw_dict) + cdb.is_dirty = True + + +def _unoptimise_snames(cdb: CDB, cui2snames: str = 'cui2snames', + snames_attr: str = 'snames') -> None: + # rebuild snames + delegate: Dict[str, Set[str]] = getattr(cdb, cui2snames) + snames = set() + for values in delegate.values(): + snames.update(values) + setattr(cdb, snames_attr, snames) + cdb.is_dirty = True + + +def unoptimise_cdb(cdb: CDB): + """This undoes all the (potential) memory optimisations done in `perform_optimisation`. + + This method relies on `CDB._memory_optimised_parts` to be up to date. + + Args: + cdb (CDB): The CDB to work on. + """ + if CUIS_PART in cdb._memory_optimised_parts: + _unoptimise(cdb, ONE2MANY, CUI_DICT_NAMES_TO_COMBINE) + if NAMES_PART in cdb._memory_optimised_parts: + _unoptimise(cdb, NAME2MANY, NAME_DICT_NAMES_TO_COMBINE) + if SNAMES_PART in cdb._memory_optimised_parts: + _unoptimise_snames(cdb) + cdb._memory_optimised_parts.clear() + + +def map_to_many(dicts: List[Dict[str, Any]]) -> Tuple[Dict[str, List[Any]], List[DelegatingDict]]: + one2many: Dict[str, List[Any]] = {} + delegators: List[DelegatingDict] = [] + for nr, d in enumerate(dicts): + delegator = DelegatingDict( + one2many, nr, nr_of_overall_items=len(dicts)) + for key, value in d.items(): + if key not in one2many: + one2many[key] = delegator._generate_empty_entry() + one2many[key][nr] = value + delegators.append(delegator) + return one2many, delegators diff --git a/medcat/utils/preprocess_umls.py b/medcat/utils/preprocess_umls.py index 9cf0ccea4..7c47f451a 100644 --- a/medcat/utils/preprocess_umls.py +++ b/medcat/utils/preprocess_umls.py @@ -3,7 +3,7 @@ import pandas as pd import tqdm import os -from typing import Dict, Set +from typing import Dict _DEFAULT_COLUMNS: list = [ "CUI", @@ -240,7 +240,7 @@ def get_pt2ch(self) -> dict: cui_parent = cui_parent[cui_parent['PAUI'].notna()] # create dict - pt2ch: Dict[str, Set[str]] = {} + pt2ch: dict = {} for _, row in tqdm.tqdm(cui_parent.iterrows(), total=len(cui_parent.index)): cur_cui = row['CUI'] paui = row['PAUI'] diff --git a/medcat/utils/saving/coding.py b/medcat/utils/saving/coding.py new file mode 100644 index 000000000..c03e6816f --- /dev/null +++ b/medcat/utils/saving/coding.py @@ -0,0 +1,146 @@ +from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable + +import json + + +@runtime_checkable +class EncodeableObject(Protocol): + + def to_dict(self) -> dict: + """Converts the object to a dict. + + Returns: + dict: The dict to be serialised. + """ + + +class UnsuitableObject(ValueError): + pass + + +class PartEncoder(Protocol): + + def try_encode(self, obj: object) -> Any: + """Try to encode an object + + Args: + obj (object): The object to encode + + Raises: + UnsuitableObject: If the object is unsuitable for encoding. + + Returns: + Any: The encoded object + """ + + +SET_IDENTIFIER = '==SET==' + + +class SetEncoder(PartEncoder): + """JSONEncoder (and decoder) for sets. + + Generally, JSON doesn't support serializing of sets natively. + This encoder adds a set identifier to the data when being serialized + and provides a method to read said identifier upon decoding.""" + + def try_encode(self, obj): + if isinstance(obj, set): + return {SET_IDENTIFIER: list(obj)} + raise UnsuitableObject() + + +class PartDecoder(Protocol): + + def try_decode(self, dct: dict) -> Union[dict, Any]: + """Try to decode the dictionary. + + Args: + dct (dict): The dict to decode. + + Returns: + Union[dict, Any]: The dict if unable to decode, the decoded object otherwise + """ + + +class SetDecoder(PartDecoder): + + def try_decode(self, dct: dict) -> Union[dict, set]: + """Decode sets from input dicts. + + Args: + dct (dict): The input dict + + Returns: + Union[dict, set]: The original dict if this was not a serialized set, the set otherwise + """ + if SET_IDENTIFIER in dct: + return set(dct[SET_IDENTIFIER]) + return dct + + +PostProcessor = Callable[[Any], None] # CDB -> None + +DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ] +DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ] +LOADING_POSTPROCESSORS: List[PostProcessor] = [] + + +def register_encoder_decoder(encoder: Optional[Type[PartEncoder]], + decoder: Optional[Type[PartDecoder]], + loading_postprocessor: Optional[PostProcessor]): + if encoder: + DEFAULT_ENCODERS.append(encoder) + if decoder: + DEFAULT_DECODERS.append(decoder) + if loading_postprocessor: + LOADING_POSTPROCESSORS.append(loading_postprocessor) + + +class CustomDelegatingEncoder(json.JSONEncoder): + + def __init__(self, delegates: List[PartEncoder], *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._delegates = delegates + + def default(self, obj): + for delegator in self._delegates: + try: + return delegator.try_encode(obj) + except UnsuitableObject: + pass + return json.JSONEncoder.default(self, obj) + + @classmethod + def def_inst(cls, *args, **kwargs) -> 'CustomDelegatingEncoder': + return cls([_cls() for _cls in DEFAULT_ENCODERS], *args, **kwargs) + + +class CustomDelegatingDecoder(json.JSONDecoder): + _def_inst: Optional['CustomDelegatingDecoder'] = None + + def __init__(self, delegates: List[PartDecoder]) -> None: + self._delegates = delegates + + def object_hook(self, dct: dict) -> Any: + for delegator in self._delegates: + ret_val = delegator.try_decode(dct) + if ret_val is not dct: + return ret_val + return dct + + @classmethod + def def_inst(cls) -> 'CustomDelegatingDecoder': + if cls._def_inst is None: + cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS]) + return cls._def_inst + + +def default_hook(dct: dict) -> Any: + cdd = CustomDelegatingDecoder.def_inst() + return cdd.object_hook(dct) + + +def default_postprocessing(cdb) -> None: + for pp in LOADING_POSTPROCESSORS: + pp(cdb) diff --git a/medcat/utils/saving/serializer.py b/medcat/utils/saving/serializer.py index c08124831..d82df751c 100644 --- a/medcat/utils/saving/serializer.py +++ b/medcat/utils/saving/serializer.py @@ -5,11 +5,13 @@ """ import os import logging -from typing import cast, Dict, Optional, Union +from typing import cast, Dict, Optional, Type import dill import json from medcat.config import Config +from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook, default_postprocessing + logger = logging.getLogger(__name__) @@ -17,35 +19,8 @@ __SPECIALITY_NAMES_NAME = set( ["name2cuis", "name2cuis2status", "name_isupper"]) __SPECIALITY_NAMES_OTHER = set(["snames", "addl_info"]) -SPECIALITY_NAMES = __SPECIALITY_NAMES_CUI | __SPECIALITY_NAMES_NAME | __SPECIALITY_NAMES_OTHER - - -class SetEncode(json.JSONEncoder): - """JSONEncoder (and decoder) for sets. - - Generally, JSON doesn't support serializing of sets natively. - This encoder adds a set identifier to the data when being serialized - and provides a method to read said identifier upon decoding.""" - SET_IDENTIFIER = '==SET==' - - def default(self, obj): - if isinstance(obj, set): - return {SetEncode.SET_IDENTIFIER: list(obj)} - return json.JSONEncoder.default(self, obj) - - @staticmethod - def set_decode(dct: dict) -> Union[dict, set]: - """Decode sets from input dicts. - - Args: - dct (dict): The input dict - - Returns: - Union[dict, set]: The original dict if this was not a serialized set, the set otherwise - """ - if SetEncode.SET_IDENTIFIER in dct: - return set(dct[SetEncode.SET_IDENTIFIER]) - return dct +ONE2MANY = set(['cui2many', 'name2many']) # these may or may not exist +SPECIALITY_NAMES = __SPECIALITY_NAMES_CUI | __SPECIALITY_NAMES_NAME | __SPECIALITY_NAMES_OTHER | ONE2MANY class JsonSetSerializer: @@ -75,7 +50,11 @@ def write(self, d: dict) -> None: logger.info('Writing data for "%s" into "%s"', self.name, self.file_name) with open(self.file_name, 'w') as f: - json.dump(d, f, cls=SetEncode) + # the def_inst method, when called, + # returns the right type of object anyway + + json.dump(d, f, cls=cast(Type[json.JSONEncoder], + CustomDelegatingEncoder.def_inst)) def read(self) -> dict: """Read the json file specified by this serializer. @@ -85,7 +64,8 @@ def read(self) -> dict: """ logger.info('Reading data for %s from %s', self.name, self.file_name) with open(self.file_name, 'r') as f: - data = json.load(f, object_hook=SetEncode.set_decode) + data = json.load( + f, object_hook=default_hook) return data @@ -168,6 +148,8 @@ def serialize(self, cdb, overwrite: bool = False) -> None: dill.dump(to_save, f) if self.jsons is not None: for name in SPECIALITY_NAMES: + if name not in cdb.__dict__: + continue # in case cui2many doesn't exit self.jsons[name].write(cdb.__dict__[name]) def deserialize(self, cdb_cls): @@ -199,5 +181,10 @@ def deserialize(self, cdb_cls): # if applicable if self.jsons is not None: for name in SPECIALITY_NAMES: + if not os.path.exists(self.jsons[name].file_name): + continue # in case of non-memory-optimised where cui2many doesn't exist cdb.__dict__[name] = self.jsons[name].read() + # if anything has + # been registered to postprocess the CDBs + default_postprocessing(cdb) return cdb diff --git a/tests/utils/saving/test_coding.py b/tests/utils/saving/test_coding.py new file mode 100644 index 000000000..c60a3b1f2 --- /dev/null +++ b/tests/utils/saving/test_coding.py @@ -0,0 +1,77 @@ +from medcat.utils.saving import coding + +import json + +import unittest + + +class SetEncodeTests(unittest.TestCase): + string2sets_dict1 = {'s1': set(['v1', 'v2', 'v3']), + 's2': set(['u1', 'u2', 'u3'])} + string2sets_dict2 = {'p1': set([1, 2, 3]), + 'p2': set([3, 4, 5])} + + def serialise(self, d: dict) -> str: + return json.dumps(d, cls=coding.CustomDelegatingEncoder.def_inst) + + def _helper_serialises(self, d: dict): + s = self.serialise(d) + self.assertIsInstance(s, str) + + def test_sets_of_strings_serialise(self): + self._helper_serialises(self.string2sets_dict1) + + def test_sets_of_ints_serialise(self): + self._helper_serialises(self.string2sets_dict2) + + def _helper_keys_in_json(self, d: dict): + s = self.serialise(d) + for k in d.keys(): + with self.subTest(k): + self.assertIn(str(k), s) + + def test_sos_keys_in_json(self): + self._helper_keys_in_json(self.string2sets_dict1) + + def test_soi_keys_in_json(self): + self._helper_keys_in_json(self.string2sets_dict2) + + def _helper_values_in_json(self, d: dict): + s = self.serialise(d) + for key, v in d.items(): + for nr, el in enumerate(v): + with self.subTest(f"Key: {key}; Element {nr}"): + self.assertIn(str(el), s) + + def test_sos_values_in_json(self): + self._helper_values_in_json(self.string2sets_dict1) + + def test_soi_values_in_json(self): + self._helper_values_in_json(self.string2sets_dict2) + + +class SetDecodeTests(unittest.TestCase): + + def deserialise(self, s: str) -> dict: + return json.loads(s, object_hook=coding.default_hook) + + def setUp(self) -> None: + self.encoder = SetEncodeTests() + self.encoded1 = self.encoder.serialise(self.encoder.string2sets_dict1) + self.encoded2 = self.encoder.serialise(self.encoder.string2sets_dict2) + + def test_sos_decodes(self): + d = self.deserialise(self.encoded1) + self.assertIsInstance(d, dict) + + def test_soi_decodes(self): + d = self.deserialise(self.encoded2) + self.assertIsInstance(d, dict) + + def test_sos_decodes_to_identical(self): + d = self.deserialise(self.encoded1) + self.assertEqual(d, self.encoder.string2sets_dict1) + + def test_soi_decodes_to_identical(self): + d = self.deserialise(self.encoded2) + self.assertEqual(d, self.encoder.string2sets_dict2) diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index 6313906dc..f0cc75de1 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -9,11 +9,13 @@ from medcat.cat import CAT from medcat.vocab import Vocab -from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES +from medcat.utils.saving.serializer import JsonSetSerializer, CDBSerializer, SPECIALITY_NAMES, ONE2MANY +import medcat.utils.saving.coding as _ -class JSONSerialoizationTests(unittest.TestCase): - folder = os.path.join('temp', 'JSONSerialoizationTests') + +class JSONSerializationTests(unittest.TestCase): + folder = os.path.join('temp', 'JSONSerializationTests') def setUp(self) -> None: return super().setUp() @@ -42,6 +44,11 @@ def test_round_trip(self): self.ser.serialize(self.cdb, overwrite=True) cdb = self.ser.deserialize(CDB) for name in SPECIALITY_NAMES: + if name in ONE2MANY: + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue with self.subTest(name): orig = getattr(self.cdb, name) now = getattr(cdb, name) @@ -82,11 +89,19 @@ def test_dill_to_json(self): json_path = os.path.join(model_pack_folder, "*.json") jsons = glob.glob(json_path) # there is also a model_card.json - self.assertGreaterEqual(len(jsons), len(SPECIALITY_NAMES)) + # but nothing for cui2many or name2many + # so can remove the length of ONE2MANY + self.assertGreaterEqual(len(jsons), len( + SPECIALITY_NAMES) - len(ONE2MANY)) for json in jsons: with self.subTest(f'JSON {json}'): if json.endswith('model_card.json'): continue # ignore model card here + if any(name in json for name in ONE2MANY): + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue self.assertTrue( any(special_name in json for special_name in SPECIALITY_NAMES)) return model_pack_folder @@ -128,6 +143,11 @@ def test_round_trip(self): self.assertEqual(cat.vocab.unigram_table, self.undertest.vocab.unigram_table) for name in SPECIALITY_NAMES: + if name in ONE2MANY: + # ignore cui2many and name2many + # since they don't exist if/when + # optimisation hasn't been done + continue with self.subTest(f'CDB Name {name}'): self.assertEqual(cat.cdb.__dict__[ name], self.undertest.cdb.__dict__[name]) diff --git a/tests/utils/test_memory_optimiser.py b/tests/utils/test_memory_optimiser.py new file mode 100644 index 000000000..5f59f5274 --- /dev/null +++ b/tests/utils/test_memory_optimiser.py @@ -0,0 +1,375 @@ +from medcat.utils import memory_optimiser + +import unittest +import tempfile +import os +import shutil +import json +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.vocab import Vocab +from medcat.utils.saving import coding + + +class DelegatingDictTests(unittest.TestCase): + _dict = {'c1': [None, 0], 'c2': [1, None]} + + def setUp(self) -> None: + # deep copy so that the origianl remains unchangeds + _dict = dict((k, v.copy() + ) for k, v in self._dict.items()) + self.del_dict1 = memory_optimiser.DelegatingDict(_dict, 0, 2) + self.del_dict2 = memory_optimiser.DelegatingDict(_dict, 1, 2) + self.delegators = [self.del_dict1, self.del_dict2] + self.names = ['delegator 1', 'delegator 2'] + self.expected_lens = [len( + [v[nr] for v in _dict.values() if v[nr] is not None] + ) for nr in range(len(_dict[list(_dict.keys())[0]]))] + + def test_removal(self, key='c2'): + self.assertIn(key, self.del_dict1) + del self.del_dict1[key] + self.assertNotIn(key, self.del_dict1) + + def test_pop_no_def_existing(self, key='c2'): + self.assertIn(key, self.del_dict1) + val = self.del_dict1.pop(key) + self.assertNotIn(key, self.del_dict1) + self.assertIs(val, self._dict[key][0]) + + def test_pop_def_non_existing(self, key='c1', def_val='DEF VAL'): + self.assertNotIn(key, self.del_dict1) + val = self.del_dict1.pop(key, def_val) + self.assertNotIn(key, self.del_dict1) + self.assertIs(val, def_val) + + def test_adding_exiting_key_nonexist_value(self, key: str = 'c1'): + self.assertNotIn(key, self.del_dict1) + self.del_dict1[key] = 'value' + self.assertIn(key, self.del_dict1) + + def test_adding_nonexiting_key(self, key: str = 'nek1'): + self.assertNotIn(key, self.del_dict1) + self.del_dict1[key] = 'value-NEW' + self.assertIn(key, self.del_dict1) + + def test_adding_nonexiting_key_not_affect_other(self, key: str = 'nek2'): + self.assertNotIn(key, self.del_dict2) + self.del_dict1[key] = 'value-NEW-2' + self.assertNotIn(key, self.del_dict2) + + def test_delegating_dict_has_correct_keys(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.keys()), exp_len) + + def test_delegating_dict_has_same_number_of_keys_and_values(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.keys()), exp_len) + self.assertEqual(len(delegator.values()), exp_len) + + def test_delegating_dict_has_same_number_of_items_and_iter_values(self): + for delegator, exp_len, name in zip(self.delegators, self.expected_lens, self.names): + with self.subTest(name): + self.assertEqual(len(delegator.items()), exp_len) + # __iter__ -> list -> len + self.assertEqual(len(list(delegator)), exp_len) + + def test_delegator_do_not_have_None_values(self): + for delegator, name in zip(self.delegators, self.names): + for key, val in delegator.items(): + with self.subTest(f"{name}: {key}"): + self.assertIsNotNone(val) + + def test_delegator_keys_in_original(self): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + self.assertIn(key, self._dict) + + def test_delegator_keys_in_container(self): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + self.assertIn(key, delegator) + + def test_delegator_get_gets_key(self, def_value='#DEFAULT#'): + for delegator, name in zip(self.delegators, self.names): + for key in delegator.keys(): + with self.subTest(f"{name}: {key}"): + val = delegator.get(key, def_value) + self.assertIsNot(val, def_value) + + def test_delegator_get_defaults_non_existant_key(self, def_value='#DEFAULT#'): + for delegator, name in zip(self.delegators, self.names): + for key in self._dict.keys(): + if key in delegator: + continue + with self.subTest(f"{name}: {key}"): + val = delegator.get(key, def_value) + self.assertIs(val, def_value) + + +class DelegatingDictJsonTests(unittest.TestCase): + _dict = {'c5': [None, 10], 'c6': [11, None]} + + def setUp(self) -> None: + self.del_dict1 = memory_optimiser.DelegatingDict(self._dict, 0, 2) + self.del_dict2 = memory_optimiser.DelegatingDict(self._dict, 1, 2) + self.delegators = [self.del_dict1, self.del_dict2] + self.master_dict = {'one2many': self._dict, + 'part1': self.del_dict1, + 'part2': self.del_dict2} + + def serialise_master(self) -> str: + return json.dumps(self.master_dict, + cls=coding.CustomDelegatingEncoder.def_inst) + + def deserialise(self, s: str, one2many_name='one2many') -> dict: + d = json.loads(s, object_hook=coding.default_hook) + one2many = d[one2many_name] + for key, value in d.items(): + if key == one2many_name: + continue + if value.delegate is None: + value.delegate = one2many + return d + + def test_dict_of_delegation_serialises(self): + s = self.serialise_master() + self.assertIsInstance(s, str) + + def test_dod_ser_has_keys(self): + s = self.serialise_master() + for key in self.master_dict: + with self.subTest(key): + self.assertIn(key, s) + + def test_dod_ser_one2many_has_sub_keys(self): + s = self.serialise_master() + for key in self.master_dict['one2many']: + with self.subTest(key): + self.assertIn(key, s) + + def test_round_trip(self): + s = self.serialise_master() + d = self.deserialise(s) + self.assertIsInstance(d, dict) + + def test_round_trip_equal(self): + s = self.serialise_master() + d = self.deserialise(s) + self.assertEqual(d, self.master_dict) + + +class UnOptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + + def test_unoptimised_cdb_does_not_have_cui2many(self): + self.assertFalse(hasattr(self.cdb, 'cui2many')) + + def test_unoptmised_cdb_does_not_have_delegating_dicts(self): + for key, val in self.cdb.__dict__.items(): + with self.subTest(key): + self.assertNotIsInstance(val, memory_optimiser.DelegatingDict) + + def test_unoptimised_knows_has_no_optimsied_parts(self): + self.assertFalse(self.cdb._memory_optimised_parts, + "Should have empty optimised partss") + + def test_simply_loaded_model_not_dirty(self): + self.assertFalse(self.cdb.is_dirty) + + +class MemoryOptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + memory_optimiser.perform_optimisation(cls.cdb, optimise_snames=True) + + def test_is_dirty(self): + self.assertTrue(self.cdb.is_dirty, + "Should be dirty after optimisation") + + def test_knows_optimised(self): + self.assertTrue(self.cdb._memory_optimised_parts, + "Should have non-empty `_memory_optimised_parts`") + + def test_knows_correct_parts_optimsed(self, should_be=['CUIS', 'snames']): + for name in should_be: + with self.subTest(name): + self.assertIn(name, self.cdb._memory_optimised_parts) + + def test_knows_incorrect_parts_NOT_optimised(self, should_not_be=['NAMES']): + for name in should_not_be: + with self.subTest(name): + self.assertNotIn(name, self.cdb._memory_optimised_parts) + + def test_cdb_has_one2many(self, one2many_name='cui2many'): + self.assertTrue(hasattr(self.cdb, one2many_name)) + one2many = getattr(self.cdb, one2many_name) + self.assertIsInstance(one2many, dict) + + def test_cdb_has_delegating_dicts(self): + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE: + with self.subTest(dict_name): + d = getattr(self.cdb, dict_name) + self.assertIsInstance(d, memory_optimiser.DelegatingDict) + + def test_has_delegating_set(self): + self.assertIsInstance( + self.cdb.snames, memory_optimiser.DelegatingValueSet) + + def test_delegating_set_has_values(self): + for values in self.cdb.cui2snames.values(): + for val in values: + with self.subTest(f'Checking {val}'): + self.assertIn(val, self.cdb.snames) + + +class MemoryUnoptimisingTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + + def test_optimisation_round_trip_cuis(self): + cui_dicts_before = [getattr(self.cdb, dict_name) + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE] + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + cui_dicts_after = [getattr(self.cdb, dict_name) + for dict_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE] + for before, after, name in zip(cui_dicts_before, + cui_dicts_after, + memory_optimiser.CUI_DICT_NAMES_TO_COMBINE): + with self.subTest(f'{name}'): + self.assertIsInstance(before, dict) + self.assertIsInstance(after, dict) + self.assertEquals(len(before), len(after)) + self.assertEquals(before, after) + + def test_optimisation_round_trip_snames(self): + snames_before = self.cdb.snames + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + snames_after = self.cdb.snames + self.assertIsInstance(snames_before, set) + self.assertIsInstance(snames_after, set) + self.assertEquals(len(snames_before), len(snames_after)) + self.assertEquals(snames_before, snames_after) + + def test_optimisation_round_trip_dirty(self): + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + self.assertTrue(self.cdb.is_dirty) + + def test_optimisation_round_trip_no_optimised_parts(self): + memory_optimiser.perform_optimisation(self.cdb) + memory_optimiser.unoptimise_cdb(self.cdb) + self.assertFalse(self.cdb._memory_optimised_parts, + "Should have no optimised parts") + + +class OperationalTests(unittest.TestCase): + temp_folder = tempfile.TemporaryDirectory() + temp_cdb_path = os.path.join(temp_folder.name, 'cat.cdb') + json_path = temp_cdb_path.rsplit(os.path.sep, 1)[0] + # importing here so it's in the local namespace + # otherwise, all of its parts would get run again + from tests.test_cat import CATTests + test_callable_with_single_text = CATTests.test_callable_with_single_text + test_callable_with_single_empty_text = CATTests.test_callable_with_single_empty_text + test_callable_with_single_none_text = CATTests.test_callable_with_single_none_text + test_get_entities = CATTests.test_get_entities + test_get_entities_including_text = CATTests.test_get_entities_including_text + test_get_entities_multi_texts = CATTests.test_get_entities_multi_texts + test_get_entities_multi_texts_including_text = CATTests.test_get_entities_multi_texts_including_text + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + memory_optimiser.perform_optimisation(cls.cdb, optimise_snames=True) + cls.vocab = Vocab.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "vocab.dat")) + cls.cdb.config.general.spacy_model = "en_core_web_md" + cls.cdb.config.ner.min_name_len = 2 + cls.cdb.config.ner.upper_case_limit_len = 3 + cls.cdb.config.general.spell_check = True + cls.cdb.config.linking.train_count_threshold = 10 + cls.cdb.config.linking.similarity_threshold = 0.3 + cls.cdb.config.linking.train = True + cls.cdb.config.linking.disamb_length_limit = 5 + cls.cdb.config.general.full_unlink = True + cls.meta_cat_dir = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "tmp") + cls.undertest = CAT(cdb=cls.cdb, config=cls.cdb.config, + vocab=cls.vocab, meta_cats=[]) + cls._linkng_filters = cls.undertest.config.linking.filters.copy_of() + + # # add tests from CAT tests + + @classmethod + def tearDownClass(cls) -> None: + cls.temp_folder.cleanup() + cls.undertest.destroy_pipe() + if os.path.exists(cls.meta_cat_dir): + shutil.rmtree(cls.meta_cat_dir) + + def tearDown(self) -> None: + self.cdb.config.annotation_output.include_text_in_output = False + # need to make sure linking filters are not retained beyond a test scope + self.undertest.config.linking.filters = self._linkng_filters.copy_of() + + def test_optimised_cdb_has_cui2many(self): + self.assertTrue(hasattr(self.cdb, 'cui2many')) + + def test_can_be_saved_as_json(self): + self.cdb.save(self.temp_cdb_path, json_path=self.json_path) + + def test_can_be_loaded_as_json(self): + self.test_can_be_saved_as_json() + cdb = CDB.load(self.temp_cdb_path, self.json_path) + self.assertEqual(self.cdb.cui2many, cdb.cui2many) + for del_name in memory_optimiser.CUI_DICT_NAMES_TO_COMBINE: + d = getattr(cdb, del_name) + with self.subTest(del_name): + self.assertIsInstance(d, memory_optimiser.DelegatingDict) + self.assertIs(cdb.cui2many, d.delegate) + + +class DelegatingValueSetTests(unittest.TestCase): + + def setUp(self) -> None: + self.delegate = {'a': set('abcd'), + 'b': set('efghij'), + 'c': set('lm'), # skip k + 'd': set('qrst'), # skip a bunch + } + self.original = set([v for s in self.delegate for v in s]) + + def test_DelegatingValueSet_constructs(self): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + self.assertIsInstance(dvs, memory_optimiser.DelegatingValueSet) + + def test_DelegatingValueSet_contains_values(self): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + for v in self.original: + with self.subTest(f'Check: {v}'): + self.assertIn(v, dvs) + + def test_DelegatingValueSet_contains_incorrect_values(self, + to_check=set('kopuvwxyz')): + dvs = memory_optimiser.DelegatingValueSet(self.delegate) + for v in to_check: + with self.subTest(f'Check: {v}'): + self.assertNotIn(v, dvs)