Skip to content

Commit

Permalink
refactor: Moved config load to the init
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Mar 27, 2024
1 parent 7e4c8a9 commit 9993dcd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 20 deletions.
25 changes: 22 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.utils import load_binary, is_number, PathLike
from ai21_tokenizer.utils import load_binary, is_number, PathLike, load_json

_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"


@dataclass
Expand All @@ -25,11 +29,11 @@ def __init__(
):
self._validate_init(model_path=model_path, model_file_handle=model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()
model_proto = load_binary(self._get_model_file(model_path)) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}
config = self._get_config(model_path=model_path, config=config)

self.pad_id = config.get("pad_id")
self.unk_id = config.get("unk_id")
Expand Down Expand Up @@ -57,6 +61,21 @@ def __init__(
self._space_mode = config.get("space_mode")
self._space_tokens = self._map_space_tokens()

def _get_model_file(self, model_path: PathLike) -> PathLike:
model_path = Path(model_path)

if model_path.is_dir():
return model_path / f"{model_path.name}{_MODEL_EXTENSION}"

return model_path

def _get_config(self, model_path: Optional[PathLike], config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if model_path is not None and Path(model_path).is_dir():
config_path = model_path / _MODEL_CONFIG_FILENAME
return load_json(config_path)

return config or {}

def _validate_init(self, model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO]) -> None:
if model_path is None and model_file_handle is None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got none.")
Expand Down
32 changes: 17 additions & 15 deletions ai21_tokenizer/tokenizer_factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import os
from pathlib import Path
from typing import Dict, Any

from ai21_tokenizer.base_tokenizer import BaseTokenizer
from ai21_tokenizer.jamaba_tokenizer import JambaTokenizer
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import load_json
from ai21_tokenizer.utils import PathLike

_LOCAL_RESOURCES_PATH = Path(__file__).parent / "resources"
_MODEL_EXTENSION = ".model"
_MODEL_CONFIG_FILENAME = "config.json"
_MODEL_CACHE_DIR = _LOCAL_RESOURCES_PATH / "cache"


Expand All @@ -18,6 +15,12 @@ class PreTrainedTokenizers:
JAMBA_TOKENIZER = "jamba-tokenizer"


_TOKENIZER_MODEL_MAP = {
PreTrainedTokenizers.JAMBA_TOKENIZER: "huggingface-tokenizer-url-placeholder",
PreTrainedTokenizers.J2_TOKENIZER: _LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER,
}


class TokenizerFactory:
"""
Factory class to create AI21 tokenizer
Expand All @@ -29,21 +32,20 @@ def get_tokenizer(cls, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER)
model_path = cls._model_path(tokenizer_name)

if tokenizer_name == PreTrainedTokenizers.JAMBA_TOKENIZER:
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" # Disable Huggingface advice warning
return JambaTokenizer(model_path=model_path, cache_dir=_MODEL_CACHE_DIR)
return cls._create_jamaba_tokenizer(model_path)

config = cls._get_config(tokenizer_name)
return JurassicTokenizer(model_path=model_path, config=config)
return cls._create_jurassic_tokenizer(model_path)

@classmethod
def _tokenizer_dir(cls, tokenizer_name: str) -> Path:
return _LOCAL_RESOURCES_PATH / tokenizer_name
def _create_jamaba_tokenizer(cls, model_path: str) -> JambaTokenizer:
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" # Disable Huggingface advice warning

return JambaTokenizer(model_path=model_path, cache_dir=_MODEL_CACHE_DIR)

@classmethod
def _model_path(cls, tokenizer_name: str) -> Path:
return cls._tokenizer_dir(tokenizer_name) / f"{tokenizer_name}{_MODEL_EXTENSION}"
def _create_jurassic_tokenizer(cls, model_path: str) -> JurassicTokenizer:
return JurassicTokenizer(model_path=model_path)

@classmethod
def _get_config(cls, tokenizer_name: str) -> Dict[str, Any]:
config_path = cls._tokenizer_dir(tokenizer_name) / _MODEL_CONFIG_FILENAME
return load_json(config_path)
def _model_path(cls, tokenizer_name: str) -> PathLike:
return _TOKENIZER_MODEL_MAP[tokenizer_name]
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from ai21_tokenizer import Tokenizer
from ai21_tokenizer import Tokenizer, PreTrainedTokenizers
from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer


Expand All @@ -13,7 +13,7 @@ def resources_path() -> Path:

@pytest.fixture(scope="session")
def tokenizer() -> JurassicTokenizer:
jurassic_tokenizer = Tokenizer.get_tokenizer()
jurassic_tokenizer = Tokenizer.get_tokenizer(tokenizer_name=PreTrainedTokenizers.J2_TOKENIZER)

if isinstance(jurassic_tokenizer, JurassicTokenizer):
return jurassic_tokenizer
Expand Down
10 changes: 10 additions & 0 deletions tests/test_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,13 @@ def test_tokenizer__(
JurassicTokenizer(model_file_handle=model_file_handle, model_path=model_path, config={})

assert error.value.args[0] == expected_error_message


def test_init__when_model_path_is_a_file__should_support_backwards_compatability():
text = "Hello world!"
tokenizer = JurassicTokenizer(model_path=_LOCAL_RESOURCES_PATH / "j2-tokenizer.model")

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text

0 comments on commit 9993dcd

Please sign in to comment.