Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Jamba instruct tokenizer #84

Merged
merged 24 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e6fe0c8
feat: added huggingface dependency to project
asafgardin Mar 26, 2024
fe91a26
feat: Added jamba tokenizer
asafgardin Mar 26, 2024
5f41587
test: Added tests to factory
asafgardin Mar 26, 2024
247e68c
refactor: Removed unnecessary const
asafgardin Mar 26, 2024
b9c0158
refactor: Removed redundant code
asafgardin Mar 26, 2024
7dba113
docs: Added comment
asafgardin Mar 26, 2024
7e37fd0
refactor: Added import to init
asafgardin Mar 26, 2024
7e4c8a9
fix: path
asafgardin Mar 26, 2024
9993dcd
refactor: Moved config load to the init
asafgardin Mar 27, 2024
a78053c
fix: Removed python 3.7 support in tests
asafgardin Mar 27, 2024
5dff07f
refactor: Renamed file
asafgardin Mar 27, 2024
2e6fabf
test: Added skip
asafgardin Mar 27, 2024
dba73dd
test: Removed 3.12 (testing)
asafgardin Mar 27, 2024
f28164d
fix: python support change
asafgardin Mar 27, 2024
b8dfc84
fix: Added JambaInstructTokenizer to main init
asafgardin Mar 27, 2024
5a2f628
fix: Changed condition
asafgardin Mar 27, 2024
6340fc0
refactor: Moved methods
asafgardin Mar 27, 2024
6010ff4
fix: Added value error raise if tokenizer name not found
asafgardin Mar 27, 2024
6465b78
fix: uses tokenizer instead
asafgardin Mar 28, 2024
8897733
refactor: used tokenizers package instead of transformers
asafgardin Mar 28, 2024
5ccd655
fix: Added docstring and organized file
asafgardin Mar 28, 2024
7755a1a
test: Added unittests to jamba instruct tokenizer
asafgardin Mar 28, 2024
9d193c9
feat: Used hf path to tokenizer
asafgardin Mar 28, 2024
262afe8
fix: CR
asafgardin Mar 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ai21_tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.tokenizer_factory import TokenizerFactory as Tokenizer
from ai21_tokenizer.tokenizer_factory import TokenizerFactory as Tokenizer, PreTrainedTokenizers
from .version import VERSION

__version__ = VERSION

__all__ = ["Tokenizer", "JurassicTokenizer", "BaseTokenizer", "__version__"]
__all__ = [
"Tokenizer",
"JurassicTokenizer",
"BaseTokenizer",
"__version__",
"PreTrainedTokenizers",
"JambaInstructTokenizer",
]
78 changes: 78 additions & 0 deletions ai21_tokenizer/jamba_instruct_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import os
import tempfile
from pathlib import Path
from typing import Union, List, cast, Optional

from tokenizers import Tokenizer

from ai21_tokenizer import BaseTokenizer
from ai21_tokenizer.utils import PathLike

_TOKENIZER_FILE = "tokenizer.json"
_DEFAULT_MODEL_CACHE_DIR = Path(tempfile.gettempdir()) / "jamba_instruct"


class JambaInstructTokenizer(BaseTokenizer):
_tokenizer: Tokenizer

def __init__(
self,
model_path: str,
cache_dir: Optional[PathLike] = None,
):
"""
Args:
model_path: str
The identifier of a Model on the Hugging Face Hub, that contains a tokenizer.json file
cache_dir: Optional[PathLike]
The directory to cache the tokenizer.json file.
If not provided, the default cache directory will be used
"""
self._tokenizer = self._init_tokenizer(model_path=model_path, cache_dir=cache_dir or _DEFAULT_MODEL_CACHE_DIR)

def _init_tokenizer(self, model_path: PathLike, cache_dir: PathLike) -> Tokenizer:
if self._is_cached(cache_dir):
return self._load_from_cache(cache_dir / _TOKENIZER_FILE)

tokenizer = cast(
Tokenizer,
Tokenizer.from_pretrained(model_path),
)
self._cache_tokenizer(tokenizer, cache_dir)

return tokenizer

def _is_cached(self, cache_dir: PathLike) -> bool:
return Path(cache_dir).exists() and _TOKENIZER_FILE in os.listdir(cache_dir)

def _load_from_cache(self, cache_file: Path) -> Tokenizer:
return cast(Tokenizer, Tokenizer.from_file(str(cache_file)))

def _cache_tokenizer(self, tokenizer: Tokenizer, cache_dir: PathLike) -> None:
# create cache directory for caching the tokenizer and save it
Path(cache_dir).mkdir(parents=True, exist_ok=True)
tokenizer.save(str(cache_dir / _TOKENIZER_FILE))

def encode(self, text: str, **kwargs) -> List[int]:
return self._tokenizer.encode(text, **kwargs).ids

def decode(self, token_ids: List[int], **kwargs) -> str:
return self._tokenizer.decode(token_ids, **kwargs)

def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
if isinstance(tokens, str):
return self._tokenizer.token_to_id(tokens)

return [self._tokenizer.token_to_id(token) for token in tokens]

def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> Union[str, List[str]]:
if isinstance(token_ids, int):
return self._tokenizer.id_to_token(token_ids)

return [self._tokenizer.id_to_token(token_id) for token_id in token_ids]

@property
def vocab_size(self) -> int:
return self._tokenizer.get_vocab_size()
25 changes: 22 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.utils import load_binary, is_number, PathLike
from ai21_tokenizer.utils import load_binary, is_number, PathLike, load_json

_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"


@dataclass
Expand All @@ -25,11 +29,11 @@ def __init__(
):
self._validate_init(model_path=model_path, model_file_handle=model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()
model_proto = load_binary(self._get_model_file(model_path)) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}
config = self._get_config(model_path=model_path, config=config)

self.pad_id = config.get("pad_id")
self.unk_id = config.get("unk_id")
Expand Down Expand Up @@ -64,6 +68,21 @@ def _validate_init(self, model_path: Optional[PathLike], model_file_handle: Opti
if model_path is not None and model_file_handle is not None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got both.")

def _get_model_file(self, model_path: PathLike) -> PathLike:
model_path = Path(model_path)

if model_path.is_dir():
return model_path / f"{model_path.name}{_MODEL_EXTENSION}"

return model_path

def _get_config(self, model_path: Optional[PathLike], config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if model_path and Path(model_path).is_dir():
config_path = model_path / _MODEL_CONFIG_FILENAME
return load_json(config_path)

return config or {}

def _map_space_tokens(self) -> List[SpaceSymbol]:
res = []
for count in range(32, 0, -1):
Expand Down
37 changes: 11 additions & 26 deletions ai21_tokenizer/tokenizer_factory.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import os
from pathlib import Path
from typing import Dict, Any

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.jamba_instruct_tokenizer import JambaInstructTokenizer
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import load_json

_LOCAL_RESOURCES_PATH = Path(__file__).parent / "resources"
_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"
_ENV_CACHE_DIR_KEY = "AI21_TOKENIZER_CACHE_DIR"
JAMABA_TOKENIZER_HF_PATH = "ai21labs/Jamba-v0.1"


class PreTrainedTokenizers:
J2_TOKENIZER = "j2-tokenizer"


_PRETRAINED_MODEL_NAMES = [
PreTrainedTokenizers.J2_TOKENIZER,
]
JAMBA_INSTRUCT_TOKENIZER = "jamba-instruct-tokenizer"


class TokenizerFactory:
Expand All @@ -25,23 +21,12 @@ class TokenizerFactory:
Currently supports only J2-Tokenizer
"""

_tokenizer_name = PreTrainedTokenizers.J2_TOKENIZER

@classmethod
def get_tokenizer(cls) -> BaseTokenizer:
config = cls._get_config(cls._tokenizer_name)
model_path = cls._model_path(cls._tokenizer_name)
return JurassicTokenizer(model_path=model_path, config=config)
def get_tokenizer(cls, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> BaseTokenizer:
if tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER:
return JambaInstructTokenizer(model_path=JAMABA_TOKENIZER_HF_PATH, cache_dir=os.getenv(_ENV_CACHE_DIR_KEY))

@classmethod
def _tokenizer_dir(cls, tokenizer_name: str) -> Path:
return _LOCAL_RESOURCES_PATH / tokenizer_name
if tokenizer_name == PreTrainedTokenizers.J2_TOKENIZER:
return JurassicTokenizer(_LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER)

@classmethod
def _model_path(cls, tokenizer_name: str) -> Path:
return cls._tokenizer_dir(tokenizer_name) / f"{tokenizer_name}{_MODEL_EXTENSION}"

@classmethod
def _get_config(cls, tokenizer_name: str) -> Dict[str, Any]:
config_path = cls._tokenizer_dir(tokenizer_name) / _MODEL_CONFIG_FILENAME
return load_json(config_path)
raise ValueError(f"Tokenizer {tokenizer_name} is not supported")
Loading
Loading