Skip to content

Commit

Permalink
test: Added new test file to validate case where model'embedding has …
Browse files Browse the repository at this point in the history
…mismatch with tokenizer's vocab (#85)
  • Loading branch information
Saibo-creator authored Aug 27, 2024
1 parent 541cab4 commit 0586750
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/test_hf_generation/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
MODEL_IDS = [
"hf-internal-testing/tiny-random-GPTJForCausalLM",
"JackFram/llama-68m",
# "hf-internal-testing/tiny-random-PhiForCausalLM",
"hf-internal-testing/tiny-random-PhiForCausalLM",
"hf-internal-testing/tiny-random-gpt2",
# "hf-internal-testing/tiny-random-BlenderbotForCausalLM",
]
Expand Down
121 changes: 121 additions & 0 deletions tests/test_hf_generation/test_generation_w_expanded_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from unittest import TestCase
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.token_grammar_recognizer import IncrementalTokenRecognizer
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

MODEL_IDS = [
"JackFram/llama-68m",
]


def check_parentheses(generation):
stack = []
for char in generation:
if char == "(":
stack.append(char)
elif char == ")":
if not stack:
return False
stack.pop()
return not stack


class TestGreedyDecoding(TestCase):
@classmethod
def setUpClass(cls):
cls.models = {}
cls.tokenizers = {}
for model_id in MODEL_IDS:
cls.models[model_id] = AutoModelForCausalLM.from_pretrained(model_id)
cls.tokenizers[model_id] = AutoTokenizer.from_pretrained(model_id)
cls.tokenizers[model_id].pad_token = cls.tokenizers[model_id].eos_token
# we expand the embedding layer to simulate the case where the model has a larger embedding layer than the tokenizer
cls.models[model_id].resize_token_embeddings(
10 + len(cls.tokenizers[model_id])
)

def test_generate_only_number(self):
# test greedy decoding with grammar constraints
grammar_str = """
root ::= [0-9]+
"""

for model_id in MODEL_IDS:
model = self.models[model_id]
tokenizer = self.tokenizers[model_id]

grammar = IncrementalTokenRecognizer(
grammar_str, start_rule_name="root", tokenizer=tokenizer
)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

prefix = "This is a valid number:"

input_ids = tokenizer(
[prefix], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"]

output = model.generate(
input_ids,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
max_new_tokens=40,
top_p=0.92,
top_k=5,
logits_processor=[grammar_processor],
repetition_penalty=100.0,
early_stopping=True,
)

generations = tokenizer.batch_decode(
output[:, input_ids.shape[1] :], skip_special_tokens=True
)
self.assertTrue(
generations[0].isdigit(), f"generations: {generations} is not a number"
)

def test_generate_balanced_parenthesis(self):
# test greedy decoding with grammar constraints
grammar_str = """
root ::= "(" root ")" | ""
"""

for model_id in MODEL_IDS:
model = self.models[model_id]
tokenizer = self.tokenizers[model_id]

grammar = IncrementalTokenRecognizer(
grammar_str, start_rule_name="root", tokenizer=tokenizer
)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

prefix = "This is a valid json:"

input_ids = tokenizer(
[prefix], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"]

output = model.generate(
input_ids,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
num_beams=1,
max_new_tokens=40,
top_p=0.92,
top_k=5,
logits_processor=[grammar_processor],
repetition_penalty=100.0,
early_stopping=True,
)

generation: str = tokenizer.batch_decode(
output[:, input_ids.shape[1] :], skip_special_tokens=True
)[0]

self.assertTrue(
check_parentheses(generation),
f"generations: {generation} is not a balanced parenthesis",
)

0 comments on commit 0586750

Please sign in to comment.