From a21d96dcb935c37ebd0a413124c4caecb26dbbb0 Mon Sep 17 00:00:00 2001 From: asafg Date: Thu, 20 Jun 2024 07:11:35 +0300 Subject: [PATCH] fix: Added backwards compatibility for jamba tokenizer --- ai21_tokenizer/tokenizer_factory.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ai21_tokenizer/tokenizer_factory.py b/ai21_tokenizer/tokenizer_factory.py index 30713e1..c7e60cd 100644 --- a/ai21_tokenizer/tokenizer_factory.py +++ b/ai21_tokenizer/tokenizer_factory.py @@ -12,7 +12,8 @@ class PreTrainedTokenizers: J2_TOKENIZER = "j2-tokenizer" - JAMBA_INSTRUCT_TOKENIZER = "jamba-tokenizer" + JAMBA_INSTRUCT_TOKENIZER = "jamba-instruct-tokenizer" + JAMBA_TOKENIZER = "jamba-tokenizer" class TokenizerFactory: @@ -26,7 +27,10 @@ def get_tokenizer( cls, tokenizer_name: str = PreTrainedTokenizers.J2_TOKENIZER, ) -> BaseTokenizer: - if tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER: + if ( + tokenizer_name == PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER + or tokenizer_name == PreTrainedTokenizers.JAMBA_TOKENIZER + ): return JambaInstructTokenizer(model_path=JAMBA_TOKENIZER_HF_PATH, cache_dir=os.getenv(_ENV_CACHE_DIR_KEY)) if tokenizer_name == PreTrainedTokenizers.J2_TOKENIZER: