From 67f11266f2b15a187080037331d28390e6b6d954 Mon Sep 17 00:00:00 2001 From: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:29:16 +0000 Subject: [PATCH] CU-8693qx9yp Deid chunking - hugging face pipeline approach (#405) * Pushing chunking update * Update transformers_ner.py * Pushing update to config Added NER config in cat load function * Update cat.py * Updating chunking overlap * CU-8693qx9yp: Add warning for deid multiprocessing with (potentially) non-functioning chunking window * CU-8693qx9yp: Fix linting issue --------- Co-authored-by: mart-r --- medcat/cat.py | 11 ++++++++--- medcat/config_transformers_ner.py | 2 ++ medcat/ner/transformers_ner.py | 6 ++++-- medcat/utils/ner/deid.py | 30 +++++++++++++++++++++++++++--- medcat/utils/ner/model.py | 7 ++++--- tests/utils/ner/test_deid.py | 2 +- 6 files changed, 46 insertions(+), 12 deletions(-) diff --git a/medcat/cat.py b/medcat/cat.py index 9159eddd8..d36e30611 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -334,6 +334,7 @@ def attempt_unpack(cls, zip_path: str) -> str: def load_model_pack(cls, zip_path: str, meta_cat_config_dict: Optional[Dict] = None, + ner_config_dict: Optional[Dict] = None, load_meta_models: bool = True, load_addl_ner: bool = True) -> "CAT": """Load everything within the 'model pack', i.e. the CDB, config, vocab and any MetaCAT models @@ -346,6 +347,10 @@ def load_model_pack(cls, A config dict that will overwrite existing configs in meta_cat. e.g. meta_cat_config_dict = {'general': {'device': 'cpu'}}. Defaults to None. + ner_config_dict (Optional[Dict]): + A config dict that will overwrite existing configs in transformers ner. + e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}. + Defaults to None. load_meta_models (bool): Whether to load MetaCAT models if present (Default value True). load_addl_ner (bool): @@ -381,15 +386,15 @@ def load_model_pack(cls, else: vocab = None - # Find meta models in the model_pack + # Find ner models in the model_pack trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else [] addl_ner = [] for trf_path in trf_paths: - trf = TransformersNER.load(save_dir_path=trf_path) + trf = TransformersNER.load(save_dir_path=trf_path,config_dict=ner_config_dict) trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model addl_ner.append(trf) - # Find meta models in the model_pack + # Find metacat models in the model_pack meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else [] meta_cats = [] for meta_path in meta_paths: diff --git a/medcat/config_transformers_ner.py b/medcat/config_transformers_ner.py index 64435e9cb..9f3102acb 100644 --- a/medcat/config_transformers_ner.py +++ b/medcat/config_transformers_ner.py @@ -13,6 +13,8 @@ class General(MixingConfig, BaseModel): """How many characters are piped at once into the meta_cat class""" ner_aggregation_strategy: str = 'simple' """Agg strategy for HF pipeline for NER""" + chunking_overlap_window: Optional[int] = 5 + """Size of the overlap window used for chunking""" test_size: float = 0.2 last_train_on: Optional[int] = None verbose_metrics: bool = False diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 729be4625..78b410230 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -76,9 +76,11 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None, else: self.training_arguments = training_arguments - def create_eval_pipeline(self): - self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer) + + if self.config.general['chunking_overlap_window'] is None: + logger.warning("Chunking overlap window attribute in the config is set to None, hence chunking is disabled. Be cautious, PII data MAY BE REVEALED. To enable chunking, set the value to 0 or above.") + self.ner_pipe = pipeline(model=self.model, task="ner", tokenizer=self.tokenizer.hf_tokenizer,stride=self.config.general['chunking_overlap_window']) if not hasattr(self.ner_pipe.tokenizer, '_in_target_context_manager'): # NOTE: this will fix the DeID model(s) created before medcat 1.9.3 # though this fix may very well be unstable diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py index 13ee5e04c..343e89ef0 100644 --- a/medcat/utils/ner/deid.py +++ b/medcat/utils/ner/deid.py @@ -34,7 +34,8 @@ - config - cdb """ -from typing import Union, Tuple, Any, List, Iterable, Optional +from typing import Union, Tuple, Any, List, Iterable, Optional, Dict +import logging from medcat.cat import CAT from medcat.utils.ner.model import NerModel @@ -42,6 +43,9 @@ from medcat.utils.ner.helpers import _deid_text as deid_text, replace_entities_in_text +logger = logging.getLogger(__name__) + + class DeIdModel(NerModel): """The DeID model. @@ -93,6 +97,25 @@ def deid_multi_texts(self, Returns: List[str]: List of deidentified documents. """ + # NOTE: we assume we're using the 1st (and generally only) + # additional NER model. + # the same assumption is made in the `train` method + chunking_overlap_window = self.cat._addl_ner[0].config.general.chunking_overlap_window + if chunking_overlap_window is not None: + logger.warning("Chunking overlap window has been set to %s. " + "This may cause multiprocessing to stall in certain" + "environments and/or situations and has not been" + "fully tested.", + chunking_overlap_window) + logger.warning("If the following hangs forever (i.e doesn't finish) " + "but you still wish to run on multiple processes you can set " + "`cat._addl_ner[0].config.general.chunking_overlap_window = None` " + "and then either a) save the model on disk and load it back up, or " + " b) call `cat._addl_ner[0].create_eval_pipeline()` to recreate the pipe. " + "However, this will remove chunking from the input text, which means " + "only the first 512 tokens will be recognised and thus only the " + "first part of longer documents (those with more than 512) tokens" + "will be deidentified. ") entities = self.cat.get_entities_multi_texts(texts, addl_info=addl_info, n_process=n_process, batch_size=batch_size) out = [] @@ -110,7 +133,7 @@ def deid_multi_texts(self, return out @classmethod - def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel': + def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel': """Load DeId model from model pack. The method first loads the CAT instance. @@ -119,6 +142,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel': valid DeId model. Args: + config: Config for DeId model pack (primarily for stride of overlap window) model_pack_path (str): The model pack path. Raises: @@ -127,7 +151,7 @@ def load_model_pack(cls, model_pack_path: str) -> 'DeIdModel': Returns: DeIdModel: The resulting DeI model. """ - ner_model = NerModel.load_model_pack(model_pack_path) + ner_model = NerModel.load_model_pack(model_pack_path,config=config) cat = ner_model.cat if not cls._is_deid_model(cat): raise ValueError( diff --git a/medcat/utils/ner/model.py b/medcat/utils/ner/model.py index 553fb4c65..d3ff2eb3b 100644 --- a/medcat/utils/ner/model.py +++ b/medcat/utils/ner/model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple, Union, Optional +from typing import Any, List, Tuple, Union, Optional, Dict from spacy.tokens import Doc @@ -94,16 +94,17 @@ def create(cls, ner: Union[TransformersNER, List[TransformersNER]]) -> 'NerModel return cls(cat) @classmethod - def load_model_pack(cls, model_pack_path: str) -> 'NerModel': + def load_model_pack(cls, model_pack_path: str,config: Optional[Dict] = None) -> 'NerModel': """Load NER model from model pack. The method first wraps the loaded CAT instance. Args: + config: Config for DeId model pack (primarily for stride of overlap window) model_pack_path (str): The model pack path. Returns: NerModel: The resulting DeI model. """ - cat = CAT.load_model_pack(model_pack_path) + cat = CAT.load_model_pack(model_pack_path,ner_config_dict=config) return cls(cat) diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py index 01c9c1af3..0eed7b6da 100644 --- a/tests/utils/ner/test_deid.py +++ b/tests/utils/ner/test_deid.py @@ -41,11 +41,11 @@ def test_can_create_model(self): deid_model = deid.DeIdModel.create(ner) self.assertIsNotNone(deid_model) - def _add_model(cls): cdb = make_or_update_cdb(TRAIN_DATA) config = transformers_ner.ConfigTransformersNER() config.general['test_size'] = 0.1 # Usually set this to 0.1-0.2 + config.general['chunking_overlap_window'] = None cls.ner = transformers_ner.TransformersNER(cdb=cdb, config=config) cls.ner.training_arguments.num_train_epochs = 1 # Use 5-10 normally # As we are NOT training on a GPU that can, we'll set it to 1