From 0ce783b7499fc94c1b15f0e00a1da3cca7e82674 Mon Sep 17 00:00:00 2001 From: Saibo-creator <53392976+Saibo-creator@users.noreply.github.com> Date: Tue, 27 Aug 2024 15:54:04 -0700 Subject: [PATCH] Unified cli (#86) * feat: **Commit Message:** Add CLI support for grammar-constrained generation and remove old examples - Introduced sub-command in CLI for text generation with grammar constraints. - Added options for 4-bit and 8-bit model loading using bitsandbytes. - Removed obsolete example scripts for generating C code, JSON arrays, and relation extraction triples. - Refactored to include device handling. - Updated to enhance device management within the grammar-constrained generation process. * feat: Refactor grammar-constrained generation scripts and CLI options - Updated with new examples for JSON, code generation, semantic parsing, and Unicode support. - Removed outdated scripts (, , , , , , , ) and merged functionalities into . - Enhanced grammar examples, including more detailed entities and relations in . - Added new CLI arguments (, ) in to control contrast mode and save output to a file. - Improved unicode detection in by adding a static method to automatically detect Unicode in grammar strings. - Removed the argument from various recognizer classes as it is now automatically detected. - Added tests for Unicode detection in . * doc: Update README with CLI example for JSON generation and improve documentation structure - Added a command-line example for generating a valid JSON object using . - Updated descriptions for generating JSON objects with minimal changes to HF code. - Improved section summaries for examples using the HF pipeline API. - Renamed to to better reflect its purpose and location in the project structure. * feat: Refactor and improve documentation and CLI interface - **README Enhancements:** - Improved structure, readability, and consistency across sections. - Updated the Quick Start section with examples for generating JSON objects using . - Expanded explanations for grammar use cases and clarified documentation on automatic JSON schema grammar conversion. - Updated the list of supported models and provided better guidance for advanced grammar debugging. - **CLI Improvements:** - Simplified CLI prompt argument by renaming to . - Added device selection options (, ) for model execution. - **Code Cleanup:** - Removed outdated script, consolidating its functionality within the CLI and README examples. - **Miscellaneous:** - Removed outdated comments and TODOs from the CLI code. --- README.md | 188 ++++++++++-------- examples/demo.sh | 117 +++++++++++ examples/generate.py | 73 ------- examples/generate_arabic.py | 45 ----- examples/generate_cIE.py | 50 ----- examples/generate_c_code.py | 58 ------ examples/generate_calflow.py | 111 ----------- examples/generate_chinese.py | 45 ----- examples/generate_emoji.py | 49 ----- examples/generate_japanese.py | 49 ----- examples/generate_json_array.py | 46 ----- examples/generate_korean.py | 45 ----- examples/generate_overnight.py | 116 ----------- examples/generate_russian.py | 45 ----- examples/grammars/Information_extraction.ebnf | 8 - examples/grammars/cIE.ebnf | 10 +- tests/test_hf_generation/test_generation.py | 2 +- .../test_unicode_generation.py | 18 +- transformers_cfg/cli/cli_main.py | 182 ++++++++++++++++- transformers_cfg/generation/logits_process.py | 22 +- transformers_cfg/token_grammar_recognizer.py | 40 +++- 21 files changed, 465 insertions(+), 854 deletions(-) create mode 100644 examples/demo.sh delete mode 100644 examples/generate.py delete mode 100644 examples/generate_arabic.py delete mode 100644 examples/generate_cIE.py delete mode 100644 examples/generate_c_code.py delete mode 100644 examples/generate_calflow.py delete mode 100644 examples/generate_chinese.py delete mode 100644 examples/generate_emoji.py delete mode 100644 examples/generate_japanese.py delete mode 100644 examples/generate_json_array.py delete mode 100644 examples/generate_korean.py delete mode 100644 examples/generate_overnight.py delete mode 100644 examples/generate_russian.py delete mode 100644 examples/grammars/Information_extraction.ebnf diff --git a/README.md b/README.md index bdb38f0..ade12b2 100644 --- a/README.md +++ b/README.md @@ -3,56 +3,80 @@ ![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) -## πŸ’­Latest News +## πŸ’­ Latest News + +- **[Gemma-2 Support](https://github.com/epfl-dlab/transformers-CFG/pull/75)** β€” Thanks to @fillassuncao (2024-08-16) +- **[DeepSeek Support](https://github.com/epfl-dlab/transformers-CFG/pull/73)** (2024-07-24) +- **LLAMA-3 Support** (2024-07-08) +- **[JSON Schema as Constraint Support](examples%2Fgrammars%2Fcustom_json_grammars%2FREADME.md)** (2024-05-13) +- **[Token Masking Optimization](#efficiency)** (2024-04-25) +- **[Phi Support](https://github.com/epfl-dlab/transformers-CFG/issues/34)** (2024-04-16) +- **[Online Demo with JSON Grammar](http://saibo-creator.xyz:7860/) at HF Space** (2024-04-10) +- **Unicode (Multilingual) Grammar Support** (2024-02-29) +- **Integration with Text-Generation-WebUI** (2023-12-17) -- **Support for [Gemma-2](https://github.com/epfl-dlab/transformers-CFG/pull/75) Thanks to @fillassuncao** (2024-08-16) +We are thrilled to announce that `transformers-cfg` has been integrated into the [Text-Generation-WebUI](https://github.com/oobabooga/text-generation-webui) project, enabling users to utilize our CFG capabilities within this popular web interface for text generation. For more details, see the [relevant pull request](https://github.com/oobabooga/text-generation-webui/pull/4953). -- **Support for [DeepSeek](https://github.com/epfl-dlab/transformers-CFG/pull/73)** (2024-07-24) +## πŸš€ Introduction -- **Support for LLAMA-3** (2024-07-08) +`transformers-cfg` is an extension library for the popular Transformers library by Hugging Face, tailored for working with context-free grammars (CFG). This package provides additional tools and functionalities to enhance your experience with natural language processing tasks involving CFGs. -- **support [JSON Schema as constraint](examples%2Fgrammars%2Fcustom_json_grammars%2FREADME.md)**(2024-05-13) +Initially developed as a pull request to the [Hugging Face Transformers](https://github.com/huggingface/transformers) library, you can find the relevant discussion [here](https://github.com/huggingface/transformers/pull/27557). -- **[Token masking optimization](#efficiency)(** (2024-04-25) +## πŸ’» Installation -- **[Support for Phi](https://github.com/epfl-dlab/transformers-CFG/issues/34)** (2024-04-16) +- **Stable Version:** -- **Online [Demo with JSON Grammar](http://saibo-creator.xyz:7860/) at HF space** (2024-04-10) + Install the stable version of `transformers-cfg` using pip: -- **Support for Unicode(multilingual) grammars** (2024-02-29) + ```bash + pip install transformers-cfg + ``` -- **Integration with Text-Generation-WebUI** (2023-12-17) +- **Development Version:** -We are thrilled to announce that `transformers_cfg` has been used in the [Text-Generation-WebUI](https://github.com/oobabooga/text-generation-webui) project. -This integration enables users to utilize our CFG capabilities within the popular, 30.5K-starred web interface for text generation. -For more details, see [Relevent Pull Request](https://github.com/oobabooga/text-generation-webui/pull/4953) + For the latest code and updates, install directly from the GitHub repository: + ```bash + pip install git+https://github.com/epfl-dlab/transformers-CFG.git@main + ``` -## πŸš€Introduction -`transformers_cfg` is an extension library for the popular Transformers library by Hugging Face, tailored for working with context-free grammars (CFG). -This package provides additional tools and functionalities to enhance your experience with natural language processing tasks involving CFGs. + This installs the package from the `main` branch. -It was initially developed as a pull request to the [Hugging Face Transformers](https://github.com/huggingface/transformers) library. -See relevant discussion [here](https://github.com/huggingface/transformers/pull/27557). +## πŸ”§ Quick Start: Force LLM to Generate a Valid JSON Object -## πŸ’» Installation +### Command-Line Interface -- You can install the stable version of `transformers-cfg` using pip: +`transformers-cfg-cli` is a command-line tool that allows you to generate text using a model and a grammar. You can specify the model, grammar, prompts, and other parameters to generate text that conforms to the specified grammar. ```bash -pip install transformers-cfg +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/json.ebnf" \ + -p "This is a valid json string for http request:" \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 +# {"name":"John","age":30,"car":null} ``` -- For the latest code and updates, you can install directly from the GitHub repository: +We support also Unicode characters in the grammar: +```bash +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/chinese.ebnf" \ + -p "Translate the following sentence into Chinese: My neighbor is a very nice person. -> " \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 ``` -pip install git+https://github.com/epfl-dlab/transformers-CFG.git@main -``` -This will install the package directly from the `main` branch of the repository. -## πŸ”§QuickStart: Force LLM to generate a valid json object +`transformers-cfg-cli generate --help` provides a list of available options and arguments. + -The below example can be found in `examples/generate_json.py` +
+Click here to see an example of generating a JSON object with minimal changes to HF code, or check it out in examples/generate_json.py ```python import torch @@ -71,21 +95,18 @@ if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained(model_id).to( - device - ) # Load model to defined device + model = AutoModelForCausalLM.from_pretrained(model_id).to(device) model.generation_config.pad_token_id = model.generation_config.eos_token_id - # Load json grammar + # Load JSON grammar with open("examples/grammars/json.ebnf", "r") as file: grammar_str = file.read() grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) grammar_processor = GrammarConstrainedLogitsProcessor(grammar) # Generate - prefix1 = "This is a valid json string for http request:" - prefix2 = "This is a valid json string for shopping cart:" - input_ids = tokenizer([prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"] + prompts = ["This is a valid json string for http request:", "This is a valid json string for shopping cart:"] + input_ids = tokenizer(prompts, add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"] output = model.generate( input_ids, @@ -94,32 +115,34 @@ if __name__ == "__main__": repetition_penalty=1.1, num_return_sequences=1, ) - # decode output + # Decode output generations = tokenizer.batch_decode(output, skip_special_tokens=True) print(generations) """ - 'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }} - 'This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 } + 'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}' + 'This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }' """ ``` -Alternatively, you can use `transformers-cfg` to perform grammar-constrained decoding with huggingface pipeline. +
-Click here to see an example, or check it out in `examples/pipeline_json.py` +Click here to see an example with HF pipeline API, or check it out in examples/pipeline_json.py ```python +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor + # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.pad_token = tokenizer.eos_token -# Load model to defined device model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # Load grammar with open(f"examples/grammars/json.ebnf", "r") as file: grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) grammar_processor = GrammarConstrainedLogitsProcessor(grammar) @@ -142,21 +165,23 @@ generations = pipe( logits_processor=[grammar_processor], ) ``` -
+ -## πŸ’‘Why should I use transformers-CFG? +## πŸ’‘ Why Should I Use `transformers-cfg`? -- We support EBNF grammar description format -- We offer the same grammar interface as llama-cpp project, allowing you to drop-in replace llama-cpp with transformers-CFG. -- We allow you to use any of the models in the πŸ€— Transformers library, including the ones that are not supported by llama-cpp. -- We support multilingual grammars, you can use any character from any language in your grammar, e.g. δΈ­ζ–‡, ζ—₯本θͺž, ν•œκ΅­μ–΄, ΰ€Ήΰ€Ώΰ€¨ΰ₯ΰ€¦ΰ₯€, Ψ§Ω„ΨΉΨ±Ψ¨ΩŠΨ©, Χ’Χ‘Χ¨Χ™Χͺ, or emoji πŸ€—. +- **EBNF Grammar Support:** We support the Extended Backus-Naur Form (EBNF) for grammar description. +- **Seamless Integration:** Our grammar interface is compatible with the llama-cpp project, allowing you to replace llama-cpp with `transformers-cfg` easily. +- **Model Compatibility:** Use any model from the πŸ€— Transformers library, including those not supported by llama-cpp. +- **Multilingual Grammar Support:** We support grammars in multiple languages, allowing you to use characters from various languages, including δΈ­ζ–‡, ζ—₯本θͺž, ν•œκ΅­μ–΄, ΰ€Ήΰ€Ώΰ€¨ΰ₯ΰ€¦ΰ₯€, Ψ§Ω„ΨΉΨ±Ψ¨ΩŠΨ©, Χ’Χ‘Χ¨Χ™Χͺ, and emoji πŸ€—. -## πŸ€”What is grammar ? +## πŸ€” What Is a Grammar? TL;DR: Think of it as an enhanced version of regular expressions. -Here is an example of a simplified JSON grammar: +
+Here is a simple example of a JSON grammar: + ```bnf # A JSON object is the root of the grammar root ::= object @@ -174,65 +199,56 @@ string ::= '"' [a-zA-Z0-9]* '"' value ::= string | object | "true" | "false" | "null" ``` -This grammar describes the structure of a JSON object. It specifies that a JSON object is a pair of key-value pairs, where the key is a string and the value can be a string, another object, or a boolean value. +This grammar describes the structure of a JSON object. It specifies that a JSON object consists of key-value pairs, where the key is a string, and the value can be a string, another object, or a boolean value. + +You can use grammars to describe simple but useful constructs, such as valid email addresses, URLs, or phone numbers: -Grammar doesn't need to be complicated. -You can use it to describe very simple but useful things, like a valid email address, a valid URL, or phone number. ``` phone_number ::= "+" [0-9]+ ``` -You can also force it to [generate only emojis](examples/generate_emoji.py) or [generate only korean characters](examples/generate_korean.py). -``` -['Describe your feeling with emoji: πŸ™ŒπŸ™‚πŸ˜πŸ˜―πŸ˜…πŸ™πŸ™‡πŸ™ˆπŸ™ŠπŸ™‹πŸ™ƒπŸ™†πŸ™…πŸ™„πŸ™πŸ™‚πŸ™€πŸ™‰πŸ™ŽπŸ™ŠπŸ™‹πŸ™ƒπŸ™†πŸ™…πŸ™„πŸ™πŸ™‚πŸ™€πŸ™‰πŸ™ŽπŸ™ŠπŸ™‹πŸ™ƒπŸ™†', 'Write a poem with emoji: πŸ™πŸ˜πŸ™πŸ™πŸ™ŒπŸ™πŸ™πŸ™πŸ™πŸ˜πŸ˜…πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™‡πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™‹πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™'] -``` +
-More details can be found in this [doc from llama-cpp](https://github.com/ggerganov/llama.cpp/tree/master/grammars) -Advanced grammar debugging guide can be found [here](docs/debugging_custom_grammars.md) +For advanced grammar debugging, check out our [debugging guide](docs/debugging_custom_grammars.md). -### Automatic Grammar Generation +## Automatic JSON Schema Grammar Conversion[Experimental] -You can use custom grammars to constrain the output of a language model. -Check out the [documentation](examples%2Fgrammars%2Fcustom_json_grammars%2FREADME.md) on json schema to grammar conversion to learn how to automatically create custom grammars for complex json objects. +Learn how to automatically create custom grammars for complex JSON objects in our [documentation](examples%2Fgrammars%2Fcustom_json_grammars%2FREADME.md) on JSON schema to grammar conversion. -### Grammar Collection +## Grammar Collection -We provide a collection of grammars in the `examples/grammars` folder, which are mostly identical to the grammars in llama-cpp project. -We try to keep the grammars up-to-date with the original grammars from llama-cpp project. -But up to now, we can not yet guarantee that all grammars from llama-cpp project can be directly used in transformers-CFG. +We provide a collection of grammars in the `examples/grammars` folder, which are mostly identical to the grammars in the llama-cpp project. We try to keep these grammars up-to-date with the original project, though we cannot yet guarantee that all grammars from llama-cpp can be directly used in `transformers-cfg`. -The list of grammars contains: -- [json.ebnf](examples%2Fgrammars%2Fjson.ebnf): A grammar for generating valid json objects. -- [json_arr.ebnf](examples%2Fgrammars%2Fjson_arr.ebnf): A grammar for generating valid json arrays. -- [c.ebnf](examples%2Fgrammars%2Fc.ebnf): A grammar for generating valid C programs. -- [chess.ebnf](examples%2Fgrammars%2Fchess.ebnf): A grammar for generating valid chess moves. -- [arithmetic.ebnf](examples%2Fgrammars%2Farithmetic.ebnf): A grammar for generating valid arithmetic expressions. +Available grammars include: +- [json.ebnf](examples%2Fgrammars%2Fjson.ebnf): For generating valid JSON objects. +- [json_arr.ebnf](examples%2Fgrammars%2Fjson_arr.ebnf): For generating valid JSON arrays. +- [c.ebnf](examples%2Fgrammars%2Fc.ebnf): For generating valid C programs. +- [chess.ebnf](examples%2Fgrammars%2Fchess.ebnf): For generating valid chess moves. +- [arithmetic.ebnf](examples%2Fgrammars%2Farithmetic.ebnf): For generating valid arithmetic expressions. ## Supported Models -- [LLaMa family models](https://huggingface.co/baffo32/decapoda-research-llama-7B-hf) -- [GPT family models](https://huggingface.co/openai-community/gpt2) -- [Bloom family models](https://huggingface.co/bigscience/bloom) -- [Mistral family models](https://huggingface.co/mistralai/Mistral-7B-v0.1) -- [Falcon family models](https://huggingface.co/tiiuae/falcon-7b) +- [LLaMa Family Models](https://huggingface.co/baffo32/decapoda-research-llama-7B-hf) +- [GPT Family Models](https://huggingface.co/openai-community/gpt2) +- [Bloom Family Models](https://huggingface.co/bigscience/bloom) +- [Mistral Family Models](https://huggingface.co/mistralai/Mistral-7B-v0.1) +- [Falcon Family Models](https://huggingface.co/tiiuae/falcon-7b) - ... See [supported_models.yaml](docs%2Fsupported_models.yaml) for the full list of supported models. -As a rule of thumb, all models with the same tokenizer should naturally be supported. -If you find any model that is not supported, please open an issue or submit a pull request. +As a rule of thumb, all models with the same tokenizer should be naturally supported. -## Efficiency -Our update in the `transformers_cfg` library has significantly improved the performance of grammar-constrained decoding (especially for complicated grammars). +If you find any model that is not supported, please open an issue or submit a pull request. - ## Citation -**Please consider citing our work, if you found the provided resources useful.**
-``` +**Please consider citing our work if you find the provided resources useful:** + +```bibtex @inproceedings{geng-etal-2023-grammar, title = {Grammar-Constrained Decoding for Structured {NLP} Tasks without Finetuning}, author = {Geng, Saibo and Josifoski, Martin and Peyrard, Maxime and West, Robert}, @@ -246,10 +262,10 @@ Our update in the `transformers_cfg` library has significantly improved the perf } ``` - ## License + This project is licensed under the [MIT License](LICENSE). -## Acknowledgement +## Acknowledgements -This project is derived from the [torch-grammars](https://github.com/Shopify/torch-grammar) project, which was derived from the [llama-cpp](https://github.com/ggerganov/llama.cpp) project. +This project is derived from the [torch-grammars](https://github.com/Shopify/torch-grammar) project, which was itself derived from the [llama-cpp](https://github.com/ggerganov/llama.cpp) project. diff --git a/examples/demo.sh b/examples/demo.sh new file mode 100644 index 0000000..266d31c --- /dev/null +++ b/examples/demo.sh @@ -0,0 +1,117 @@ + +################ +# +# JSON generation: object and array +# +################ + +# generate json object +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/json.ebnf" \ + -p "This is a valid json string for http request:" \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + +# generate json array + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/json_arr.ebnf" \ + -p "Put my shopping list into a json array:" \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + +################ +# +# Code generation: Python, C +# +################ + +# generate C code +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/c.ebnf" \ + -p "#include \n" \ + --use_4bit \ + --max_new_tokens 20 \ + --repetition_penalty 3.0 + +################ +# +# NLP tasks: relation extraction +# +################ + +# generate relation extraction triples +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/cIE.ebnf" \ + -p "Extract relations from the following sentence: RenΓ© Descartes was a French philosopher, scientist, and mathematician" \ + --use_8bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + + +################ +# +# Semantic parsing: CalFlow, GeoQuery, overnight, etc. +# +################ + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/calflow.ebnf" \ + -p 'Generate 3 CalFlow strings: 1.(Yield (toRecipient (CurrentUser))) 2.(Yield (CreateCommitEventWrapper (CreatePreflightEventWrapper (Event.subject_? (?= "choose the meeting"))))) 3.' \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/geo_query.ebnf" \ + -p "Translate the following sentence into GeoQuery: What is the population of the largest city in California?" \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/overnight.ebnf" \ + -p """Translate natural language to DSL: + Q: which brick is no wider than 3 inches + A: listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty width) <= (ensureNumericEntity 3 en.inch))) + Q: which block is above block 1 + A: (listValue (filter (filter (getProperty (singleton en.block) !type) (reverse above) = en.block.block1) above = en.block.block1)) + Q: what block is longer than 3 inches + A: """ \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + + + +################ +# +# Unicode support, Chinese, Emoji, etc. +# +################ + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/chinese.ebnf" \ + -p "Translate the following sentence into Chinese: My neighbor is a very nice person. -> " \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 + + +transformers-cfg-cli generate \ + -m "microsoft/Phi-3-mini-4k-instruct" \ + -g "examples/grammars/emoji.ebnf" \ + -p "Translate the following sentence into emoji: I am very happy today. -> " \ + --use_4bit \ + --max_new_tokens 60 \ + --repetition_penalty 1.1 diff --git a/examples/generate.py b/examples/generate.py deleted file mode 100644 index 6fddaeb..0000000 --- a/examples/generate.py +++ /dev/null @@ -1,73 +0,0 @@ -import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(args): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(args.model_id) - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained(args.model_id) - - # Load grammar - with open(args.grammar_file_path, "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint( - grammar_str, "root", tokenizer, unicode=False - ) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix = args.prefix_prompt - input_ids = tokenizer( - prefix, add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate text with grammar constraints." - ) - parser.add_argument( - "-m", - "--model_id", - type=str, - required=True, - help="Model identifier for loading the tokenizer and model", - default="gpt2", - ) - parser.add_argument( - "-g", - "--grammar_file_path", - type=str, - required=True, - help="Path to the grammar file (supports both relative and absolute paths)", - ) - parser.add_argument( - "-p", - "--prefix_prompt", - type=str, - required=True, - help="Prefix prompt for generation", - ) - - args = parser.parse_args() - main(args) diff --git a/examples/generate_arabic.py b/examples/generate_arabic.py deleted file mode 100644 index 6f89096..0000000 --- a/examples/generate_arabic.py +++ /dev/null @@ -1,45 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device - - # Load grammar - with open("examples/grammars/arabic.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "English: coffee, Arabic: " - prefix2 = "English: dog, Arabic: " - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - main() diff --git a/examples/generate_cIE.py b/examples/generate_cIE.py deleted file mode 100644 index ef10255..0000000 --- a/examples/generate_cIE.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - - -if __name__ == "__main__": - - # Detect if GPU is available, otherwise use CPU - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2").to( - device - ) # Load model to defined device - - # Load grammar - with open("examples/grammars/cIE.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "This is a valid json string for http request:" - prefix2 = "This is a valid json string for shopping cart:" - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"].to( - device - ) # Move input_ids to the same device as model - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=60, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - print(generations) - - """ - 'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }} - 'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 } - """ diff --git a/examples/generate_c_code.py b/examples/generate_c_code.py deleted file mode 100644 index 30b7b16..0000000 --- a/examples/generate_c_code.py +++ /dev/null @@ -1,58 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -############################################################ -# -# use llama to generate C code -# -############################################################ - - -def main(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - model_id = "mistralai/Mistral-7B-v0.1" - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token - - model = AutoModelForCausalLM.from_pretrained(model_id).to( - device - ) # Load model to defined device - model.generation_config.pad_token_id = model.generation_config.eos_token_id - # Load grammar - with open("examples/grammars/c.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - # Generate - prefix1 = "#include \n" - input_ids = tokenizer( - [prefix1], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"].to( - device - ) # Move input_ids to the same device as model - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=3.0, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - print(generations) - - """ - #include - int thresh_f(int n){return (1-threshold);} - """ - - -if __name__ == "__main__": - main() diff --git a/examples/generate_calflow.py b/examples/generate_calflow.py deleted file mode 100644 index edf18e2..0000000 --- a/examples/generate_calflow.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.recognizer import StringRecognizer -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor -from transformers_cfg.parser import parse_ebnf - - -def parse_args(): - parser = argparse.ArgumentParser(description="Generate calflow strings") - parser.add_argument( - "--model-id", - type=str, - default="/dlabdata1/llm_hub/Mistral-7B-v0.1", - help="Model ID", - ) - parser.add_argument("--device", type=str, help="Device to put the model on") - return parser.parse_args() - - -def main(): - args = parse_args() - model_id = args.model_id - - # Detect if GPU is available, otherwise use CPU - device = torch.device( - args.device or ("cuda" if torch.cuda.is_available() else "cpu") - ) - print(f"Using device: {device}") - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token - # Load model to defined device - model = AutoModelForCausalLM.from_pretrained(model_id).to(device) - - # Load grammar - with open(f"examples/grammars/calflow.ebnf", "r") as file: - grammar_str = file.read() - - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prompts = [ - 'Generate 3 CalFlow strings: 1.(Yield (toRecipient (CurrentUser))) 2.(Yield (CreateCommitEventWrapper (CreatePreflightEventWrapper (Event.subject_? (?= "choose the meeting"))))) 3.' - ] - - input_ids = tokenizer( - prompts, add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"].to( - device - ) # Move input_ids to the same device as model - - n_examples = input_ids.shape[0] - - max_new_tokens = 50 - unconstrained_output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=max_new_tokens, - repetition_penalty=1.9, - num_return_sequences=1, - ) - constrained_output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=max_new_tokens, - logits_processor=[grammar_processor], - repetition_penalty=1.9, - num_return_sequences=1, - ) - - # decode outputs (possibly of different lengths across decoding modes) - generations = tokenizer.batch_decode( - unconstrained_output, skip_special_tokens=True - ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) - - parsed_grammar = parse_ebnf(grammar_str) - string_grammar = StringRecognizer( - parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] - ) - - print() - for i in range(n_examples): - print(f"Unconstrained: {generations[i]}") - constrained_generation = generations[i + n_examples] - print(f"Constrained: {constrained_generation}") - print( - f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" - ) - print( - f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" - ) - print() - - -if __name__ == "__main__": - main() - - -########################## -# Example output: -# -# Unconstrained: Generate 3 CalFlow strings: 1.(Yield (toRecipient (CurrentUser))) 2.(Yield (CreateCommitEventWrapper (CreatePreflightEventWrapper (Event.subject_? (?= "choose the meeting"))))) 3.((yielder) ((reciever)) (((event-type)? ("create")(("prefight" ?)))) -# ``` -# Constrained: Generate 3 CalFlow strings: 1.(Yield (toRecipient (CurrentUser))) 2.(Yield (CreateCommitEventWrapper (CreatePreflightEventWrapper (Event.subject_? (?= "choose the meeting"))))) 3.(Yield (Path.apply "create")) -# The constrained generation matches the grammar: True -# The generated prefix matches the grammar: True -########################## diff --git a/examples/generate_chinese.py b/examples/generate_chinese.py deleted file mode 100644 index c1cf9f3..0000000 --- a/examples/generate_chinese.py +++ /dev/null @@ -1,45 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device - - # Load grammar - with open("examples/grammars/chinese.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "English: coffee, Chinese: " - prefix2 = "English: dog, Chinese: " - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - main() diff --git a/examples/generate_emoji.py b/examples/generate_emoji.py deleted file mode 100644 index 81059eb..0000000 --- a/examples/generate_emoji.py +++ /dev/null @@ -1,49 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device - - # Load grammar - with open("examples/grammars/emoji.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "Describe your feeling with emoji: " - prefix2 = "Write a poem with emoji: " - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=100, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - print(generations) - - """ - ['Describe your feeling with emoji: πŸ™ŒπŸ™‚πŸ˜πŸ˜―πŸ˜…πŸ™πŸ™‡πŸ™ˆπŸ™ŠπŸ™‹πŸ™ƒπŸ™†πŸ™…πŸ™„πŸ™πŸ™‚πŸ™€πŸ™‰πŸ™ŽπŸ™ŠπŸ™‹πŸ™ƒπŸ™†πŸ™…πŸ™„πŸ™πŸ™‚πŸ™€πŸ™‰πŸ™ŽπŸ™ŠπŸ™‹πŸ™ƒπŸ™†', 'Write a poem with emoji: πŸ™πŸ˜πŸ™πŸ™πŸ™ŒπŸ™πŸ™πŸ™πŸ™πŸ˜πŸ˜…πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™‡πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™‹πŸ™πŸ™πŸ™πŸ™πŸ™πŸ™'] - - """ - - -if __name__ == "__main__": - main() diff --git a/examples/generate_japanese.py b/examples/generate_japanese.py deleted file mode 100644 index c765b99..0000000 --- a/examples/generate_japanese.py +++ /dev/null @@ -1,49 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained( - "JackFram/llama-68m" - ) # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained( - "JackFram/llama-68m" - ) # Load model to defined device - - # Load grammar - with open("examples/grammars/japanese.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" - prefix2 = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - main() diff --git a/examples/generate_json_array.py b/examples/generate_json_array.py deleted file mode 100644 index 9b09401..0000000 --- a/examples/generate_json_array.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -if __name__ == "__main__": - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"Using device: {device}") - - model_id = "mistralai/Mistral-7B-v0.1" - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token - - model = AutoModelForCausalLM.from_pretrained(model_id).to( - device - ) # Load model to defined device - model.generation_config.pad_token_id = model.generation_config.eos_token_id - - # Load grammar - with open("examples/grammars/json_arr.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "This is a valid json array for student records:" - prefix2 = "This is a valid json array for shopping cart:" - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"].to( - device - ) # Move input_ids to the same device as model - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=60, - logits_processor=[grammar_processor], - repetition_penalty=1.0, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - print(generations) diff --git a/examples/generate_korean.py b/examples/generate_korean.py deleted file mode 100644 index fe07594..0000000 --- a/examples/generate_korean.py +++ /dev/null @@ -1,45 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device - - # Load grammar - with open("examples/grammars/korean.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "English: coffee, Korean: " - prefix2 = "English: dog, Korean: " - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - main() diff --git a/examples/generate_overnight.py b/examples/generate_overnight.py deleted file mode 100644 index 0bcd5c0..0000000 --- a/examples/generate_overnight.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -import argparse -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.recognizer import StringRecognizer -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor -from transformers_cfg.parser import parse_ebnf - - -def parse_args(): - parser = argparse.ArgumentParser(description="Generate overnight blocks strings") - parser.add_argument( - "--model-id", - type=str, - default="/dlabdata1/llm_hub/Mistral-7B-v0.1", - help="Model ID", - ) - parser.add_argument("--device", type=str, help="Device to put the model on") - return parser.parse_args() - - -def main(): - args = parse_args() - model_id = args.model_id - - # Detect if GPU is available, otherwise use CPU - device = torch.device( - args.device or ("cuda" if torch.cuda.is_available() else "cpu") - ) - print(f"Using device: {device}") - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.pad_token = tokenizer.eos_token - # Load model to defined device - model = AutoModelForCausalLM.from_pretrained(model_id).to(device) - - # Load grammar - with open(f"examples/grammars/overnight.ebnf", "r") as file: - grammar_str = file.read() - - parsed_grammar = parse_ebnf(grammar_str) - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prompts = [ - "Translate natural language to DSL:\n" - + "Q: which brick is no wider than 3 inches\n" - + "A: listValue (filter (getProperty (singleton en.block) !type) (ensureNumericProperty width) <= (ensureNumericEntity 3 en.inch)))\n" - + "Q: which block is above block 1\n" - + "A: (listValue (filter (filter (getProperty (singleton en.block) !type) (reverse above) = en.block.block1) above = en.block.block1))\n" - + "Q: what block is longer than 3 inches\n" - + "A: " - ] - - input_ids = tokenizer( - prompts, add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"].to( - device - ) # Move input_ids to the same device as model - - n_examples = input_ids.shape[0] - - max_new_tokens = 50 - unconstrained_output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=max_new_tokens, - repetition_penalty=1.9, - num_return_sequences=1, - ) - constrained_output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=max_new_tokens, - logits_processor=[grammar_processor], - repetition_penalty=1.9, - num_return_sequences=1, - ) - - parsed_grammar = parse_ebnf(grammar_str) - string_grammar = StringRecognizer( - parsed_grammar.grammar_encoding, parsed_grammar.symbol_table["root"] - ) - - # decode outputs (possibly of different lengths across decoding modes) - generations = tokenizer.batch_decode( - unconstrained_output, skip_special_tokens=True - ) + tokenizer.batch_decode(constrained_output, skip_special_tokens=True) - print() - for i in range(n_examples): - print(f"Unconstrained: {generations[i]}") - constrained_generation = generations[i + n_examples] - print(f"Constrained: {generations[i + n_examples]}") - print( - f"The constrained generation matches the grammar: {string_grammar._accept_string(constrained_generation[len(prompts[i]):])}" - ) - print( - f"The generated prefix matches the grammar: {string_grammar._accept_prefix(constrained_generation[len(prompts[i]):])}" - ) - print() - - -if __name__ == "__main__": - main() - - -########################## -# Example output: -# -# Unconstrained: how many states border colorado and border new mexico ? 1. -# - How long is the drive from denver to albuquerque? The distance between Denver, Colorado (CO) & Alburqueque New Mexico(NM). Driving directions for your road trip or vacation: Get driving -# Constrained: how many states border colorado and border new mexico ? answer(smallest_one(area_1(stateid('colorado')))) -# -########################## diff --git a/examples/generate_russian.py b/examples/generate_russian.py deleted file mode 100644 index 8b041ab..0000000 --- a/examples/generate_russian.py +++ /dev/null @@ -1,45 +0,0 @@ -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers_cfg.grammar_utils import IncrementalGrammarConstraint -from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor - -import logging - -logging.basicConfig(level=logging.DEBUG) - - -def main(): - - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained("gpt2") # JackFram/llama-68m" - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained("gpt2") # Load model to defined device - - # Load grammar - with open("examples/grammars/russian.ebnf", "r") as file: - grammar_str = file.read() - grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer, unicode=True) - grammar_processor = GrammarConstrainedLogitsProcessor(grammar) - - # Generate - prefix1 = "English: coffee, Russian: " - prefix2 = "English: dog, Russian: " - input_ids = tokenizer( - [prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True - )["input_ids"] - - output = model.generate( - input_ids, - do_sample=False, - max_new_tokens=20, - logits_processor=[grammar_processor], - repetition_penalty=1.1, - num_return_sequences=1, - ) - # decode output - generations = tokenizer.batch_decode(output, skip_special_tokens=True) - - print(generations) - - -if __name__ == "__main__": - main() diff --git a/examples/grammars/Information_extraction.ebnf b/examples/grammars/Information_extraction.ebnf deleted file mode 100644 index d865b06..0000000 --- a/examples/grammars/Information_extraction.ebnf +++ /dev/null @@ -1,8 +0,0 @@ -root ::= triplet ( delim triplet )* -triplet ::= "[s] " subject " [r] " predicate " [o] " object -delim ::= " [e] " -subject ::= entity -predicate ::= relation -object ::= entity -entity ::= "entity1" | "entity2" | "entity3" | "entity4" -relation ::= "relation1" | "relation2" | "relation3" | "relation4" diff --git a/examples/grammars/cIE.ebnf b/examples/grammars/cIE.ebnf index d865b06..c1815e4 100644 --- a/examples/grammars/cIE.ebnf +++ b/examples/grammars/cIE.ebnf @@ -1,8 +1,10 @@ +# This is just for illustration purposes. Depending on the actual use case, the set of entities and relations can be extended to include more entities and relations. root ::= triplet ( delim triplet )* triplet ::= "[s] " subject " [r] " predicate " [o] " object delim ::= " [e] " -subject ::= entity +subject ::= subject_entity predicate ::= relation -object ::= entity -entity ::= "entity1" | "entity2" | "entity3" | "entity4" -relation ::= "relation1" | "relation2" | "relation3" | "relation4" +object ::= object_entity +subject_entity ::= "Rene Descartes" | "Isaac Newton" | "Albert Einstein" | "Stephen Hawking" | "Galileo Galilei" | "Nikola Tesla" | "Leonardo da Vinci" | "Aristotle" | "Plato" | "Socrates" | "Pythagoras" | "Euclid" | "Archimedes" | "Hippocrates" | "Ptolemy" | "Nicolaus Copernicus" | "Johannes Kepler" | "Galileo Galilei" | "Isaac Newton" | "Albert Einstein" | "Stephen Hawking" | "Nikola Tesla" | "Leonardo da Vinci" | "Aristotle" +object_entity ::= "France" | "England" | "Germany" | "Italy" | "Greece" | "Egypt" | "China" | "India" | "Russia" | "USA" | "Canada" | "Brazil" | "Australia" | "Japan" | "South Africa" | "Mexico" | "Argentina" | "Spain" | "Portugal" | "Netherlands" | "Belgium" | "Sweden" | "Norway" | "Denmark" | "Finland" | "Poland" | "Czech Republic" | "Slovakia" | "Hungary" | "Romania" | "Bulgaria" | "Greece" | "Turkey" | "Iran" | "Iraq" | "Syria" +relation ::= "was born in" | "died in" | "lived in" | "worked in" | "studied in" | "invented" | "discovered" | "wrote" | "painted" | "sculpted" | "composed" | "played" | "sang" | "acted" | "directed" | "produced" | "won" | "lost" | "was awarded" | "was nominated" | "was married to" | "was divorced from" | "had children with" | "was friends with" | "was enemies with" diff --git a/tests/test_hf_generation/test_generation.py b/tests/test_hf_generation/test_generation.py index 86698a5..f40d87e 100644 --- a/tests/test_hf_generation/test_generation.py +++ b/tests/test_hf_generation/test_generation.py @@ -178,7 +178,7 @@ def test_generate_emoji(self): tokenizer = self.tokenizers[model_id] grammar = IncrementalTokenRecognizer( - grammar_str, start_rule_name="root", tokenizer=tokenizer, unicode=True + grammar_str, start_rule_name="root", tokenizer=tokenizer ) grammar_processor = GrammarConstrainedLogitsProcessor(grammar) diff --git a/tests/test_hf_generation/test_unicode_generation.py b/tests/test_hf_generation/test_unicode_generation.py index fe92e65..e7fefef 100644 --- a/tests/test_hf_generation/test_unicode_generation.py +++ b/tests/test_hf_generation/test_unicode_generation.py @@ -1,6 +1,7 @@ from unittest import TestCase from transformers import AutoModelForCausalLM, AutoTokenizer from transformers_cfg.token_grammar_recognizer import IncrementalTokenRecognizer +from transformers_cfg.token_grammar_recognizer import AbsTokenRecognizer from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor @@ -9,6 +10,21 @@ ] +class TestDetectUnicode(TestCase): + def test_detect_unicode(self): + # Test with a string containing only ASCII characters + self.assertFalse(AbsTokenRecognizer.detect_unicode("Hello, world!")) + + # Test with a string containing Unicode characters + self.assertTrue(AbsTokenRecognizer.detect_unicode("δ½ ε₯½οΌŒδΈ–η•ŒοΌ")) + + # Test with an empty string + self.assertFalse(AbsTokenRecognizer.detect_unicode("")) + + # Test with a string containing a mix of ASCII and Unicode characters + self.assertTrue(AbsTokenRecognizer.detect_unicode("Hello, RenΓ©!")) + + class TestGreedyDecoding(TestCase): @classmethod def setUpClass(cls): @@ -30,7 +46,7 @@ def test_generate_emoji(self): tokenizer = self.tokenizers[model_id] grammar = IncrementalTokenRecognizer( - grammar_str, start_rule_name="root", tokenizer=tokenizer, unicode=True + grammar_str, start_rule_name="root", tokenizer=tokenizer ) grammar_processor = GrammarConstrainedLogitsProcessor(grammar) diff --git a/transformers_cfg/cli/cli_main.py b/transformers_cfg/cli/cli_main.py index 2377de4..c72c7b7 100755 --- a/transformers_cfg/cli/cli_main.py +++ b/transformers_cfg/cli/cli_main.py @@ -2,6 +2,10 @@ import argparse from transformers_cfg.tokenization.utils import is_tokenizer_supported +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers_cfg.grammar_utils import IncrementalGrammarConstraint +from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor +import torch def parse_arguments(args=None): @@ -14,14 +18,80 @@ def parse_arguments(args=None): "model", type=str, help="The unique model name on HF hub." ) + # Sub-command: generate + generate_parser = subparsers.add_parser( + "generate", help="Generate text with grammar constraints" + ) + generate_parser.add_argument( + "-m", + "--model_id", + type=str, + required=True, + help="Model identifier for loading the tokenizer and model", + ) + generate_parser.add_argument( + "-g", + "--grammar_file_path", + type=str, + required=True, + help="Path to the grammar file", + ) + generate_parser.add_argument( + "-p", + "--prompt", + type=str, + required=True, + help="Prompt for generation", + ) + generate_parser.add_argument( + "-d", + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + choices=["cpu", "cuda"], + help="Device to run the model on", + ) + generate_parser.add_argument( + "-n", + "--max_new_tokens", + type=int, + default=20, + help="Maximum number of new tokens to generate", + ) + generate_parser.add_argument( + "--repetition_penalty", + type=float, + default=1.1, + help="Penalty for token repetition", + ) + generate_parser.add_argument( + "--use_4bit", + action="store_true", + help="Load the model in 4-bit mode using bitsandbytes", + ) + generate_parser.add_argument( + "--use_8bit", + action="store_true", + help="Load the model in 8-bit mode using bitsandbytes", + ) + + generate_parser.add_argument( + "--no_contrast_mode", + action="store_true", + help="Disable contrast mode (enabled by default)", + ) + + generate_parser.add_argument( + "--save_to", + type=str, + help="File path to save the generated text", + ) + return parser.parse_args(args) def check_model_support(model_name): # Check if the model tokenizer is supported - - # for now the only condition is that the tokenizer is in SUPPORTED_TOKENIZERS - # maybe there will be more conditions in the future if is_tokenizer_supported(model_name): print(f"{model_name} is supported") return True @@ -30,11 +100,117 @@ def check_model_support(model_name): return False +def generate_text(args): + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_id) + tokenizer.pad_token = tokenizer.eos_token + + # Load the model with bitsandbytes if 8bit or 4bit flag is set + if args.use_8bit or args.use_4bit: + try: + pass + except ImportError: + raise ImportError( + "You need to install bitsandbytes to use 8-bit or 4-bit modes. Install it with `pip install bitsandbytes`." + ) + + bnb_config = BitsAndBytesConfig( + load_in_8bit=args.use_8bit, + load_in_4bit=args.use_4bit, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, quantization_config=bnb_config, device_map="auto" + ) + else: + model = AutoModelForCausalLM.from_pretrained(args.model_id).to(args.device) + + # set special tokens in generation config + model.generation_config.pad_token_id = tokenizer.pad_token_id + + inputs = tokenizer( + args.prefix_prompt, add_special_tokens=False, return_tensors="pt", padding=True + ) + input_ids = inputs["input_ids"].to(args.device) + attention_mask = inputs["attention_mask"].to(args.device) + + # Load grammar + with open(args.grammar_file_path, "r") as file: + grammar_str = file.read() + grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer) + grammar_processor = GrammarConstrainedLogitsProcessor(grammar) + + # Generate with grammar constraints + constrained_output = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=args.max_new_tokens, + logits_processor=[grammar_processor], + repetition_penalty=args.repetition_penalty, + num_return_sequences=1, + ) + + # remove prefix from the output + constrained_output = constrained_output[:, len(input_ids[0]) :] + + constrained_generations = tokenizer.batch_decode( + constrained_output, skip_special_tokens=True + ) + + # print prompt first in color + print("\033[92m" + "Prompt:" + args.prefix_prompt + "\033[0m") + + # Store results for optional file output + result = f"Prompt: {args.prefix_prompt}\n\n" + + # Generate without grammar constraints (if contrast mode is enabled) + if not args.no_contrast_mode: + unconstrained_output = model.generate( + input_ids, + attention_mask=attention_mask, + do_sample=False, + max_new_tokens=args.max_new_tokens, + repetition_penalty=args.repetition_penalty, + num_return_sequences=1, + ) + # remove prefix from the output + unconstrained_output = unconstrained_output[:, len(input_ids[0]) :] + unconstrained_generations = tokenizer.batch_decode( + unconstrained_output, skip_special_tokens=True + ) + + # Print results in different colors + print("\033[91m" + "Unconstrained Generation:" + "\033[0m") + result += "Unconstrained Generation:\n" + for generation in unconstrained_generations: + print(generation) + result += generation + "\n" + + print("\033[94m" + "Constrained Generation:" + "\033[0m") + result += "Constrained Generation:\n" + for generation in constrained_generations: + print(generation) + result += generation + "\n" + + # Save to file if save_to is provided + if args.save_to: + with open(args.save_to, "w") as f: + f.write(result) + print(f"\nResults saved to {args.save_to}") + + def main(args=None): args = parse_arguments(args) + if args.command == "check": check_model_support(args.model) + elif args.command == "generate": + generate_text(args) if __name__ == "__main__": main() + +# TODO, add support for device selection for parsing diff --git a/transformers_cfg/generation/logits_process.py b/transformers_cfg/generation/logits_process.py index a09a773..24b1028 100644 --- a/transformers_cfg/generation/logits_process.py +++ b/transformers_cfg/generation/logits_process.py @@ -15,11 +15,12 @@ class GrammarConstrainedLogitsProcessor(LogitsProcessor): - def __init__(self, grammar_constraint, valid_token_start_idx=None): + def __init__(self, grammar_constraint, valid_token_start_idx=None, device=None): self.last_size = None self.grammar_constraint = grammar_constraint self.batch_parsing_states = None - self.valid_token_start_idx = None + self.valid_token_start_idx = valid_token_start_idx + self.device = device def mask_logits(self, logits, device): masked_logits = logits.clone() @@ -29,16 +30,22 @@ def mask_logits(self, logits, device): acceptance = self.grammar_constraint.batch_filter_vocab( self.batch_parsing_states, device ) - + # if the logits size of the model is more than the tokennizer vocab # we artificially expand the acceptance tensor and block everything # beyond the tokenizer vocab size acceptance_vocab_size = acceptance.shape[-1] masked_logits_vocab_size = masked_logits.shape[-1] if masked_logits_vocab_size != acceptance_vocab_size: - assert acceptance_vocab_size < masked_logits_vocab_size, "impossible for tokenizer vocab to be less than model vocab" + assert ( + acceptance_vocab_size < masked_logits_vocab_size + ), "impossible for tokenizer vocab to be less than model vocab" vocab_size_diff = masked_logits_vocab_size - acceptance_vocab_size - false_tensor = torch.zeros((*acceptance.shape[:-1], vocab_size_diff), dtype=torch.bool, device=device) + false_tensor = torch.zeros( + (*acceptance.shape[:-1], vocab_size_diff), + dtype=torch.bool, + device=device, + ) acceptance = torch.cat((acceptance, false_tensor), dim=-1) # acceptance is a tensor of shape (batch_size, vocab_size) @@ -70,14 +77,13 @@ def mask_logits(self, logits, device): masked_logits[~acceptance] = -math.inf return masked_logits - # TODO: batching - def process_logits(self, input_ids, scores,device=None): + def process_logits(self, input_ids, scores): """ :param input_ids: :param scores: :return: """ - if device is None: + if self.device is None: device = scores.device # we dynamically create stacks at the first call, so that we know the batch size and beam size if self.batch_parsing_states is None: diff --git a/transformers_cfg/token_grammar_recognizer.py b/transformers_cfg/token_grammar_recognizer.py index dad204c..7f71736 100644 --- a/transformers_cfg/token_grammar_recognizer.py +++ b/transformers_cfg/token_grammar_recognizer.py @@ -16,11 +16,18 @@ class AbsTokenRecognizer(ABC): - def __init__(self, grammar_str, tokenizer, start_rule_name="root", unicode=False, trie=None,homomorphism=None): + def __init__( + self, + grammar_str, + tokenizer, + start_rule_name="root", + trie=None, + homomorphism=None, + ): parsed_grammar = parse_ebnf(grammar_str) grammar_encoding = parsed_grammar.grammar_encoding self.start_rule_id = parsed_grammar.symbol_table.get(start_rule_name) - self.use_unicode = unicode + self.use_unicode = self.detect_unicode(grammar_str) self.eos_token_id = tokenizer.eos_token_id self.tokenizer = tokenizer @@ -95,10 +102,24 @@ def validate_and_set_eos_acceptance(self, acceptance: torch.Tensor) -> torch.Ten def accept_token_ids(self, token_ids, stacks) -> bool: """Accept a list of token IDs according to the grammar rules.""" raise NotImplementedError - + + @staticmethod + def detect_unicode(text: str) -> bool: + # check if the text contains any unicode characters + return any(ord(char) > 127 for char in text) + + class IncrementalTokenRecognizer(AbsTokenRecognizer): - def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False, trie=None,homomorphism=None): - super().__init__(grammar_str, tokenizer, start_rule_name, unicode, trie=trie,homomorphism=homomorphism) + def __init__( + self, grammar_str, start_rule_name, tokenizer, trie=None, homomorphism=None + ): + super().__init__( + grammar_str, + tokenizer, + start_rule_name, + trie=trie, + homomorphism=homomorphism, + ) self.last_size = None def _update_state_with_token_id( @@ -334,8 +355,8 @@ def check_token_acceptance_in_trie(trie_node, stacks, grammar, eos_token_id, acc class NonIncrementalTokenSeqRecognizer(IncrementalTokenRecognizer): - def __init__(self, grammar_str, start_rule_name, tokenizer, unicode=False): - super().__init__(grammar_str, start_rule_name, tokenizer, unicode) + def __init__(self, grammar_str, start_rule_name, tokenizer): + super().__init__(grammar_str, start_rule_name, tokenizer) def update_state_with_batch_token_seqs( self, input_ids, batch_parsing_states, valid_token_start_idx=None @@ -391,10 +412,7 @@ def update_state_with_batch_token_seqs( tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenRecognizer = IncrementalTokenRecognizer( - grammar_str=input_text, - start_rule_name="root", - tokenizer=tokenizer, - unicode=True, + grammar_str=input_text, start_rule_name="root", tokenizer=tokenizer ) japanese = "γƒˆγƒͺγƒΌγƒ " # "こんにけは"