Skip to content

Commit

Permalink
updates in the base tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Hk669 committed May 29, 2024
1 parent e011cd1 commit e6799d6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
6 changes: 4 additions & 2 deletions bpetokenizer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def render_token(t: bytes) -> str:
class Tokenizer:
"""A Base class for the tokenizer, used for training and encoding/decoding the text without special tokens."""

def __init__(self):
def __init__(self, special_tokens=None):
self.merges = {}
self.pattern = "" # the regex pattern
self.compiled_pattern = re.compile(self.pattern) if self.pattern else ""
self.special_tokens = {}
self.special_tokens = special_tokens if special_tokens else {}
self.vocab = self._build_vocab() if self.merges else {}

def _build_vocab(self) -> dict:
Expand Down Expand Up @@ -176,6 +176,7 @@ def load(self, file_name, mode="json"):

def encode(self, texts):
"""Method to encode the text to ids."""
assert texts
text_bytes = texts.encode("utf-8") # raw bytes string
ids = list(map(int, text_bytes))
while len(ids) >= 2:
Expand Down Expand Up @@ -206,6 +207,7 @@ def train(self, texts, vocab_size, verbose=False, min_frequency=2):
min_frequency: int (the minimum frequency of the pair to be merged and added into the vocab as a new token)
"""
assert vocab_size >= 256
assert texts
num_merges = vocab_size - 256

tokens = texts.encode("utf-8")
Expand Down
54 changes: 54 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from bpetokenizer import Tokenizer


@pytest.fixture()
def tokenizer():
text = "<|start|>This is a test text for training the vocab of the tokenizer<|end|>"
special_tokens = {
"<|start|>" : 1001,
"<|end|>": 1002
}
tokenizer = Tokenizer(special_tokens=special_tokens)
tokenizer.train(text, vocab_size=270, min_frequency=0)
return tokenizer

def test_train():
text = "<|start|>This is a test text for training the vocab of the tokenizer<|end|>"
special_tokens = {
"<|start|>" : 1001,
"<|end|>": 1002
}
tokenizer = Tokenizer(special_tokens=special_tokens)
tokenizer.train(text, vocab_size=270, min_frequency=0)
assert tokenizer.encode(text)
assert len(tokenizer.vocab) == 270
assert len(tokenizer.merges) == 270 - 256
assert tokenizer.decode(tokenizer.encode(text)) == text

def test_encode(tokenizer):
"""Test encoding with different text lengths and special tokens."""

# Test with short text
short_text = "hello"
encoded_short = tokenizer.encode(short_text)
assert len(encoded_short) > 0 # Encoded text should not be empty

# Test with long text
long_text = "This is a very long text to test the tokenizer's encoding capabilities."
encoded_long = tokenizer.encode(long_text)
assert len(encoded_long) > 0 # Encoded text should not be empty

# Test with special tokens
special_text = "<|start|>This has special tokens<|end|>"
tokenizer.train(special_text, vocab_size=260, min_frequency=0)
encoded_special = tokenizer.encode(special_text)
assert all(t in tokenizer.vocab for t in encoded_special) # All tokens should be in vocab


def test_decode(tokenizer):
"""Test decoding functionality with different encoded inputs."""

encoded_text = [1, 2, 3]
decoded_text = tokenizer.decode(encoded_text)
assert len(decoded_text) > 0

0 comments on commit e6799d6

Please sign in to comment.