Skip to content

Commit

Permalink
Merge branch 'master' of github.com:bramiozo/MedCAT
Browse files Browse the repository at this point in the history
  • Loading branch information
bramiozo committed Mar 5, 2024
2 parents 588688f + 67f1126 commit 82ad983
Show file tree
Hide file tree
Showing 9 changed files with 149 additions and 19 deletions.
17 changes: 10 additions & 7 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def attempt_unpack(cls, zip_path: str) -> str:
def load_model_pack(cls,
zip_path: str,
meta_cat_config_dict: Optional[Dict] = None,
ner_config_dict: Optional[Dict] = None,
load_meta_models: bool = True,
load_addl_ner: bool = True) -> "CAT":
"""Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models
Expand All @@ -346,6 +347,10 @@ def load_model_pack(cls,
A config dict that will overwrite existing configs in meta_cat.
e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}.
Defaults to None.
ner_config_dict (Optional[Dict]):
A config dict that will overwrite existing configs in transformers ner.
e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}.
Defaults to None.
load_meta_models (bool):
Whether to load MetaCAT models if present (Default value True).
load_addl_ner (bool):
Expand Down Expand Up @@ -381,18 +386,16 @@ def load_model_pack(cls,
else:
vocab = None

# Find meta models in the model_pack
trf_paths = [os.path.join(model_pack_path, path)
for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
# Find ner models in the model_pack
trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
addl_ner = []
for trf_path in trf_paths:
trf = TransformersNER.load(save_dir_path=trf_path)
trf = TransformersNER.load(save_dir_path=trf_path,config_dict=ner_config_dict)
trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model
addl_ner.append(trf)

# Find meta models in the model_pack
meta_paths = [os.path.join(model_pack_path, path)
for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
# Find metacat models in the model_pack
meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_cats = []
for meta_path in meta_paths:
meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
Expand Down
3 changes: 1 addition & 2 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.dataclasses import Any, Callable, Dict, Optional, Union
from pydantic.fields import ModelField
from typing import List, Set, Tuple, cast
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
from multiprocessing import cpu_count
import logging
import jsonpickle
Expand Down
2 changes: 2 additions & 0 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class General(MixingConfig, BaseModel):
"""How many characters are piped at once into the meta_cat class"""
ner_aggregation_strategy: str = 'simple'
"""Agg strategy for HF pipeline for NER"""
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
verbose_metrics: bool = False
Expand Down
6 changes: 4 additions & 2 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
else:
self.training_arguments = training_arguments


def create_eval_pipeline(self):
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer)

if self.config.general['chunking_overlap_window'] is None:
logger.warning("Chunking overlap window attribute in the config is set to None, hence chunking is disabled. Be cautious, PII data MAY BE REVEALED. To enable chunking, set the value to 0 or above.")
self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer,stride=self.config.general['chunking_overlap_window'])
if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'):
# NOTE: this will fix the DeID model(s) created before medcat 1.9.3
# though this fix may very well be unstable
Expand Down
30 changes: 27 additions & 3 deletions medcat/utils/ner/deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@
- config
- cdb
"""
from typing import Union, Tuple, Any, List, Iterable, Optional
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
import logging

from medcat.cat import CAT
from medcat.utils.ner.model import NerModel

from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text


logger = logging.getLogger(__name__)


class DeIdModel(NerModel):
"""The DeID model.
Expand Down Expand Up @@ -93,6 +97,25 @@ def deid_multi_texts(self,
Returns:
List[str]: List of deidentified documents.
"""
# NOTE: we assume we're using the 1st (and generally only)
# additional NER model.
# the same assumption is made in the `train` method
chunking_overlap_window = self.cat._addl_ner[0].config.general.chunking_overlap_window
if chunking_overlap_window is not None:
logger.warning("Chunking overlap window has been set to %s. "
"This may cause multiprocessing to stall in certain"
"environments and/or situations and has not been"
"fully tested.",
chunking_overlap_window)
logger.warning("If the following hangs forever (i.e doesn't finish) "
"but you still wish to run on multiple processes you can set "
"`cat._addl_ner[0].config.general.chunking_overlap_window = None` "
"and then either a) save the model on disk and load it back up, or "
" b) call `cat._addl_ner[0].create_eval_pipeline()` to recreate the pipe. "
"However, this will remove chunking from the input text, which means "
"only the first 512 tokens will be recognised and thus only the "
"first part of longer documents (those with more than 512) tokens"
"will be deidentified. ")
entities = self.cat.get_entities_multi_texts(texts, addl_info=addl_info,
n_process=n_process, batch_size=batch_size)
out = []
Expand All @@ -110,7 +133,7 @@ def deid_multi_texts(self,
return out

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
"""Load DeId model from model pack.
The method first loads the CAT instance.
Expand All @@ -119,6 +142,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
valid DeId model.
Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.
Raises:
Expand All @@ -127,7 +151,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel':
Returns:
DeIdModel: The resulting DeI model.
"""
ner_model = NerModel.load_model_pack(model_pack_path)
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
cat = ner_model.cat
if not cls._is_deid_model(cat):
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions medcat/utils/ner/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple, Union, Optional
from typing import Any, List, Tuple, Union, Optional, Dict

from spacy.tokens import Doc

Expand Down Expand Up @@ -94,16 +94,17 @@ def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel
return cls(cat)

@classmethod
def load_model_pack(cls, model_pack_path: str) -> 'NerModel':
def load_model_pack(cls, model_pack_path: str,config: Optional[Dict] = None) -> 'NerModel':
"""Load NER model from model pack.
The method first wraps the loaded CAT instance.
Args:
config: Config for DeId model pack (primarily for stride of overlap window)
model_pack_path (str): The model pack path.
Returns:
NerModel: The resulting DeI model.
"""
cat = CAT.load_model_pack(model_pack_path)
cat = CAT.load_model_pack(model_pack_path,ner_config_dict=config)
return cls(cat)
37 changes: 36 additions & 1 deletion medcat/utils/preprocess_snomed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,32 @@ def get_all_children(sctid, pt2ch):
return result


def get_direct_refset_mapping(in_dict: dict) -> dict:
"""This method uses the output from Snomed.map_snomed2icd10 or
Snomed.map_snomed2opcs4 and removes the metadata and maps each
SNOMED CUI to the prioritised list of the target ontology CUIs.
The input dict is expected to be in the following format:
- Keys are SnomedCT CUIs
- The values are lists of dictionaries, each list item (at least)
- Has a key 'code' that specifies the target onotlogy CUI
- Has a key 'mapPriority' that specifies the priority
Args:
in_dict (dict): The input dict.
Returns:
dict: The map from Snomed CUI to list of priorities list of target ontology CUIs.
"""
ret_dict = dict()
for k, vals in in_dict.items():
# sort such that highest priority values are first
svals = sorted(vals, key=lambda el: el['mapPriority'], reverse=True)
# only keep the code / CUI
ret_dict[k] = [v['code'] for v in svals]
return ret_dict


class Snomed:
"""
Pre-process SNOMED CT release files.
Expand All @@ -53,6 +79,15 @@ def __init__(self, data_path, uk_ext=False, uk_drug_ext=False):
self.release = data_path[-16:-8]
self.uk_ext = uk_ext
self.uk_drug_ext = uk_drug_ext
self.opcs_refset_id = "1126441000000105"
if ((self.uk_ext or self.uk_drug_ext) and
# using lexicographical comparison below
# e.g "20240101" > "20231122" results in True
# yet "20231121" > "20231122" reults in False
len(self.release) == len("20231122") and self.release >= "20231122"):
# NOTE for UK extensions starting from 20231122 the
# OPCS4 refset ID seems to be different
self.opcs_refset_id = '1382401000000109'

def to_concept_df(self):
"""
Expand Down Expand Up @@ -398,7 +433,7 @@ def _map_snomed2refset(self):
mapping_df = pd.concat(dfs2merge)
del dfs2merge
if self.uk_ext or self.uk_drug_ext:
opcs_df = mapping_df[mapping_df['refsetId'] == '1126441000000105']
opcs_df = mapping_df[mapping_df['refsetId'] == self.opcs_refset_id]
icd10_df = mapping_df[mapping_df['refsetId']
== '999002271000000101']
return icd10_df, opcs_df
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/ner/test_deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ def test_can_create_model(self):
deid_model = deid.DeIdModel.create(ner)
self.assertIsNotNone(deid_model)


def _add_model(cls):
cdb = make_or_update_cdb(TRAIN_DATA)
config = transformers_ner.ConfigTransformersNER()
config.general['test_size'] = 0.1 # Usually set this to 0.1-0.2
config.general['chunking_overlap_window'] = None
cls.ner = transformers_ner.TransformersNER(cdb=cdb, config=config)
cls.ner.training_arguments.num_train_epochs = 1 # Use 5-10 normally
# As we are NOT training on a GPU that can, we'll set it to 1
Expand Down
64 changes: 64 additions & 0 deletions tests/utils/test_preprocess_snomed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict
from medcat.utils import preprocess_snomed

import unittest


EXAMPLE_REFSET_DICT: Dict = {
'SCUI1': [
{'code': 'TCUI1', 'mapPriority': '1'},
{'code': 'TCUI2', 'mapPriority': '2'},
{'code': 'TCUI3', 'mapPriority': '3'},
]
}

# in order from highest priority to lowest
EXPECTED_DIRECT_MAPPINGS = {"SCUI1": ['TCUI3', 'TCUI2', 'TCUI1']}

EXAMPLE_REFSET_DICT_WITH_EXTRAS = dict(
(k, [dict(v, otherKey=f"val-{k}") for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items())

EXAMPLE_REFSET_DICT_NO_PRIORITY = dict(
(k, [{ik: iv for ik, iv in v.items() if ik != 'mapPriority'} for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items()
)

EXAMPLE_REFSET_DICT_NO_CODE = dict(
(k, [{ik: iv for ik, iv in v.items() if ik != 'code'} for v in vals]) for k, vals in EXAMPLE_REFSET_DICT.items()
)


class DirectMappingTest(unittest.TestCase):

def test_example_gets_direct_mappings(self):
res = preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT)
self.assertEqual(res, EXPECTED_DIRECT_MAPPINGS)

def test_example_w_extras_gets_direct_mappings(self):
res = preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_WITH_EXTRAS)
self.assertEqual(res, EXPECTED_DIRECT_MAPPINGS)

def test_example_no_priority_fails(self):
with self.assertRaises(KeyError):
preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_NO_PRIORITY)

def test_example_no_codfe_fails(self):
with self.assertRaises(KeyError):
preprocess_snomed.get_direct_refset_mapping(EXAMPLE_REFSET_DICT_NO_CODE)

EXAMPLE_SNOMED_PATH_OLD = "SnomedCT_InternationalRF2_PRODUCTION_20220831T120000Z"
EXAMPLE_SNOMED_PATH_NEW = "SnomedCT_UKClinicalRF2_PRODUCTION_20231122T000001Z"


class TestSnomedVersionsOPCS4(unittest.TestCase):

def test_old_gets_old_OPCS4_mapping_nonuk_ext(self):
snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=False)
self.assertEqual(snomed.opcs_refset_id, "1126441000000105")

def test_old_gets_old_OPCS4_mapping_uk_ext(self):
snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_OLD, uk_ext=True)
self.assertEqual(snomed.opcs_refset_id, "1126441000000105")

def test_new_gets_new_OCPS4_mapping_uk_ext(self):
snomed = preprocess_snomed.Snomed(EXAMPLE_SNOMED_PATH_NEW, uk_ext=True)
self.assertEqual(snomed.opcs_refset_id, "1382401000000109")

0 comments on commit 82ad983

Please sign in to comment.