diff --git a/medcat/cat.py b/medcat/cat.py index b2d3f7cb3..a439d9301 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -375,7 +375,8 @@ def load_model_pack(cls, vocab = None # Find meta models in the model_pack - trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else [] + trf_paths = [os.path.join(model_pack_path, path) + for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else [] addl_ner = [] for trf_path in trf_paths: trf = TransformersNER.load(save_dir_path=trf_path) @@ -383,7 +384,8 @@ def load_model_pack(cls, addl_ner.append(trf) # Find meta models in the model_pack - meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else [] + meta_paths = [os.path.join(model_pack_path, path) + for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else [] meta_cats = [] for meta_path in meta_paths: meta_cats.append(MetaCAT.load(save_dir_path=meta_path, @@ -821,15 +823,24 @@ def add_and_train_concept(self, \*\*other: Refer to medcat.cat.cdb.CDB.add_concept """ - names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config) + names = prepare_name(name, self.pipe.spacy_nlp, {}, + self.config) # Only if not negative, otherwise do not add the new name if in fact it should not be detected if do_add_concept and not negative: - self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description, + self.cdb.add_concept(cui=cui, names=names, + ontologies=ontologies, + name_status=name_status, + type_ids=type_ids, + description=description, full_build=full_build) if spacy_entity is not None and spacy_doc is not None: # Train Linking - self.linker.context_model.train(cui=cui, entity=spacy_entity, doc=spacy_doc, negative=negative, names=names) # type: ignore + self.linker.context_model.train(cui=cui, + entity=spacy_entity, + doc=spacy_doc, + negative=negative, + names=names) # type: ignore if not negative and devalue_others: # Find all cuis diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index d92e6ea61..7f84faed5 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -84,6 +84,12 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module: if config.model['model_name'] == 'lstm': from medcat.utils.meta_cat.models import LSTM model = LSTM(embeddings, config) + elif config.model['model_name'] == 'roberta': + from medcat.utils.meta_cat.models import RoBERTaForMetaAnnotation + model = RoBERTaForMetaAnnotation(config) + elif config.model['model_name'] == 'bert': + from medcat.utils.meta_cat.models import BertForMetaAnnotation + model = BertForMetaAnnotation(config) else: raise ValueError("Unknown model name %s" % config.model['model_name']) @@ -342,6 +348,9 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA elif config.general['tokenizer_name'] == 'bert-tokenizer': from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT tokenizer = TokenizerWrapperBERT.load(save_dir_path) + elif config.general['tokenizer_name'] == 'roberta-tokenizer': + from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperRoBERTa + tokenizer = TokenizerWrapperRoBERTa.load(save_dir_path) # Create meta_cat meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config) diff --git a/medcat/tokenizers/meta_cat_tokenizers.py b/medcat/tokenizers/meta_cat_tokenizers.py index fd6b26b30..d2b6f9074 100644 --- a/medcat/tokenizers/meta_cat_tokenizers.py +++ b/medcat/tokenizers/meta_cat_tokenizers.py @@ -3,7 +3,7 @@ from typing import List, Dict, Optional, Union, overload from tokenizers import Tokenizer, ByteLevelBPETokenizer from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast - +from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast class TokenizerWrapperBase(ABC): @@ -201,3 +201,70 @@ def token_to_id(self, token: str) -> Union[int, List[int]]: def get_pad_id(self) -> Optional[int]: self.hf_tokenizers = self.ensure_tokenizer() return self.hf_tokenizers.pad_token_id + +class TokenizerWrapperRoBERTa(TokenizerWrapperBase): + """Wrapper around a huggingface BERT tokenizer so that it works with the + MetaCAT models. + + Args: + transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: + A huggingface Fast BERT. + """ + name = 'roberta-tokenizer' + + def __init__(self, hf_tokenizers: Optional[RobertaTokenizerFast] = None) -> None: + super().__init__(hf_tokenizers) + + @overload + def __call__(self, text: str) -> Dict: ... + + @overload + def __call__(self, text: List[str]) -> List[Dict]: ... + + def __call__(self, text: Union[str, List[str]]) -> Union[Dict, List[Dict]]: + self.hf_tokenizers = self.ensure_tokenizer() + if isinstance(text, str): + result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, + add_special_tokens=False) + + return {'offset_mapping': result['offset_mapping'], + 'input_ids': result['input_ids'], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(result['input_ids']), + } + elif isinstance(text, list): + results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, + add_special_tokens=False) + output = [] + for ind in range(len(results['input_ids'])): + output.append({'offset_mapping': results['offset_mapping'][ind], + 'input_ids': results['input_ids'][ind], + 'tokens': self.hf_tokenizers.convert_ids_to_tokens(results['input_ids'][ind]), + }) + return output + else: + raise Exception("Unsuported input type, supported: text/list, but got: {}".format(type(text))) + + def save(self, dir_path: str) -> None: + self.hf_tokenizers = self.ensure_tokenizer() + path = os.path.join(dir_path, self.name) + self.hf_tokenizers.save_pretrained(path) + + @classmethod + def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperRoBERTa": + tokenizer = cls() + path = os.path.join(dir_path, cls.name) + tokenizer.hf_tokenizers = RobertaTokenizerFast.from_pretrained(path, **kwargs) + + return tokenizer + + def get_size(self) -> int: + self.hf_tokenizers = self.ensure_tokenizer() + return len(self.hf_tokenizers.vocab) + + def token_to_id(self, token: str) -> Union[int, List[int]]: + self.hf_tokenizers = self.ensure_tokenizer() + return self.hf_tokenizers.convert_tokens_to_ids(token) + + def get_pad_id(self) -> Optional[int]: + self.hf_tokenizers = self.ensure_tokenizer() + return self.hf_tokenizers.pad_token_id \ No newline at end of file diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index 7b1398e20..b08757b73 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -4,6 +4,7 @@ from torch import nn, Tensor from torch.nn import CrossEntropyLoss from transformers import BertPreTrainedModel, BertModel, BertConfig +from transformers import RoBERTaPreTrainedModel, RobertaModel, RobertaConfig from transformers.modeling_outputs import TokenClassifierOutput from medcat.meta_cat import ConfigMetaCAT @@ -147,3 +148,74 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class RoBERTaForMetaAnnotation(RoBERTaPreTrainedModel): + + _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore + + def __init__(self, config: RobertaConfig) -> None: + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = RobertaModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() # type: ignore + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + center_positions: Optional[Any] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> TokenClassifierOutput: + """labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore + + outputs = self.bert( # type: ignore + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] # (batch_size, sequence_length, hidden_size) + + row_indices = torch.arange(0, sequence_output.size(0)).long() + sequence_output = sequence_output[row_indices, center_positions, :] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )