Skip to content

Commit

Permalink
Fix regex lookahead for code input tokenization (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Jun 14, 2024
1 parent 6d671d2 commit c9a4a70
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ build/

# data
/data/
*.log
23 changes: 12 additions & 11 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,7 @@ TiktokenCoreBPE::TiktokenCoreBPE(std::unordered_map<std::string, int> encoder,
: regex(std::make_unique<RE2>("(" + pattern + ")")), encoder(std::move(encoder)),
special_tokens_encoder(std::move(special_tokens_encoder)) {
CHATGLM_CHECK(regex->ok()) << regex->error();
CHATGLM_CHECK(regex->NumberOfCapturingGroups() <= 2) << "unimplemented";

decoder.reserve(this->encoder.size());
for (const auto &[token, rank] : this->encoder) {
Expand Down Expand Up @@ -1853,24 +1854,24 @@ std::vector<int> TiktokenCoreBPE::byte_pair_encode(const std::string &piece,

std::vector<int> TiktokenCoreBPE::_encode_ordinary_native(const std::string &text) const {
std::vector<int> ret;
re2::StringPiece input = text;
re2::StringPiece prev_input = input;
re2::StringPiece input(text);
re2::StringPiece prev_input(input);
std::string piece;
while (RE2::FindAndConsume(&input, *regex, &piece)) {
// recover input in case of negative lookahead
if (prev_input.find(piece) == 0) {
input = prev_input.substr(piece.size());
prev_input = input;
} else {
std::cerr << "[WARN] chatglm.cpp: encounter unknown token\n";
std::string piece2;
while (RE2::FindAndConsume(&input, *regex, &piece, &piece2)) {
if (!piece2.empty()) {
// workaround for lookahead: capture sub group and restore input
auto pos = prev_input.find(piece2);
input = prev_input.substr(pos + piece2.length());
piece = std::move(piece2);
}

if (auto it = encoder.find(piece); it != encoder.end()) {
ret.emplace_back(it->second);
} else {
std::vector<int> bpe_ids = byte_pair_encode(piece, encoder);
ret.insert(ret.end(), bpe_ids.begin(), bpe_ids.end());
}
prev_input = input;
}
return ret;
}
Expand Down Expand Up @@ -1930,7 +1931,7 @@ ChatGLM4Tokenizer::ChatGLM4Tokenizer(const std::string &vocab_text) {
observation_token_id = special_tokens_encoder.at("<|observation|>");

const std::string pattern =
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|\s)|\s+)";
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|(\s+)(?:\s)|\s+)";
core_bpe = TiktokenCoreBPE(std::move(mergeable_ranks), std::move(special_tokens_encoder), pattern);
}

Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chatglm_cpp._C as _C
from chatglm_cpp._C import ChatMessage

__version__ = "0.3.3"
__version__ = "0.3.4"


@dataclass
Expand Down
28 changes: 25 additions & 3 deletions chatglm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,7 @@ TEST(Pipeline, ChatGLM4) {
Pipeline pipeline(model_path.string());
ASSERT_TRUE(dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get()));
ASSERT_TRUE(dynamic_cast<ChatGLM4ForCausalLM *>(pipeline.model.get()));
auto tokenizer = dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get());

// const std::string system_tool_call =
// read_text(fs::path(__FILE__).parent_path() / "examples/system/function_call.txt");
Expand All @@ -1346,8 +1347,6 @@ TEST(Pipeline, ChatGLM4) {

// tiktoken
{
auto tokenizer = dynamic_cast<ChatGLM4Tokenizer *>(pipeline.tokenizer.get());

// taken from:
// https://github.com/ggerganov/llama.cpp/blob/4bfe50f741479c1df1c377260c3ff5702586719e/convert-hf-to-gguf.py#L413
const std::string chktxt =
Expand All @@ -1372,7 +1371,30 @@ TEST(Pipeline, ChatGLM4) {
498, 2704, 30, 364, 44, 537, 2704, 358, 3278, 1281, 432, 11, 364, 35,
498, 1075, 1045, 15231, 30, 1205, 6, 42368, 264, 63409, 43};

std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(chktxt);
const std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(chktxt);
EXPECT_EQ(ref_ids, out_ids);
}
{
const std::string text = R"(
```c++
#include <iostream>
int main() {
printf("hello world\n"); // say hello
}
```
```python
if __name__ == '__main__':
print('hello world') # say hello
```
)";
const std::vector<int> ref_ids = {198, 73022, 66, 22879, 1067, 366, 9661, 1339, 396, 1887, 368,
341, 262, 4100, 445, 14978, 1879, 1699, 5038, 262, 442, 1977,
23745, 198, 532, 13865, 19288, 73022, 12663, 198, 333, 1304, 606,
563, 621, 12106, 3817, 16165, 262, 1173, 492, 14978, 1879, 863,
286, 671, 1977, 23745, 198, 13865, 3989};
const std::vector<int> out_ids = tokenizer->core_bpe.encode_ordinary(text);
EXPECT_EQ(ref_ids, out_ids);
}
// tokenizer
Expand Down
20 changes: 18 additions & 2 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,22 @@ def make_glm4_pipeline_data():
chktxt = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````\"\"\"\"......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL"
print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=()))

chktxt = r"""
```c++
#include <iostream>
int main() {
printf("hello world\n"); // say hello
}
```
```python
if __name__ == '__main__':
print('hello world') # say hello
```
"""
print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=()))

# tokenizer
inputs = tokenizer("你好")
print(f"encode: {inputs=}")
Expand Down Expand Up @@ -861,8 +877,8 @@ def main():
# make_data_baichuan7b_model()
# make_data_baichuan13b_model()
# make_internlm_model()
make_data_glm4_model()
# make_glm4_pipeline_data()
# make_data_glm4_model()
make_glm4_pipeline_data()


if __name__ == "__main__":
Expand Down

0 comments on commit c9a4a70

Please sign in to comment.