Skip to content

Commit

Permalink
working on adding roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
bramiozo committed Aug 15, 2023
1 parent 9711554 commit 5ea7c82
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 6 deletions.
21 changes: 16 additions & 5 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,17 @@ def load_model_pack(cls,
vocab = None

# Find meta models in the model_pack
trf_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
trf_paths = [os.path.join(model_pack_path, path)
for path in os.listdir(model_pack_path) if path.startswith('trf_')] if load_addl_ner else []
addl_ner = []
for trf_path in trf_paths:
trf = TransformersNER.load(save_dir_path=trf_path)
trf.cdb = cdb # Set the cat.cdb to be the CDB of the TRF model
addl_ner.append(trf)

# Find meta models in the model_pack
meta_paths = [os.path.join(model_pack_path, path) for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_paths = [os.path.join(model_pack_path, path)
for path in os.listdir(model_pack_path) if path.startswith('meta_')] if load_meta_models else []
meta_cats = []
for meta_path in meta_paths:
meta_cats.append(MetaCAT.load(save_dir_path=meta_path,
Expand Down Expand Up @@ -821,15 +823,24 @@ def add_and_train_concept(self,
\*\*other:
Refer to medcat.cat.cdb.CDB.add_concept
"""
names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config)
names = prepare_name(name, self.pipe.spacy_nlp, {},
self.config)
# Only if not negative, otherwise do not add the new name if in fact it should not be detected
if do_add_concept and not negative:
self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description,
self.cdb.add_concept(cui=cui, names=names,
ontologies=ontologies,
name_status=name_status,
type_ids=type_ids,
description=description,
full_build=full_build)

if spacy_entity is not None and spacy_doc is not None:
# Train Linking
self.linker.context_model.train(cui=cui, entity=spacy_entity, doc=spacy_doc, negative=negative, names=names) # type: ignore
self.linker.context_model.train(cui=cui,
entity=spacy_entity,
doc=spacy_doc,
negative=negative,
names=names) # type: ignore

if not negative and devalue_others:
# Find all cuis
Expand Down
9 changes: 9 additions & 0 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
if config.model['model_name'] == 'lstm':
from medcat.utils.meta_cat.models import LSTM
model = LSTM(embeddings, config)
elif config.model['model_name'] == 'roberta':
from medcat.utils.meta_cat.models import RoBERTaForMetaAnnotation
model = RoBERTaForMetaAnnotation(config)
elif config.model['model_name'] == 'bert':
from medcat.utils.meta_cat.models import BertForMetaAnnotation
model = BertForMetaAnnotation(config)
else:
raise ValueError("Unknown model name %s" % config.model['model_name'])

Expand Down Expand Up @@ -342,6 +348,9 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
elif config.general['tokenizer_name'] == 'bert-tokenizer':
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
tokenizer = TokenizerWrapperBERT.load(save_dir_path)
elif config.general['tokenizer_name'] == 'roberta-tokenizer':
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperRoBERTa
tokenizer = TokenizerWrapperRoBERTa.load(save_dir_path)

# Create meta_cat
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)
Expand Down
69 changes: 68 additions & 1 deletion medcat/tokenizers/meta_cat_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Dict, Optional, Union, overload
from tokenizers import Tokenizer, ByteLevelBPETokenizer
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast

from transformers.models.roberta.tokenization_roberta_fast import RobertaTokenizerFast

class TokenizerWrapperBase(ABC):

Expand Down Expand Up @@ -201,3 +201,70 @@ def token_to_id(self, token: str) -> Union[int, List[int]]:
def get_pad_id(self) -> Optional[int]:
self.hf_tokenizers = self.ensure_tokenizer()
return self.hf_tokenizers.pad_token_id

class TokenizerWrapperRoBERTa(TokenizerWrapperBase):
"""Wrapper around a huggingface BERT tokenizer so that it works with the
MetaCAT models.
Args:
transformers.models.bert.tokenization_bert_fast.BertTokenizerFast:
A huggingface Fast BERT.
"""
name = 'roberta-tokenizer'

def __init__(self, hf_tokenizers: Optional[RobertaTokenizerFast] = None) -> None:
super().__init__(hf_tokenizers)

@overload
def __call__(self, text: str) -> Dict: ...

@overload
def __call__(self, text: List[str]) -> List[Dict]: ...

def __call__(self, text: Union[str, List[str]]) -> Union[Dict, List[Dict]]:
self.hf_tokenizers = self.ensure_tokenizer()
if isinstance(text, str):
result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True,
add_special_tokens=False)

return {'offset_mapping': result['offset_mapping'],
'input_ids': result['input_ids'],
'tokens': self.hf_tokenizers.convert_ids_to_tokens(result['input_ids']),
}
elif isinstance(text, list):
results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True,
add_special_tokens=False)
output = []
for ind in range(len(results['input_ids'])):
output.append({'offset_mapping': results['offset_mapping'][ind],
'input_ids': results['input_ids'][ind],
'tokens': self.hf_tokenizers.convert_ids_to_tokens(results['input_ids'][ind]),
})
return output
else:
raise Exception("Unsuported input type, supported: text/list, but got: {}".format(type(text)))

def save(self, dir_path: str) -> None:
self.hf_tokenizers = self.ensure_tokenizer()
path = os.path.join(dir_path, self.name)
self.hf_tokenizers.save_pretrained(path)

@classmethod
def load(cls, dir_path: str, **kwargs) -> "TokenizerWrapperRoBERTa":
tokenizer = cls()
path = os.path.join(dir_path, cls.name)
tokenizer.hf_tokenizers = RobertaTokenizerFast.from_pretrained(path, **kwargs)

return tokenizer

def get_size(self) -> int:
self.hf_tokenizers = self.ensure_tokenizer()
return len(self.hf_tokenizers.vocab)

def token_to_id(self, token: str) -> Union[int, List[int]]:
self.hf_tokenizers = self.ensure_tokenizer()
return self.hf_tokenizers.convert_tokens_to_ids(token)

def get_pad_id(self) -> Optional[int]:
self.hf_tokenizers = self.ensure_tokenizer()
return self.hf_tokenizers.pad_token_id
72 changes: 72 additions & 0 deletions medcat/utils/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import nn, Tensor
from torch.nn import CrossEntropyLoss
from transformers import BertPreTrainedModel, BertModel, BertConfig
from transformers import RoBERTaPreTrainedModel, RobertaModel, RobertaConfig
from transformers.modeling_outputs import TokenClassifierOutput
from medcat.meta_cat import ConfigMetaCAT

Expand Down Expand Up @@ -147,3 +148,74 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

class RoBERTaForMetaAnnotation(RoBERTaPreTrainedModel):

_keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore

def __init__(self, config: RobertaConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels

self.bert = RobertaModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights() # type: ignore

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
center_positions: Optional[Any] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> TokenClassifierOutput:
"""labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1]``.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict # type: ignore

outputs = self.bert( # type: ignore
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0] # (batch_size, sequence_length, hidden_size)

row_indices = torch.arange(0, sequence_output.size(0)).long()
sequence_output = sequence_output[row_indices, center_positions, :]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

0 comments on commit 5ea7c82

Please sign in to comment.