Skip to content

Commit

Permalink
feat: allow creating JurassicTokenizer from model file handle
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 committed Jan 2, 2024
1 parent 7b8348d commit 59c9c5e
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from dataclasses import dataclass
from typing import List, Union, Optional, Dict, Any, Tuple
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

Expand All @@ -19,11 +19,16 @@ class SpaceSymbol:
class JurassicTokenizer(BaseTokenizer):
def __init__(
self,
model_path: PathLike,
model_path: Optional[PathLike],
model_file_handle: Optional[BinaryIO],
config: Optional[Dict[str, Any]] = None,
):
JurassicTokenizer._assert_exactly_one(model_path, model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=load_binary(model_path))
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}

self.pad_id = config.get("pad_id")
Expand Down Expand Up @@ -52,6 +57,20 @@ def __init__(
self._space_mode = config.get("space_mode")
self._space_tokens = self._map_space_tokens()

@classmethod
def from_file_handle(
cls, model_file_handle: BinaryIO, config: Optional[Dict[str, Any]] = None
) -> JurassicTokenizer:
return cls(model_path=None, model_file_handle=model_file_handle, config=config)

@staticmethod
def _assert_exactly_one(model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO]) -> None:
if model_path is None and model_file_handle is None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got none.")

if model_path is not None and model_file_handle is not None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got both.")

def _map_space_tokens(self) -> List[SpaceSymbol]:
res = []
for count in range(32, 0, -1):
Expand Down

0 comments on commit 59c9c5e

Please sign in to comment.