Skip to content

Commit

Permalink
Merge branch 'main' into angelayi/aoti_metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Nov 5, 2024
2 parents b2b93c5 + 9480258 commit 7146029
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 7 deletions.
32 changes: 32 additions & 0 deletions tokenizer/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Abstract base class for all tokenizer classes in python matching c++ interface.
"""

# Standard
from abc import ABC, abstractmethod
from typing import List


class TokenizerBase(ABC):
__doc__ = __doc__

@abstractmethod
def encode(self, s: str, *, bos: bool = False, eos: bool = False) -> List[int]:
"""Encode the given string and optionally include bos/eos tokens"""

@abstractmethod
def decode(self, ids: List[int]) -> str:
"""Decode the given token ids into a string"""

@abstractmethod
def bos_id(self) -> int:
"""The id of the begin-of-string token"""

@abstractmethod
def eos_id(self) -> int:
"""The id of the end-of-string token"""
92 changes: 92 additions & 0 deletions tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List, Optional
import json
import os

# Third Party
from tokenizers import Tokenizer

# Local
from .base import TokenizerBase


class HFTokenizer(TokenizerBase):
"""
Wrapper around the Huggingface `tokenizers` library for API compatibility
"""

def __init__(self, file_path: str):
# If the path is a directory, look for "tokenizer.json" which is
# standard for transformers checkpoints and also look for the
# "tokenizer_config.json" file to parse eos/bos tokens
if os.path.isdir(file_path):
tokenizer_path = os.path.join(file_path, "tokenizer.json")
tokenizer_config_path = os.path.join(file_path, "tokenizer_config.json")
else:
tokenizer_path = file_path
tokenizer_config_path = os.path.join(os.path.dirname(file_path), "tokenizer_config.json")
if not os.path.isfile(tokenizer_path):
tokenizer_config_path = None

# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
with open(tokenizer_config_path, "r") as handle:
tok_config = json.load(handle)
bos_token = tok_config.get("bos_token")
eos_token = tok_config.get("eos_token")
if bos_token is not None:
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
tok_content = json.loads(self._tokenizer.to_str())
if self._bos_id is None:
self._bos_id = self._look_for_special_token(tok_content, ["begin", "text"])
if self._eos_id is None:
self._eos_id = self._look_for_special_token(tok_content, ["end", "text"])

assert None not in [self._bos_id, self._eos_id], "Unable to find an BOS/EOS tokens"

@staticmethod
def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optional[int]:
candidate_toks = added_tokens
for search_str in search_strs:
candidate_toks = [
tok for tok in candidate_toks
if tok["special"] and search_str in tok["content"]
]
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

def encode(
self,
s: str,
*,
bos: bool = False,
eos: bool = False,
) -> List[int]:
res = self._tokenizer.encode(s, add_special_tokens=bos).ids
if eos and (not res or res[-1] != self._eos_token):
res.append(self._eos_token)
return res

def decode(self, ids: List[int]) -> str:
return self._tokenizer.decode(ids)

def bos_id(self) -> int:
return self._bos_id

def eos_id(self) -> int:
return self._eos_id
4 changes: 3 additions & 1 deletion tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import tiktoken
from tiktoken.load import load_tiktoken_bpe

from .base import TokenizerBase


logger = getLogger(__name__)

Expand All @@ -38,7 +40,7 @@ class Message(TypedDict):
Dialog = Sequence[Message]


class Tokenizer:
class Tokenizer(TokenizerBase):
"""
tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""
Expand Down
40 changes: 35 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ class TokenizerArgs:
tokenizer_path: Optional[Union[Path, str]] = None
is_sentencepiece: bool = False
is_tiktoken: bool = False
is_hf_tokenizer: bool = False
t: Optional[Any] = None

def __post_init__(self):
Expand All @@ -224,6 +225,7 @@ def __post_init__(self):
self.t = TiktokenTokenizer(model_path=str(self.tokenizer_path))
self.is_tiktoken = True
self.is_sentencepiece = False
self.is_hf_tokenizer = False
return
except:
pass
Expand All @@ -234,12 +236,25 @@ def __post_init__(self):
self.t = SentencePieceProcessor(model_file=str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = True
self.is_hf_tokenizer = False
return
except:
pass

try:
from tokenizer.hf_tokenizer import HFTokenizer

self.t = HFTokenizer(str(self.tokenizer_path))
self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = True
return
except:
pass

self.is_tiktoken = False
self.is_sentencepiece = False
self.is_hf_tokenizer = False
self.t = None
return

Expand All @@ -251,16 +266,27 @@ def validate_model(
if model is None:
return

if self.is_tiktoken == self.is_sentencepiece:
if sum([self.is_tiktoken, self.is_hf_tokenizer, self.is_sentencepiece]) != 1:
raise RuntimeError(f"no tokenizer was found at {self.tokenizer_path}")

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
is_hf_tokenizer = self.is_hf_tokenizer
use_tiktoken = model.config.use_tiktoken
use_hf_tokenizer = model.config.use_hf_tokenizer
use_sentencepiece = not (use_tiktoken or use_hf_tokenizer)

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
if (
(is_tiktoken and not use_tiktoken) or
(is_hf_tokenizer and not use_hf_tokenizer) or
(is_sentencepiece and not use_sentencepiece)
):
raise RuntimeError(
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
"model-specified tokenizer ({}) does not match provided tokenizer ({}) for {}".format(
tokenizer_setting_to_name(use_tiktoken, use_hf_tokenizer),
tokenizer_setting_to_name(is_tiktoken, is_hf_tokenizer),
model_description,
)
)

return
Expand Down Expand Up @@ -655,5 +681,9 @@ def _initialize_model(
return model


def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
return "TikToken" if tiktoken else "SentencePiece"
def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str:
if tiktoken:
return "TikToken"
if tokenizers:
return "Tokenizers"
return "SentencePiece"
8 changes: 7 additions & 1 deletion torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,9 @@ class TransformerArgs:
norm_eps: float = 1e-5
multiple_of: int = 256
ffn_dim_multiplier: Optional[int] = None
# Select the desired tokenizer. Defaults to sentencepiece
use_tiktoken: bool = False
use_hf_tokenizer: bool = False
max_seq_length: int = 8192
rope_scaling: Optional[Dict[str, Any]] = None
# For pipeline parallel
Expand Down Expand Up @@ -327,12 +329,14 @@ class ModelArgs:
model_type: ModelType
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool
use_hf_tokenizer: bool

def __init__(
self,
transformer_args: Dict[str, Dict[str, Any]],
model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
use_hf_tokenizer: bool = False,
) -> None:
self._sanity_check(transformer_args, model_type)

Expand All @@ -341,6 +345,7 @@ def __init__(

# Model-level attributes
self.use_tiktoken = use_tiktoken
self.use_hf_tokenizer = use_hf_tokenizer

def _sanity_check(
self,
Expand All @@ -367,7 +372,8 @@ def from_params(cls, params_path):
}

use_tiktoken = loaded_params.get("use_tiktoken", False)
return cls(transformer_args, model_type, use_tiktoken)
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)

@classmethod
def from_table(cls, name: str):
Expand Down

0 comments on commit 7146029

Please sign in to comment.