Skip to content

Commit

Permalink
fix: Added value error raise if tokenizer name not found
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Mar 27, 2024
1 parent 6340fc0 commit 6010ff4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
21 changes: 6 additions & 15 deletions ai21_tokenizer/tokenizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ class PreTrainedTokenizers:
JAMBA_TOKENIZER = "jamba-tokenizer"


_TOKENIZER_MODEL_MAP = {
PreTrainedTokenizers.JAMBA_TOKENIZER: "huggingface-tokenizer-url-placeholder",
PreTrainedTokenizers.J2_TOKENIZER: _LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER,
}


class TokenizerFactory:
"""
Factory class to create AI21 tokenizer
Expand All @@ -29,12 +23,13 @@ class TokenizerFactory:

@classmethod
def get_tokenizer(cls, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER) -> BaseTokenizer:
model_path = cls._model_path(tokenizer_name)

if tokenizer_name == PreTrainedTokenizers.JAMBA_TOKENIZER:
return cls._create_jamaba_tokenizer(model_path)
return cls._create_jamaba_tokenizer("<huggingface-tokenizer-url-placeholder>")

return cls._create_jurassic_tokenizer(model_path)
if tokenizer_name == PreTrainedTokenizers.J2_TOKENIZER:
return cls._create_jurassic_tokenizer(_LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER)

raise ValueError(f"Tokenizer {tokenizer_name} is not supported")

@classmethod
def _create_jamaba_tokenizer(cls, model_path: str) -> JambaInstructTokenizer:
Expand All @@ -43,9 +38,5 @@ def _create_jamaba_tokenizer(cls, model_path: str) -> JambaInstructTokenizer:
return JambaInstructTokenizer(model_path=model_path, cache_dir=_MODEL_CACHE_DIR)

@classmethod
def _create_jurassic_tokenizer(cls, model_path: str) -> JurassicTokenizer:
def _create_jurassic_tokenizer(cls, model_path: PathLike) -> JurassicTokenizer:
return JurassicTokenizer(model_path=model_path)

@classmethod
def _model_path(cls, tokenizer_name: str) -> PathLike:
return _TOKENIZER_MODEL_MAP[tokenizer_name]
5 changes: 5 additions & 0 deletions tests/test_tokenizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ def test_tokenizer_factory__get_tokenizer(

assert tokenizer is not None
assert isinstance(tokenizer, expected_tokenizer_instance)


def test_tokenizer__when_tokenizer_name_is_not_supported__should_raise_value_error() -> None:
with pytest.raises(ValueError):
Tokenizer.get_tokenizer(tokenizer_name="unsupported")

0 comments on commit 6010ff4

Please sign in to comment.