diff --git a/ai21_tokenizer/jurassic_tokenizer.py b/ai21_tokenizer/jurassic_tokenizer.py index 026d2c8..cbeb724 100644 --- a/ai21_tokenizer/jurassic_tokenizer.py +++ b/ai21_tokenizer/jurassic_tokenizer.py @@ -174,14 +174,14 @@ def decode(self, token_ids: List[int], **kwargs) -> str: """ Transforms token ids into text """ - res_text, _ = self.decode_with_offsets(token_ids) + res_text, _ = self.decode_with_offsets(token_ids, **kwargs) return res_text - def decode_with_offsets(self, token_ids: List[int]) -> Tuple[str, List[Tuple[int, int]]]: + def decode_with_offsets(self, token_ids: List[int], **kwargs) -> Tuple[str, List[Tuple[int, int]]]: """ Transforms token ids into text, and returns the offsets of each token as well """ - start_of_line = True + start_of_line = kwargs.get("start_of_line", True) res_text = "" offsets = []