From c9a4a70d9ac89e7aca360a50270539144c433e7f Mon Sep 17 00:00:00 2001 From: Jiahao Li Date: Fri, 14 Jun 2024 15:52:39 +0800 Subject: [PATCH] Fix regex lookahead for code input tokenization (#314) --- .gitignore | 1 + chatglm.cpp | 23 ++++++++++++----------- chatglm_cpp/__init__.py | 2 +- chatglm_test.cpp | 28 +++++++++++++++++++++++++--- tests/test_convert.py | 20 ++++++++++++++++++-- 5 files changed, 57 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 21c41c2..d35b24a 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ build/ # data /data/ +*.log diff --git a/chatglm.cpp b/chatglm.cpp index a86ae0b..8c94b85 100644 --- a/chatglm.cpp +++ b/chatglm.cpp @@ -1766,6 +1766,7 @@ TiktokenCoreBPE::TiktokenCoreBPE(std::unordered_map encoder, : regex(std::make_unique("(" + pattern + ")")), encoder(std::move(encoder)), special_tokens_encoder(std::move(special_tokens_encoder)) { CHATGLM_CHECK(regex->ok()) << regex->error(); + CHATGLM_CHECK(regex->NumberOfCapturingGroups() <= 2) << "unimplemented"; decoder.reserve(this->encoder.size()); for (const auto &[token, rank] : this->encoder) { @@ -1853,24 +1854,24 @@ std::vector TiktokenCoreBPE::byte_pair_encode(const std::string &piece, std::vector TiktokenCoreBPE::_encode_ordinary_native(const std::string &text) const { std::vector ret; - re2::StringPiece input = text; - re2::StringPiece prev_input = input; + re2::StringPiece input(text); + re2::StringPiece prev_input(input); std::string piece; - while (RE2::FindAndConsume(&input, *regex, &piece)) { - // recover input in case of negative lookahead - if (prev_input.find(piece) == 0) { - input = prev_input.substr(piece.size()); - prev_input = input; - } else { - std::cerr << "[WARN] chatglm.cpp: encounter unknown token\n"; + std::string piece2; + while (RE2::FindAndConsume(&input, *regex, &piece, &piece2)) { + if (!piece2.empty()) { + // workaround for lookahead: capture sub group and restore input + auto pos = prev_input.find(piece2); + input = prev_input.substr(pos + piece2.length()); + piece = std::move(piece2); } - if (auto it = encoder.find(piece); it != encoder.end()) { ret.emplace_back(it->second); } else { std::vector bpe_ids = byte_pair_encode(piece, encoder); ret.insert(ret.end(), bpe_ids.begin(), bpe_ids.end()); } + prev_input = input; } return ret; } @@ -1930,7 +1931,7 @@ ChatGLM4Tokenizer::ChatGLM4Tokenizer(const std::string &vocab_text) { observation_token_id = special_tokens_encoder.at("<|observation|>"); const std::string pattern = - R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|\s)|\s+)"; + R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|(\s+)(?:\s)|\s+)"; core_bpe = TiktokenCoreBPE(std::move(mergeable_ranks), std::move(special_tokens_encoder), pattern); } diff --git a/chatglm_cpp/__init__.py b/chatglm_cpp/__init__.py index 88f5097..f7568fd 100644 --- a/chatglm_cpp/__init__.py +++ b/chatglm_cpp/__init__.py @@ -6,7 +6,7 @@ import chatglm_cpp._C as _C from chatglm_cpp._C import ChatMessage -__version__ = "0.3.3" +__version__ = "0.3.4" @dataclass diff --git a/chatglm_test.cpp b/chatglm_test.cpp index a448230..96590af 100644 --- a/chatglm_test.cpp +++ b/chatglm_test.cpp @@ -1338,6 +1338,7 @@ TEST(Pipeline, ChatGLM4) { Pipeline pipeline(model_path.string()); ASSERT_TRUE(dynamic_cast(pipeline.tokenizer.get())); ASSERT_TRUE(dynamic_cast(pipeline.model.get())); + auto tokenizer = dynamic_cast(pipeline.tokenizer.get()); // const std::string system_tool_call = // read_text(fs::path(__FILE__).parent_path() / "examples/system/function_call.txt"); @@ -1346,8 +1347,6 @@ TEST(Pipeline, ChatGLM4) { // tiktoken { - auto tokenizer = dynamic_cast(pipeline.tokenizer.get()); - // taken from: // https://github.com/ggerganov/llama.cpp/blob/4bfe50f741479c1df1c377260c3ff5702586719e/convert-hf-to-gguf.py#L413 const std::string chktxt = @@ -1372,7 +1371,30 @@ TEST(Pipeline, ChatGLM4) { 498, 2704, 30, 364, 44, 537, 2704, 358, 3278, 1281, 432, 11, 364, 35, 498, 1075, 1045, 15231, 30, 1205, 6, 42368, 264, 63409, 43}; - std::vector out_ids = tokenizer->core_bpe.encode_ordinary(chktxt); + const std::vector out_ids = tokenizer->core_bpe.encode_ordinary(chktxt); + EXPECT_EQ(ref_ids, out_ids); + } + { + const std::string text = R"( +```c++ +#include + +int main() { + printf("hello world\n"); // say hello +} +``` + +```python +if __name__ == '__main__': + print('hello world') # say hello +``` +)"; + const std::vector ref_ids = {198, 73022, 66, 22879, 1067, 366, 9661, 1339, 396, 1887, 368, + 341, 262, 4100, 445, 14978, 1879, 1699, 5038, 262, 442, 1977, + 23745, 198, 532, 13865, 19288, 73022, 12663, 198, 333, 1304, 606, + 563, 621, 12106, 3817, 16165, 262, 1173, 492, 14978, 1879, 863, + 286, 671, 1977, 23745, 198, 13865, 3989}; + const std::vector out_ids = tokenizer->core_bpe.encode_ordinary(text); EXPECT_EQ(ref_ids, out_ids); } // tokenizer diff --git a/tests/test_convert.py b/tests/test_convert.py index 668e53b..9a1e31e 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -803,6 +803,22 @@ def make_glm4_pipeline_data(): chktxt = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````\"\"\"\"......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL" print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=())) + chktxt = r""" +```c++ +#include + +int main() { + printf("hello world\n"); // say hello +} +``` + +```python +if __name__ == '__main__': + print('hello world') # say hello +``` +""" + print("tiktoken:", tokenizer.tokenizer.encode(chktxt, disallowed_special=())) + # tokenizer inputs = tokenizer("你好") print(f"encode: {inputs=}") @@ -861,8 +877,8 @@ def main(): # make_data_baichuan7b_model() # make_data_baichuan13b_model() # make_internlm_model() - make_data_glm4_model() - # make_glm4_pipeline_data() + # make_data_glm4_model() + make_glm4_pipeline_data() if __name__ == "__main__":