Skip to content

Commit

Permalink
fix: cr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
miri-bar committed Jun 18, 2024
1 parent fdbe9a8 commit 9fe658f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 42 deletions.
6 changes: 5 additions & 1 deletion ai21_tokenizer/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Union
import asyncio


class BaseTokenizer(ABC):
Expand Down Expand Up @@ -82,7 +83,7 @@ class AsyncBaseTokenizer(ABC):
"""
Base class for tokenizers.
This class defines the interface for tokenization operations such as encoding, decoding,
This class defines the interface for async tokenization operation such as encoding, decoding,
converting tokens to IDs, and converting IDs to tokens.
"""

Expand Down Expand Up @@ -151,3 +152,6 @@ def vocab_size(self) -> int:
int: The size of the vocabs.
"""
pass

async def _make_async_call(self, callback_func, **kwargs):
return await asyncio.get_running_loop().run_in_executor(executor=None, func=lambda: callback_func(**kwargs))
35 changes: 12 additions & 23 deletions ai21_tokenizer/jamba_instruct_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
import tempfile
from pathlib import Path
from typing import Union, List, Optional, cast
Expand Down Expand Up @@ -70,7 +69,7 @@ class AsyncJambaInstructTokenizer(BaseJambaInstructTokenizer, AsyncBaseTokenizer

def __init__(self):
raise ValueError(
"Create object with context manager only.Use either AsyncJambaInstructTokenizer.create or "
"Do not create AsyncJambaInstructTokenizer directly. Use either AsyncJambaInstructTokenizer.create or "
"Tokenizer.get_async_tokenizer"
)

Expand All @@ -97,34 +96,26 @@ async def create(
async def encode(self, text: str, **kwargs) -> List[int]:
if not self._tokenizer:
await self._init_tokenizer()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._encode(text=text, **kwargs)
)

return await self._make_async_call(callback_func=self._encode, text=text, **kwargs)

async def decode(self, token_ids: List[int], **kwargs) -> str:
if not self._tokenizer:
await self._init_tokenizer()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._decode(token_ids=token_ids, **kwargs)
)

return await self._make_async_call(callback_func=self._decode, token_ids=token_ids, **kwargs)

async def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
if not self._tokenizer:
await self._init_tokenizer()
return await asyncio.get_running_loop().run_in_executor(
None,
self._convert_tokens_to_ids,
tokens,
)

return await self._make_async_call(callback_func=self._convert_tokens_to_ids, tokens=tokens)

async def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> Union[str, List[str]]:
if not self._tokenizer:
await self._init_tokenizer()
return await asyncio.get_running_loop().run_in_executor(
None,
self._convert_ids_to_tokens,
token_ids,
)

return await self._make_async_call(callback_func=self._convert_ids_to_tokens, token_ids=token_ids, **kwargs)

@property
def vocab_size(self) -> int:
Expand All @@ -139,8 +130,8 @@ async def _init_tokenizer(self):
if self._is_cached(self._cache_dir):
self._tokenizer = await self._load_from_cache(self._cache_dir / _TOKENIZER_FILE)
else:
tokenizer_from_pretrained = await asyncio.get_running_loop().run_in_executor(
None, Tokenizer.from_pretrained, self._model_path
tokenizer_from_pretrained = await self._make_async_call(
callback_func=Tokenizer.from_pretrained, identifier=self._model_path
)

tokenizer = cast(
Expand All @@ -152,7 +143,5 @@ async def _init_tokenizer(self):
self._tokenizer = tokenizer

async def _load_from_cache(self, cache_file: Path) -> Tokenizer:
tokenizer_from_file = await asyncio.get_running_loop().run_in_executor(
None, Tokenizer.from_file, str(cache_file)
)
tokenizer_from_file = await self._make_async_call(callback_func=Tokenizer.from_file, path=str(cache_file))
return cast(Tokenizer, tokenizer_from_file)
26 changes: 12 additions & 14 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

Expand Down Expand Up @@ -80,7 +79,7 @@ def from_file_path(cls, model_path: PathLike, config: Optional[Dict[str, Any]] =
class AsyncJurassicTokenizer(BaseJurassicTokenizer, AsyncBaseTokenizer):
def __init__(self):
raise ValueError(
"Create object with context manager only.Use either AsyncJurassicTokenizer.create or "
"Do not create AsyncJurassicTokenizer directly.Use either AsyncJurassicTokenizer.create or "
"Tokenizer.get_async_tokenizer"
)

Expand Down Expand Up @@ -121,40 +120,39 @@ async def encode(self, text: str, **kwargs) -> List[int]:
"""
if not self._sp:
await self._aload_model_proto()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._encode_wrapper(text=text, **kwargs)
)

return await self._make_async_call(callback_func=self._encode_wrapper, text=text, **kwargs)

async def decode(self, token_ids: List[int], **kwargs) -> str:
"""
Transforms token ids into text
"""
if not self._sp:
await self._aload_model_proto()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._decode_wrapper(token_ids=token_ids, **kwargs)
)

return await self._make_async_call(callback_func=self._decode_wrapper, token_ids=token_ids, **kwargs)

async def decode_with_offsets(self, token_ids: List[int], **kwargs) -> Tuple[str, List[Tuple[int, int]]]:
"""
Transforms token ids into text, and returns the offsets of each token as well
"""
if not self._sp:
await self._aload_model_proto()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._decode_with_offsets(token_ids=token_ids, **kwargs)
)

return await self._make_async_call(callback_func=self._decode_with_offsets, token_ids=token_ids, **kwargs)

async def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
if not self._sp:
await self._aload_model_proto()
return await asyncio.get_running_loop().run_in_executor(None, self._convert_tokens_to_ids, tokens)

return await self._make_async_call(callback_func=self._convert_tokens_to_ids, tokens=tokens)

async def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> Union[str, List[str]]:
if not self._sp:
await self._aload_model_proto()
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=lambda: self._convert_ids_to_tokens_wrapper(token_ids=token_ids, **kwargs)

return await self._make_async_call(
callback_func=self._convert_ids_to_tokens_wrapper, token_ids=token_ids, **kwargs
)

@classmethod
Expand Down
6 changes: 2 additions & 4 deletions ai21_tokenizer/tokenizer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@ async def get_async_tokenizer(
tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER,
) -> AsyncBaseTokenizer:
if tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER:
jamba_tokenizer = await AsyncJambaInstructTokenizer.create(
return await AsyncJambaInstructTokenizer.create(
model_path=JAMBA_TOKENIZER_HF_PATH, cache_dir=os.getenv(_ENV_CACHE_DIR_KEY)
)
return jamba_tokenizer

if tokenizer_name == PreTrainedTokenizers.J2_TOKENIZER:
jurassic_tokenizer = await AsyncJurassicTokenizer.create(
return await AsyncJurassicTokenizer.create(
model_path=_LOCAL_RESOURCES_PATH / PreTrainedTokenizers.J2_TOKENIZER
)
return jurassic_tokenizer

raise ValueError(f"Tokenizer {tokenizer_name} is not supported")

0 comments on commit 9fe658f

Please sign in to comment.