diff --git a/README.md b/README.md index bf30735..0526a65 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,10 @@ the ability to generate prompt led synthetic datasets. Promptwright was inspired by the [redotvideo/pluto](https://github.com/redotvideo/pluto), in fact it started as fork, but ended up largley being a re-write, to allow dataset generation -against a local LLM model, as opposed to OpenAI where costs can be prohibitively expensive. +against a local LLM model. -The library interfaces with Ollama, making it easy to just pull a model and run Promptwright. +The library interfaces with Ollama, making it easy to just pull a model and run +Promptwright. ## Features @@ -54,12 +55,23 @@ To run an example: 4. Set the `model_name` in the chosen example file to the model you have downloaded. ```python - engine = LocalDataEngine( - args=LocalEngineArguments( + + tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Creative Writing Prompts", + model_system_prompt=system_prompt, + tree_degree=5, # Increase degree for more prompts + tree_depth=4, # Increase depth for more prompts + temperature=0.9, # Higher temperature for more creative variations + model_name="ollama/llama3" # Set the model name here + ) + ) + engine = DataEngine( + args=EngineArguments( instructions="Generate creative writing prompts and example responses.", system_prompt="You are a creative writing instructor providing writing prompts and example responses.", - model_name="llama3.2:latest", - temperature=0.9, # Higher temperature for more creative variations + model_name="ollama/llama3", + temperature=0.9, max_retries=2, ``` 5. Run your chosen example file: @@ -89,47 +101,34 @@ To run an example: } ``` -### Library Overview - -#### Classes - -- **Dataset**: A class for managing generated datasets. -- **LocalDataEngine**: The main engine responsible for interacting with the LLM client and generating datasets. -- **LocalEngineArguments**: A configuration class that defines the instructions, system prompt, model name temperature, retries, and prompt templates used for generating data. -- **OllamaClient**: A client class for interacting with the Ollama API -- **HFUploader**: A utility class for uploading datasets to Hugging Face (pass in the path to the dataset and token). - -### Troubleshooting - -If you encounter any errors while running the script, here are a few common troubleshooting steps: - -1. **Restart Ollama**: - ```bash - killall ollama && ollama serve - ``` - -2. **Verify Model Installation**: - ```bash - ollama pull {model_name} - ``` - -3. **Check Ollama Logs**: - Inspect the logs for any error messages that might provide more context on - what went wrong, these can be found in the `~/.ollama/logs` directory. - ## Model Compatibility The library should work with most LLM models. It has been tested with the following models so far: -- **LLaMA3**: The library is designed to work with the LLaMA model, specifically -the `llama3:latest` model. -- **Mistral**: The library is compatible with the Mistral model, which is a fork -of the GPT-3 model. +- **Mistral** +- **LLaMA3** +--**Qwen2.5** + +## Unpredictable Behavior + +The library is designed to generate synthetic data based on the prompts and instructions +provided. The quality of the generated data is dependent on the quality of the prompts +and the model used. The library does not guarantee the quality of the generated data. + +Large Language Models can sometimes generate unpredictable or inappropriate +content and the authors of this library are not responsible for the content +generated by the models. We recommend reviewing the generated data before using it +in any production environment. -If you test anymore, please make a pull request to update this list! +Large Language Models also have the potential to fail to stick with the behavior +defined by the prompt around JSON formatting, and may generate invalid JSON. This +is a known issue with the underlying model and not the library. We handle these +errors by retrying the generation process and filtering out invalid JSON. The +failure rate is low, but it can happen. We report on each failure within a final +summary. -### Contributing +## Contributing If something here could be improved, please open an issue or submit a pull request. diff --git a/coverage.xml b/coverage.xml index 1d3db5d..42f30ce 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,12 +1,12 @@ - + /Users/lhinds/repos/aiml/promptwright/promptwright - + @@ -15,348 +15,492 @@ - - - + + - + - - - - - - - - - - - - - + + + - + + - - - - - - - + + + + + + + - - - - + + + + - - - - - + + + + - - - - - - - + + + + + + + + + + + - - - + + - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - + + + - + - - - + - - - + - + + + + + + + + + + - + + - - + + + + - + + - - - - - - - + - - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - - - - - + + + + + + + + - + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - - + + - - - + + - + + - - - - + + + - - + + + - + - - - - - + + + + - - + - + - - + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - + + + + + - - - - + + + + + + + + + + + + + + + + + + + - + + - - + + + - - - + - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - - + + + - - + + + + - + + - - - - - - - - - - - + + + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - - - - - - - - - - - + + + + + @@ -365,71 +509,31 @@ + + - + - - + + - - - - - - - - - - - - - - - - + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + diff --git a/examples/__init__.py b/examples/__init__.py index 34a5496..d22f91b 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -1,3 +1,3 @@ -from . import coding, creative_writing, science +from . import example_coding, example_creative_writing, example_science -__all__ = ["coding", "creative_writing", "science"] +__all__ = ["example_coding", "example_creative_writing", "example_science"] diff --git a/examples/basic_prompt.py b/examples/basic_prompt.py deleted file mode 100644 index 576db1a..0000000 --- a/examples/basic_prompt.py +++ /dev/null @@ -1,99 +0,0 @@ -import os -import sys - -# Add the parent directory to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from promptwright import LocalDataEngine, LocalEngineArguments - - -def main(): - print("promptwright - Synthetic Dataset Generation") - print("==========================================") - - # First verify Ollama connection - engine = LocalDataEngine( - args=LocalEngineArguments( - instructions="Generate a simple test response.", - system_prompt="You are a helpful assistant.", - model_name="mistral:latest", - temperature=0.7, - max_retries=2, - prompt_template="""Return this exact JSON structure with a simple question and answer: - { - "messages": [ - { - "role": "system", - "content": "You are a helpful assistant. You provide clear and concise answers to user questions." - }, - { - "role": "user", - "content": "What is 2+2?" - }, - { - "role": "assistant", - "content": "2+2 equals 4." - } - ] - }""", - ) - ) - - try: - # Test Ollama connection - print("\nTesting Ollama connection...") - models = engine.llm_client.list_local_models() - print(f"Available models: {[m['name'] for m in models]}") - - # Generate a single test sample with system message consistency - # Amend the number of steps and batch size as needed, the current values - # are set for a single test sample - print("\nGenerating test sample(s)...") - dataset = engine.create_data(num_steps=1, batch_size=1, topic_tree=None) - - # Ensure consistency by adding the system message explicitly - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a helpful assistant. You provide clear and concise answers to user questions.", - }, - ) - - if len(dataset) > 0: - print("\nTest successful! Starting main generation...") - - # Ask for confirmation before proceeding with full generation - response = input("\nProceed with full generation? (y/n): ") - if response.lower() == "y": - dataset = engine.create_data(num_steps=100, batch_size=10, topic_tree=None) - - # Ensure system message consistency in all data - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a helpful assistant. You provide clear and concise answers to user questions.", - }, - ) - - dataset.save("full_dataset.jsonl") - - else: - print("\nError: Test generation failed") - - except Exception as e: - print(f"\nError encountered: {str(e)}") - print("\nTroubleshooting steps:") - print("1. Try restarting Ollama: 'killall ollama && ollama serve'") - print("2. Verify model is installed: 'ollama pull $model_name'") - print("3. Check Ollama logs for errors") - raise - - -if __name__ == "__main__": - main() diff --git a/examples/coding.py b/examples/coding.py deleted file mode 100644 index 7dfbc04..0000000 --- a/examples/coding.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import sys - -# Add the parent directory to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from promptwright import LocalDataEngine, LocalEngineArguments - - -def main(): - print("promptwright - Programming Dataset Generation") - print("=====================================") - - engine = LocalDataEngine( - args=LocalEngineArguments( - instructions="Generate a simple programming question and answer.", - system_prompt="You are a programming expert.", - model_name="llama3.2", - temperature=0.7, - max_retries=2, - prompt_template="""Return this exact JSON structure with a programming question and answer: - { - "messages": [ - { - "role": "system", - "content": "You are a programming expert. You provide clear code examples and explanations." - }, - { - "role": "user", - "content": "How do I print Hello World in Python?" - }, - { - "role": "assistant", - "content": "To print Hello World in Python, use this code:\\n```python\\nprint('Hello World')\\n```\\nThis will display the text Hello World in the console." - } - ] - }""", - ) - ) - - try: - # Test Ollama connection - print("\nTesting Ollama connection...") - models = engine.llm_client.list_local_models() - print(f"Available models: {[m['name'] for m in models]}") - - # Generate a single test sample - print("\nGenerating test sample...") - dataset = engine.create_data(num_steps=1, batch_size=1, topic_tree=None) - - # Ensure consistency by adding the system message explicitly - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a programming instructor. You provide clear code examples and explanations.", - }, - ) - - if len(dataset) > 0: - print("\nTest successful! Starting main generation...") - response = input("\nProceed with full generation? (y/n): ") - if response.lower() == "y": - dataset = engine.create_data(num_steps=100, batch_size=10, topic_tree=None) - - # Ensure system message consistency - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a programming instructor. You provide clear code examples and explanations.", - }, - ) - - dataset.save("programming_dataset.jsonl") - print(f"\nSaved {len(dataset)} programming Q&A pairs to programming_dataset.jsonl") - else: - print("\nError: Test generation failed") - - except Exception as e: - print(f"\nError encountered: {str(e)}") - print("\nTroubleshooting steps:") - print("1. Try restarting Ollama: 'killall ollama && ollama serve'") - print("2. Verify model is installed: 'ollama pull llama3.2'") - print("3. Check Ollama logs for errors") - raise - - -if __name__ == "__main__": - main() diff --git a/examples/creative_writing.py b/examples/creative_writing.py deleted file mode 100644 index 3fb0c77..0000000 --- a/examples/creative_writing.py +++ /dev/null @@ -1,97 +0,0 @@ -import os -import sys - -# Add the parent directory to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from promptwright import LocalDataEngine, LocalEngineArguments - - -def main(): - print("promptwright - Creative Writing Dataset Generation") - print("=====================================") - - # Initialize engine with creative writing-specific parameters - engine = LocalDataEngine( - args=LocalEngineArguments( - instructions="Generate creative writing prompts and example responses.", - system_prompt="You are a creative writing instructor providing writing prompts and example responses.", - model_name="llama3.2", - temperature=0.9, # Higher temperature for more creative variations - max_retries=2, - prompt_template="""Return this exact JSON structure with a writing prompt and creative response: - { - "messages": [ - { - "role": "system", - "content": "You are a creative writing instructor providing writing prompts and example responses." - }, - { - "role": "user", - "content": "Write a short story about finding an unexpected letter." - }, - { - "role": "assistant", - "content": "The envelope was weathered, its corners soft from age. Sarah found it tucked behind the loose baseboard she'd been meaning to fix, addressed in her grandmother's distinctive cursive. The postmark read 1962—this letter had waited sixty years to be found. With trembling fingers, she broke the seal..." - } - ] - }""", - ) - ) - - try: - # Test Ollama connection - print("\nTesting Ollama connection...") - models = engine.llm_client.list_local_models() - print(f"Available models: {[m['name'] for m in models]}") - - # Generate a single test sample - print("\nGenerating writing test sample...") - dataset = engine.create_data(num_steps=1, batch_size=1, topic_tree=None) - - # Ensure consistency by adding the system message explicitly - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a creative writing instructor providing writing prompts and example responses.", - }, - ) - - if len(dataset) > 0: - print("\nTest successful! Starting main generation...") - response = input("\nProceed with full writing dataset generation? (y/n): ") - if response.lower() == "y": - dataset = engine.create_data(num_steps=100, batch_size=10, topic_tree=None) - - # Ensure system message consistency - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a creative writing instructor providing writing prompts and example responses.", - }, - ) - - dataset.save("writing_dataset.jsonl") - print( - f"\nSaved {len(dataset)} writing prompts and responses to writing_dataset.jsonl" - ) - else: - print("\nError: Test generation failed") - - except Exception as e: - print(f"\nError encountered: {str(e)}") - print("\nTroubleshooting steps:") - print("1. Try restarting Ollama: 'killall ollama && ollama serve'") - print("2. Verify model is installed: 'ollama pull mistral:latest'") - print("3. Check Ollama logs for errors") - raise - - -if __name__ == "__main__": - main() diff --git a/examples/example_basic_prompt.py b/examples/example_basic_prompt.py new file mode 100644 index 0000000..07df970 --- /dev/null +++ b/examples/example_basic_prompt.py @@ -0,0 +1,44 @@ +import os +import sys + +# Add the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = ( + "You are a helpful assistant. You provide clear and concise answers to user questions." +) + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Capital Cities of the World.", + model_system_prompt=system_prompt, + tree_degree=3, # Different continents + tree_depth=2, # Deeper tree for more specific topics + temperature=0.7, # Higher temperature for more creative variations + model_name="ollama/mistral:latest", # Model name + ) +) + +tree.build_tree() +tree.save("basic_prompt_topictree.jsonl") + +engine = DataEngine( + args=EngineArguments( + instructions="Please provide training examples with questions about capital cities of the world.", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/mistral:latest", # Model name + temperature=0.9, # Higher temperature for more creative variations + max_retries=2, # Retry failed prompts up to 2 times + ) +) + +dataset = engine.create_data( + num_steps=5, + batch_size=1, + topic_tree=tree, + model_name="ollama/mistral:latest", +) + +dataset.save("basic_prompt_dataset.jsonl") diff --git a/examples/example_creative_writing.py b/examples/example_creative_writing.py new file mode 100644 index 0000000..e54820a --- /dev/null +++ b/examples/example_creative_writing.py @@ -0,0 +1,43 @@ +import os +import sys + +# Add the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are a creative writing instructor providing writing prompts and example responses. If you use apostrophes in your prompts, make sure to escape them with a backslash. For example, use 'don\'t' instead of 'don't'. Respond only with valid JSON. Do not write an introduction or summary.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Creative Writing Prompts", + model_system_prompt=system_prompt, + tree_degree=5, + tree_depth=4, + temperature=0.7, + model_name="ollama/llama3", + ) +) + +tree.build_tree() +tree.save("numpy_topictree.jsonl") + +# Initialize engine with creative writing-specific parameters +engine = DataEngine( + args=EngineArguments( + instructions="Generate creative writing prompts and example responses.", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.7, # Higher temperature for creative writing + max_retries=3, # Retry failed prompts up to 3 times + ) +) + +dataset = engine.create_data( + num_steps=10, + batch_size=5, + topic_tree=tree, + model_name="ollama/llama3.2", +) + +dataset.save("creative_writing.jsonl") diff --git a/examples/example_culinary_database.py b/examples/example_culinary_database.py new file mode 100644 index 0000000..49005f4 --- /dev/null +++ b/examples/example_culinary_database.py @@ -0,0 +1,48 @@ +import os +import sys + +# Add the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are a culinary expert who documents recipes and cooking techniques. +Your entries should be detailed, precise, and include both traditional and modern cooking methods.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Global Cuisine and Cooking Techniques", # Root prompt for the tree + model_system_prompt=system_prompt, # System prompt for the model + tree_degree=5, # Different cuisine types + tree_depth=3, # Specific dishes and techniques + temperature=0.7, # Balanced temperature for creativity and precision + model_name="ollama/llama3", # Model name + ) +) + +tree.build_tree() +tree.save("culinary_techniques_tree.jsonl") + +engine = DataEngine( + args=EngineArguments( + instructions="""Create detailed recipe and technique entries that include: + - Ingredient lists with possible substitutions + - Step-by-step instructions + - Critical technique explanations + - Common mistakes to avoid + - Storage and serving suggestions + - Cultural context and history""", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.1, # Balance between creativity and precision + max_retries=3, # Retry failed prompts up to 3 times + ) +) + +dataset = engine.create_data( + num_steps=15, # Generate 15 entries + batch_size=2, # Generate 2 entries at a time + topic_tree=tree, +) + +dataset.save("culinary_database.jsonl") diff --git a/examples/example_historic_figures.py b/examples/example_historic_figures.py new file mode 100644 index 0000000..f6068cc --- /dev/null +++ b/examples/example_historic_figures.py @@ -0,0 +1,35 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are a knowledgeable historian who creates detailed, accurate biographical entries. +Each entry should include: birth/death dates, major achievements, historical impact, and interesting anecdotes.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Notable Historical Figures Across Different Eras and Fields", + model_system_prompt=system_prompt, + tree_degree=4, # More branches for different categories + tree_depth=3, # Deeper tree for more specific figures + temperature=0.6, # Balanced temperature for creativity and accuracy + model_name="ollama/llama3", # Model name + ) +) + +tree.build_tree() +tree.save("historical_figures_tree.jsonl") + +engine = DataEngine( + args=EngineArguments( + instructions="""Generate biographical entries for historical figures. + Include lesser-known details and focus on their lasting impact. + Each entry should be engaging while maintaining historical accuracy.""", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.7, # Balance between creativity and accuracy + max_retries=3, # Retry failed generations up to 3 times + ) +) diff --git a/examples/example_programming_challenges.py b/examples/example_programming_challenges.py new file mode 100644 index 0000000..ccf4832 --- /dev/null +++ b/examples/example_programming_challenges.py @@ -0,0 +1,47 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are an expert programming instructor who creates engaging coding challenges. +Each challenge should test specific programming concepts while remaining accessible and educational.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Programming Challenges Across Different Difficulty Levels and Concepts", # Root prompt for the tree + model_system_prompt=system_prompt, # System prompt for the model + tree_degree=4, # Different programming concepts + tree_depth=2, # Various difficulty levels + temperature=0.7, # Higher temperature for creative problem scenarios + model_name="ollama/llama3", # Model name + ) +) + +tree.build_tree() +tree.save("programming_challenges_tree.jsonl") + +engine = DataEngine( + args=EngineArguments( + instructions="""Generate programming challenges that include: + - Problem description + - Input/Output examples + - Constraints and edge cases + - Hint system (progressive hints) + - Solution approach discussion + - Time/Space complexity requirements""", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.8, # Higher temperature for creative problem scenarios + max_retries=3, # Retry failed generations up to 3 times + ) +) + +dataset = engine.create_data( + num_steps=6, + batch_size=2, + topic_tree=tree, +) + +dataset.save("programming_challenges.jsonl") diff --git a/examples/push_to_hf_hub.py b/examples/example_push_to_hf_hub.py similarity index 93% rename from examples/push_to_hf_hub.py rename to examples/example_push_to_hf_hub.py index 1e62b4f..b131aa8 100644 --- a/examples/push_to_hf_hub.py +++ b/examples/example_push_to_hf_hub.py @@ -9,7 +9,7 @@ def main(): print("promptwright - Uploading to Hugging Face Hub") - print("==========================================") + print("============================================") dataset_file = "my_dataset.jsonl" diff --git a/examples/example_redteam_scenarios_dataset.py b/examples/example_redteam_scenarios_dataset.py new file mode 100644 index 0000000..4612f75 --- /dev/null +++ b/examples/example_redteam_scenarios_dataset.py @@ -0,0 +1,45 @@ +import os +import sys + +# Add the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are a red team exercise designer creating advanced cybersecurity training scenarios. +Focus on creating realistic, detailed scenarios that help organizations improve their security posture. +All scenarios must emphasize legal compliance and ethical considerations.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Advanced Red Team Exercise Scenarios and Methodologies", # Root prompt for the tree + model_system_prompt=system_prompt, + tree_degree=4, # Different attack vectors + tree_depth=4, # Detailed scenario branches + temperature=0.6, # Balanced for creativity within constraints + model_name="ollama/qwen2.5", # Model name + ) +) + +tree.build_tree() +tree.save("redteam_scenarios_tree.jsonl") + +engine = DataEngine( + args=EngineArguments( + instructions="""Generate red team training scenarios. Each scenario should be realistic and focus on improving +organizational security posture while maintaining ethical standards.""", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.6, # Balanced for creativity within constraints + max_retries=3, # Retry failed prompts up to 3 times + ) +) + +dataset = engine.create_data( + num_steps=12, # Generate 12 scenarios + batch_size=2, # Generate 2 scenarios at a time + topic_tree=tree, # Use the red team scenarios tree + model_name="ollama/qwen2.5", # Use the Qwen model +) + +dataset.save("redteam_scenarios_dataset.jsonl") diff --git a/examples/example_scientific_experiments.py b/examples/example_scientific_experiments.py new file mode 100644 index 0000000..382dd00 --- /dev/null +++ b/examples/example_scientific_experiments.py @@ -0,0 +1,48 @@ +import os +import sys + +# Add the parent directory to sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from promptwright import DataEngine, EngineArguments, TopicTree, TopicTreeArguments + +system_prompt = """You are a scientific database curator specializing in experimental procedures and results. +Your role is to document experiments clearly, including methodology, materials, observations, and conclusions.""" + +tree = TopicTree( + args=TopicTreeArguments( + root_prompt="Groundbreaking Scientific Experiments Throughout History", # Root prompt for the tree + model_system_prompt=system_prompt, # System prompt for the model + tree_degree=3, # Branch into different scientific fields + tree_depth=3, # Go deeper into specific experiments + temperature=0.6, # Lower temperature for more precise content + model_name="ollama/llama3", + ) +) + +tree.build_tree() +tree.save("science_experiments_tree.jsonl") # Save the generated topic tree to a file + +engine = DataEngine( + args=EngineArguments( + instructions="""Create detailed experimental procedure entries. + Each entry should include: + - Required materials and equipment + - Step-by-step methodology + - Expected results and observations + - Common pitfalls and troubleshooting + - Safety considerations""", # Instructions for the model + system_prompt=system_prompt, # System prompt for the model + model_name="ollama/llama3", # Model name + temperature=0.4, # Lower temperature for more precise content + max_retries=3, # Retry failed prompts up to 3 times + ) +) + +dataset = engine.create_data( + num_steps=8, # Number of steps to generate + batch_size=2, # Batch size for each step + topic_tree=tree, # Topic tree used to guide the generation +) + +dataset.save("scientific_experiments.jsonl") # Save the generated dataset to a file diff --git a/examples/science.py b/examples/science.py deleted file mode 100644 index cb16978..0000000 --- a/examples/science.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import sys - -# Add the parent directory to sys.path -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from promptwright import LocalDataEngine, LocalEngineArguments - - -def main(): - print("promptwright - Scientific Dataset Generation") - print("=====================================") - - # Initialize engine with science-specific parameters - engine = LocalDataEngine( - args=LocalEngineArguments( - instructions="Generate scientific Q&A pairs covering physics, chemistry, and biology.", - system_prompt="You are a science educator providing accurate, clear explanations.", - model_name="llama3.2", - temperature=0.7, - max_retries=2, - prompt_template="""Return this exact JSON structure with a scientific question and detailed answer: - { - "messages": [ - { - "role": "system", - "content": "You are a science educator providing accurate, clear explanations." - }, - { - "role": "user", - "content": "What causes the seasons on Earth?" - }, - { - "role": "assistant", - "content": "Seasons are caused by Earth's tilted axis of rotation (23.5 degrees) as it orbits the Sun. This tilt means different hemispheres receive varying amounts of direct sunlight throughout the year, leading to seasonal changes in temperature and daylight hours." - } - ] - }""", - ) - ) - - try: - # Test Ollama connection - print("\nTesting Ollama connection...") - models = engine.llm_client.list_local_models() - print(f"Available models: {[m['name'] for m in models]}") - - # Generate a single test sample - print("\nGenerating science test sample...") - dataset = engine.create_data(num_steps=1, batch_size=1, topic_tree=None) - - # Ensure consistency by adding the system message explicitly - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a science educator providing accurate, clear explanations.", - }, - ) - - if len(dataset) > 0: - print("\nTest successful! Starting main generation...") - response = input("\nProceed with full science dataset generation? (y/n): ") - if response.lower() == "y": - dataset = engine.create_data(num_steps=100, batch_size=10, topic_tree=None) - - # Ensure system message consistency - for data in dataset: - if not any(msg.get("role") == "system" for msg in data.get("messages", [])): - data["messages"].insert( - 0, - { - "role": "system", - "content": "You are a science educator providing accurate, clear explanations.", - }, - ) - - dataset.save("science_dataset.jsonl") - print(f"\nSaved {len(dataset)} science Q&A pairs to science_dataset.jsonl") - else: - print("\nError: Test generation failed") - - except Exception as e: - print(f"\nError encountered: {str(e)}") - print("\nTroubleshooting steps:") - print("1. Try restarting Ollama: 'killall ollama && ollama serve'") - print("2. Verify model is installed: 'ollama pull mistral:latest'") - print("3. Check Ollama logs for errors") - raise - - -if __name__ == "__main__": - main() diff --git a/promptwright/__init__.py b/promptwright/__init__.py index 2d20e32..7932045 100644 --- a/promptwright/__init__.py +++ b/promptwright/__init__.py @@ -1,18 +1,16 @@ # promptwright/__init__.py from .dataset import Dataset -from .engine import LocalDataEngine, LocalEngineArguments +from .engine import DataEngine, EngineArguments from .hf_hub import HFUploader -from .ollama_client import OllamaClient -from .topic_tree import LocalTopicTree, LocalTopicTreeArguments +from .topic_tree import TopicTree, TopicTreeArguments __version__ = "0.1.0" __all__ = [ - "LocalTopicTree", - "LocalTopicTreeArguments", - "LocalDataEngine", - "LocalEngineArguments", + "TopicTree", + "TopicTreeArguments", + "DataEngine", + "EngineArguments", "Dataset", - "OllamaClient", "HFUploader", ] diff --git a/promptwright/dataset.py b/promptwright/dataset.py index ef97b74..cd6e4f9 100644 --- a/promptwright/dataset.py +++ b/promptwright/dataset.py @@ -3,15 +3,37 @@ class Dataset: - """A class to handle training datasets for local LLM fine-tuning. - - This class manages collections of training samples, providing functionality - to load, save, validate, and manipulate the dataset. + """ + A class to represent a dataset consisting of samples, where each sample contains messages with specific roles. + Methods: + __init__(): + Initialize an empty dataset. + from_jsonl(file_path: str) -> "Dataset": + Create a Dataset instance from a JSONL file. + from_list(sample_list: list[dict]) -> "Dataset": + Create a Dataset instance from a list of samples. + validate_sample(sample: dict) -> bool: + Validate if a sample has the correct format. + add_samples(samples: list[dict]) -> tuple[list[dict], list[str]]: + Add multiple samples to the dataset and return any failures. + remove_linebreaks_and_spaces(input_string: str) -> str: + Clean up a string by removing extra whitespace and normalizing linebreaks. + save(save_path: str): + Save the dataset to a JSONL file. + __len__() -> int: + Get the number of samples in the dataset. + __getitem__(idx: int) -> dict: + Get a sample from the dataset by index. + filter_by_role(role: str) -> list[dict]: + Filter samples to only include messages with a specific role. + get_statistics() -> dict: + Calculate basic statistics about the dataset. """ def __init__(self): """Initialize an empty dataset.""" self.samples = [] + self.failed_samples = [] @classmethod def from_jsonl(cls, file_path: str) -> "Dataset": @@ -30,7 +52,7 @@ def from_jsonl(cls, file_path: str) -> "Dataset": if cls.validate_sample(sample): instance.samples.append(sample) else: - print(f"Warning: Invalid sample found and skipped: {sample}") + instance.failed_samples.append(sample) return instance @@ -49,7 +71,7 @@ def from_list(cls, sample_list: list[dict]) -> "Dataset": if cls.validate_sample(sample): instance.samples.append(sample) else: - print(f"Warning: Invalid sample skipped: {sample}") + instance.failed_samples.append(sample) return instance @@ -78,17 +100,27 @@ def validate_sample(sample: dict) -> bool: return True - def add_samples(self, samples: list[dict]): - """Add multiple samples to the dataset. + def add_samples(self, samples: list[dict]) -> tuple[list[dict], list[str]]: + """Add multiple samples to the dataset and return any failures. Args: samples: List of dictionaries containing the samples to add. + + Returns: + tuple: (list of failed samples, list of failure descriptions) """ + failed_samples = [] + failure_descriptions = [] + for sample in samples: if self.validate_sample(sample): self.samples.append(sample) else: - print(f"Warning: Invalid sample, not added: {sample}") + failed_samples.append(sample) + failure_descriptions.append(f"Invalid sample format: {sample}") + self.failed_samples.append(sample) + + return failed_samples, failure_descriptions @staticmethod def remove_linebreaks_and_spaces(input_string: str) -> str: @@ -118,10 +150,6 @@ def save(self, save_path: str): f.write(clean_json + "\n") print(f"Saved dataset to {save_path}") - print("\nYou can now use this dataset for fine-tuning with various platforms:") - print("- Ollama: Use 'ollama create' with the dataset") - print("- LocalAI: Import directly into your local instance") - print("- Other platforms that support JSONL format for fine-tuning") def __len__(self) -> int: """Get the number of samples in the dataset. diff --git a/promptwright/engine.py b/promptwright/engine.py index f8517d2..4013af7 100644 --- a/promptwright/engine.py +++ b/promptwright/engine.py @@ -1,224 +1,291 @@ import json +import math import random -import time +import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any + +import litellm from tqdm import tqdm from .dataset import Dataset -from .ollama_client import OllamaClient +from .prompts import ENGINE_JSON_INSTRUCTIONS, SAMPLE_GENERATION_PROMPT +from .topic_tree import TopicTree # Handle circular import for type hints if TYPE_CHECKING: - from .topic_tree import LocalTopicTree + from .topic_tree import TopicTree + + +def validate_json_response(json_str: str, schema: dict[str, Any] | None = None) -> dict | None: + """Validate and clean JSON response from LLM.""" + try: + json_match = re.search(r"(?s)\{.*\}", json_str) + if not json_match: + return None + + cleaned_json = json_match.group(0) + cleaned_json = re.sub(r"```json\s*|\s*```", "", cleaned_json) + + parsed = json.loads(cleaned_json) + + if schema is not None: + # Schema validation could be added here + pass + else: + return parsed + except (json.JSONDecodeError, ValueError): + return None @dataclass -class LocalEngineArguments: +class EngineArguments: instructions: str system_prompt: str - model_name: str + model_name: str # Required field prompt_template: str | None = None example_data: Dataset | None = None - ollama_base_url: str = "http://localhost:11434" - temperature: float = 0.7 + temperature: float = 0.2 max_retries: int = 3 default_batch_size: int = 5 default_num_examples: int = 3 - request_timeout: int = 30 # Added timeout parameter - - -class LocalDataEngine: - DEFAULT_PROMPT_TEMPLATE = """Generate a JSON object in this exact format: - { - "messages": [ - { - "role": "user", - "content": "" - }, - { - "role": "assistant", - "content": "" - } - ] - }""" - - def __init__(self, args: LocalEngineArguments): - if not args.model_name: - raise ValueError("model_name unspecified in LocalEngineArguments") # noqa: TRY003 + request_timeout: int = 30 + +class DataEngine: + def __init__(self, args: EngineArguments): + if ( + not args.model_name + or not isinstance(args.model_name, str) + or not args.model_name.strip() + ): + raise ValueError("model_name must be a non-empty string in EngineArguments") # noqa: TRY003 + + self.model_name = args.model_name.strip() # Store model_name as instance variable self.args = args self.dataset = Dataset() - self.llm_client = OllamaClient(base_url=args.ollama_base_url) - self.failed_samples = [] # Track failed attempts + self.failed_samples = [] + self.failure_analysis = { + "json_parsing_errors": [], + "invalid_schema": [], + "api_errors": [], + "empty_responses": [], + "malformed_responses": [], + "other_errors": [], + } + self.args.system_prompt = ENGINE_JSON_INSTRUCTIONS + self.args.system_prompt + + def analyze_failure(self, response_content: str, error: Exception = None) -> str: + """Analyze the failure reason for a sample.""" + if error: + error_str = str(error) + if "schema" in error_str.lower(): + return "invalid_schema" + if any( + api_err in error_str.lower() for api_err in ["timeout", "rate limit", "connection"] + ): + return "api_errors" + return "other_errors" + + if not response_content or response_content.isspace(): + return "empty_responses" + + # Check if response seems to be attempting JSON but failing + if any(char in response_content for char in "{}[]"): + return "json_parsing_errors" + return "malformed_responses" + + def summarize_failures(self) -> dict: + """Generate a summary of all failures.""" + summary = { + "total_failures": len(self.failed_samples), + "failure_types": {k: len(v) for k, v in self.failure_analysis.items()}, + "failure_examples": {}, + } + + # Add example failures for each category + for category, failures in self.failure_analysis.items(): + if failures: + # Get up to 3 examples for each category + examples = failures[:3] + summary["failure_examples"][category] = [ + str(ex)[:200] + "..." if len(str(ex)) > 200 else str(ex) # noqa: PLR2004 + for ex in examples + ] + return summary def create_data( # noqa: PLR0912 - self, # noqa: PLR0912 - num_steps: int | None = None, - batch_size: int | None = None, - num_example_demonstrations: int | None = None, - topic_tree: Optional["LocalTopicTree"] = None, - ) -> Dataset: - """Generate training data with improved reliability.""" - batch_size = min(batch_size or self.args.default_batch_size, 5) - - if num_steps is None and topic_tree is None: - raise ValueError("Must specify either num_steps or provide a topic_tree") # noqa: TRY003 + self, + num_steps: int = None, + num_example_demonstrations: int = 3, + batch_size: int = 10, + topic_tree: TopicTree = None, + model_name: str = None, + ): + if num_steps is None: + raise ValueError("num_steps must be specified") # noqa: TRY003 + + # Use instance model_name as fallback if none provided + self.model_name = model_name.strip() if model_name else self.model_name + + if not self.model_name: + raise ValueError("No valid model_name provided") # noqa: TRY003 + + data_creation_prompt = SAMPLE_GENERATION_PROMPT + + tree_paths = None + if topic_tree is not None: + tree_paths = topic_tree.tree_paths + total_paths = len(tree_paths) + required_samples = num_steps * batch_size + + if required_samples > total_paths: + raise ValueError( # noqa: TRY003 + f"Required samples ({required_samples}) exceeds available tree paths ({total_paths})" + ) # noqa: TRY003 + + tree_paths = random.sample(tree_paths, required_samples) + num_steps = math.ceil(len(tree_paths) / batch_size) total_samples = num_steps * batch_size - success_count = 0 - consecutive_failures = 0 - - print("\nStarting generation:") - print(f"Target: {total_samples} samples") - print(f"Model: {self.args.model_name}") - print(f"Batch size: {batch_size}") + print(f"Generating dataset using model {self.model_name}") + print(f"Generating dataset in {num_steps} steps, with batch size {batch_size}") - start_time = time.time() - last_save_time = start_time + # Enable JSON schema validation + litellm.enable_json_schema_validation = True try: - with tqdm(total=total_samples, desc="Generating samples") as pbar: + with tqdm(total=total_samples, desc="Progress") as pbar: for step in range(num_steps): - for batch_item in range(batch_size): - sample_success = False - retries = 0 - - while not sample_success and retries < self.args.max_retries: - try: - # Generate prompt - prompt = self.build_prompt( - num_example_demonstrations=(num_example_demonstrations or 0), - subtopics_list=None, - ) - - # Get model response with timeout - response = self.llm_client.generate_completion( - prompt=prompt, - model=self.args.model_name, - system_prompt=self.args.system_prompt, - temperature=self.args.temperature, - ) - - # Parse and validate - sample = json.loads(response.content) - if self._validate_sample(sample): - self.dataset.add_samples([sample]) - success_count += 1 - consecutive_failures = 0 - sample_success = True - pbar.update(1) + prompts = [] + start_idx = step * batch_size + + for i in range(batch_size): + path = None + if tree_paths: + current_idx = start_idx + i + if current_idx < len(tree_paths): + path = tree_paths[current_idx] + else: + break + + sample_prompt = self.build_prompt( + data_creation_prompt=data_creation_prompt, + num_example_demonstrations=num_example_demonstrations, + subtopics_list=path, + ) + prompts.append(sample_prompt) + + for attempt in range(self.args.max_retries): + try: + responses = litellm.batch_completion( + model=self.model_name, + messages=[[{"role": "user", "content": p}] for p in prompts], + temperature=self.args.temperature, + ) + + samples = [] + for r in responses: + response_content = r.choices[0].message.content + parsed_json = validate_json_response(response_content) + + if parsed_json: + samples.append(parsed_json) else: - retries += 1 + self.failed_samples.append(response_content) + failure_type = self.analyze_failure(response_content) + self.failure_analysis[failure_type].append(response_content) - except Exception as e: - retries += 1 - consecutive_failures += 1 - self.failed_samples.append( - {"step": step, "batch_item": batch_item, "error": str(e)} + if samples: + failed_samples, failure_descriptions = self.dataset.add_samples( + samples ) - - if consecutive_failures >= 5: # noqa: PLR2004 - print( - f"\nToo many consecutive failures ({consecutive_failures}). Saving progress..." - ) - self.save_dataset( - f"emergency_save_{success_count}_samples.jsonl" - ) - print("Consider:") - print("1. Checking Ollama status") - print("2. Restarting Ollama") - print("3. Using a different model") - return self.dataset - - # Save progress every 5 minutes or 10 successful samples - current_time = time.time() - if (current_time - last_save_time) > 300 or success_count % 10 == 0: # noqa: PLR2004 - last_save_time = current_time - - # Show progress statistics - elapsed = current_time - start_time - rate = success_count / elapsed - eta = (total_samples - success_count) / rate if rate > 0 else 0 - - print("\nProgress update:") - print(f"Samples generated: {success_count}/{total_samples}") - print(f"Generation rate: {rate:.2f} samples/second") - print(f"Estimated time remaining: {eta/60:.1f} minutes") + if failed_samples: + for sample, desc in zip( + failed_samples, failure_descriptions, strict=True + ): + self.failed_samples.append(sample) + self.failure_analysis["invalid_schema"].append(desc) + pbar.update(len(samples) - len(failed_samples)) + break # Success - exit retry loop + + except Exception as e: + if attempt == self.args.max_retries - 1: + print(f"Failed after {self.args.max_retries} attempts: {str(e)}") + self.failed_samples.append(str(e)) + failure_type = self.analyze_failure(str(e), error=e) + self.failure_analysis[failure_type].append(str(e)) + else: + print(f"Attempt {attempt + 1} failed: {str(e)}") except KeyboardInterrupt: print("\nGeneration interrupted by user.") + self.print_failure_summary() self.save_dataset("interrupted_dataset.jsonl") + return self.dataset except Exception as e: print(f"\nUnexpected error: {str(e)}") + self.print_failure_summary() self.save_dataset("error_dataset.jsonl") raise - finally: - # Save failure log if there were any failures - if self.failed_samples: - with open("generation_failures.json", "w") as f: - json.dump(self.failed_samples, f, indent=2) - - total_duration = time.time() - start_time - print("\nGeneration complete:") - print(f"Total samples: {success_count}/{total_samples}") - print(f"Success rate: {(success_count/total_samples)*100:.1f}%") - print(f"Total time: {total_duration/60:.1f} minutes") - print(f"Average speed: {success_count/total_duration:.2f} samples/second") - + print(f"Successfully Generated {len(self.dataset)} samples.") + self.print_failure_summary() return self.dataset + def print_failure_summary(self): + """Print a detailed summary of all failures.""" + summary = self.summarize_failures() + + print("\n=== Failure Analysis Summary ===") + print(f"Total Failed Samples: {summary['total_failures']}") + print("\nFailure Types Breakdown:") + for failure_type, count in summary["failure_types"].items(): + if count > 0: + print(f"\n{failure_type.replace('_', ' ').title()}: {count}") + if failure_type in summary["failure_examples"]: + print("Example failures:") + for i, example in enumerate(summary["failure_examples"][failure_type], 1): + print(f" {i}. {example}") + print("\n=============================") + def build_prompt( - self, num_example_demonstrations: int, subtopics_list: list[str] = None + self, + data_creation_prompt: str, + num_example_demonstrations: int, + subtopics_list: list[str] = None, ) -> str: - """Build a minimal, focused prompt.""" - components = [] - base_prompt = ( - self.args.prompt_template if self.args.prompt_template else self.DEFAULT_PROMPT_TEMPLATE + prompt = data_creation_prompt.replace("{{{{system_prompt}}}}", self.build_system_prompt()) + prompt = prompt.replace("{{{{instructions}}}}", self.build_custom_instructions_text()) + prompt = prompt.replace( + "{{{{examples}}}}", self.build_examples_text(num_example_demonstrations) ) - components.append(base_prompt) + return prompt.replace("{{{{subtopics}}}}", self.build_subtopics_text(subtopics_list)) - if self.args.instructions: - components.append(f"\nRequirements: {self.args.instructions}") + def build_system_prompt(self): + return self.args.system_prompt - if subtopics_list: - components.append(f"\nTopic: {' -> '.join(subtopics_list)}") + def build_custom_instructions_text(self) -> str: + if self.args.instructions is None: + return "" + return f"\nHere are additional instructions:\n\n{self.args.instructions}\n\n" - # Only add examples if specifically requested - if self.args.example_data and num_example_demonstrations > 0: - components.append("\nExamples:") - examples = random.sample( - self.args.example_data.samples, - min(num_example_demonstrations, len(self.args.example_data.samples)), - ) - for ex in examples: - components.append(json.dumps(ex)) + def build_examples_text(self, num_example_demonstrations: int): + if self.args.example_data is None or num_example_demonstrations == 0: + return "" - return "\n".join(components) + examples = random.sample(self.args.example_data.samples, num_example_demonstrations) + examples_text = "Here are output examples:\n\n" + examples_text += "\n".join(f"Example {i+1}: \n\n{ex}\n" for i, ex in enumerate(examples)) + return f"\nHere are output examples:\n\n{examples_text}\n\n" - def _validate_sample(self, sample: dict) -> bool: # noqa: PLR0911 - """Validate sample format.""" - try: - if "messages" not in sample: - return False - - for msg in sample["messages"]: - if not all(key in msg for key in ["role", "content"]): - return False - if msg["role"] not in ["user", "assistant", "system"]: - return False - if not isinstance(msg["content"], str): - return False - if not msg["content"].strip(): - return False - - return True # noqa: TRY300 - except Exception: - return False + def build_subtopics_text(self, subtopic_list: list[str]): + if subtopic_list is None: + return "" + return f"\nLastly, the topic of the training data should be related to the following subtopics: {' -> '.join(subtopic_list)}" def save_dataset(self, save_path: str): """Save the dataset to a file.""" diff --git a/promptwright/hf_hub.py b/promptwright/hf_hub.py index 5b028be..00be43e 100644 --- a/promptwright/hf_hub.py +++ b/promptwright/hf_hub.py @@ -1,5 +1,3 @@ -# uploader.py - from datasets import load_dataset from huggingface_hub import login from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError @@ -7,7 +5,25 @@ class HFUploader: """ - A class to handle uploading datasets in JSONL format to the Hugging Face Hub. + HFUploader is a class for uploading datasets to the Hugging Face Hub. + + Methods + ------- + __init__(hf_token) + + push_to_hub(hf_dataset_repo, jsonl_file_path) + + Parameters + ---------- + hf_dataset_repo : str + The repository name in the format 'username/dataset_name'. + jsonl_file_path : str + Path to the JSONL file. + + Returns + ------- + dict + A dictionary containing the status and a message. """ def __init__(self, hf_token): @@ -31,13 +47,8 @@ def push_to_hub(self, hf_dataset_repo, jsonl_file_path): dict: A dictionary containing the status and a message. """ try: - # Login to Hugging Face Hub login(token=self.hf_token) - - # Load the dataset from the JSONL file dataset = load_dataset("json", data_files={"train": jsonl_file_path}) - - # Push the dataset to the Hugging Face Hub dataset.push_to_hub(hf_dataset_repo, token=self.hf_token) except RepositoryNotFoundError: diff --git a/promptwright/ollama_client.py b/promptwright/ollama_client.py deleted file mode 100644 index ca1d72b..0000000 --- a/promptwright/ollama_client.py +++ /dev/null @@ -1,114 +0,0 @@ -import json - -from dataclasses import dataclass - -import requests - - -@dataclass -class LLMResponse: - content: str - total_duration: int - prompt_eval_count: int - eval_count: int - - -class OllamaClient: - def __init__(self, base_url: str = "http://localhost:11434"): - self.base_url = base_url.rstrip("/") - - def generate_completion( - self, - prompt: str, - model: str = "llama2", - system_prompt: str | None = None, - temperature: float = 0.7, - ) -> LLMResponse: - """Generate completion using the Ollama API.""" - - data = { - "model": model, - "prompt": prompt, - "stream": False, # Disable streaming for simpler handling - "format": "json", # Request JSON format - "options": { - "temperature": temperature, - "num_predict": 1000, - "stop": ["\n\n", "```"], # Stop tokens to prevent extra content - }, - } - - if system_prompt: - data["system"] = system_prompt - - url = f"{self.base_url}/api/generate" - - try: - print("\nSending request to Ollama...") - response = requests.post(url, json=data, timeout=30) - response.raise_for_status() - - result = response.json() - - # Debug output - print(f"Raw response: {result.get('response', '')[:500]}...") - - if not result.get("response"): - raise ValueError("Empty response from model") # noqa: TRY003 - - # Try to parse the response as JSON - try: - json_content = json.loads(result["response"]) - # If successful, convert back to string with proper formatting - content = json.dumps(json_content) - except json.JSONDecodeError: - # If not valid JSON, try to extract JSON from the response - content = self._extract_json(result["response"]) - - return LLMResponse( - content=content, - total_duration=result.get("total_duration", 0), - prompt_eval_count=result.get("prompt_eval_count", 0), - eval_count=result.get("eval_count", 0), - ) - - except requests.exceptions.Timeout: - raise TimeoutError("Request to Ollama timed out") # noqa: B904, TRY003 - except requests.exceptions.RequestException as e: - raise Exception(f"Request failed: {str(e)}") # noqa: B904, TRY002, TRY003 - - def _extract_json(self, text: str) -> str: - """Extract JSON object from text.""" - try: - # Find the first opening brace - start = text.find("{") - if start == -1: - raise ValueError("No JSON object found in response") # noqa: TRY301, TRY003 - - # Keep track of braces - count = 0 - for i, char in enumerate(text[start:]): - if char == "{": - count += 1 - elif char == "}": - count -= 1 - if count == 0: - # Found complete JSON object - json_str = text[start : start + i + 1] - # Validate it's proper JSON - json.loads(json_str) - return json_str - - raise ValueError("No complete JSON object found") # noqa: TRY301, TRY003 - - except Exception as e: - raise ValueError(f"Failed to extract JSON: {str(e)}") # noqa: B904, TRY002, TRY003 - - def list_local_models(self) -> list[dict]: - """List available models.""" - try: - response = requests.get(f"{self.base_url}/api/tags", timeout=5) - response.raise_for_status() - return response.json().get("models", []) - except Exception as e: - raise Exception(f"Failed to list models: {str(e)}") # noqa: B904, TRY002, TRY003 diff --git a/promptwright/prompts.py b/promptwright/prompts.py new file mode 100644 index 0000000..22f6950 --- /dev/null +++ b/promptwright/prompts.py @@ -0,0 +1,163 @@ +SAMPLE_GENERATION_PROMPT = """I want to train a large language model and you should help me generate training data for it. Here is the system prompt of the model that tells it what it should be able to do: + + +{{{{system_prompt}}}} + + +You should now generate three training samples for the model. Each training sample should consist of a JSON object with the field "messages", which is a list of messages alternating between user and assistant roles. The first message must always be from the user, and the last one from the assistant. Depending on the use case of the system prompt, there may be multiple user and assistant messages. The format for each training sample must strictly follow this format: + +{ + "messages": [ + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] +} + +It is crucial that you respond only with valid JSON. Do not include any introductions, explanations, summaries, or additional text that is not part of the JSON object. Any non-JSON content will be considered incorrect. If you encounter issues generating valid JSON, please retry or provide a default response. + +Here are additional inputs to guide you: + +{{{{instructions}}}} +{{{{examples}}}} +{{{{subtopics}}}} + +Now, generate a single training sample in the JSON format specified above. Respond only with valid JSON.""" + +TREE_GENERATION_PROMPT = """I want to train a large language model and I am using another, bigger large language model to generate training data for this. However, if we always ask the bigger model to generate training data with the same prompt, it will end up generating very repetitive training samples. Therefore, we will slightly modify our prompt for each sampling procedure according to some aspects. For instance, when asking the model to generate news articles, we could modify the prompt to let the model tell news articles about particular topics, such as business or politics. To further generate training data, we will do this recursively, and generate submodifications to the prompt. For instance, within the domain of business, we could adapt the prompt to generate news about the stock market or business scandals, and within politics, we could ask the model to generate articles for subtopics like elections or climate policy. We do this recursively, and therefore, we get a tree-like structure of topics. +Your job is the following: I will give you a path of nodes down the topic tree - you should then come up with a list of new subtopics for this given node and return it as a python list. Here are a few examples of what your outputs should look like, related to the news example I just gave you: + +Example 1: +node path: "News Topics" -> "Sports" -> "Football" +desired number of subtopics: 5 +subtopics: ["college football", "football stadiums", "health consequences football", "Seattle Seahawks", "football sponsorships"] + + +Example 2: +node path: "News Topics" -> "Entertainment" -> "Movies" -> "Star Portraits" +desired number of subtopics: 8 +subtopics: ["Tom Hanks", "Meryl Streep", "Leonardo DiCaprio", "Jennifer Lawrence", "Denzel Washington", "Charlize Theron", "Robert Downey Jr.", "Emma Stone"] + + +Here are three new examples, this time for generating smalltalk topics for a friendly chat assistant: + +Example 1: +node path: "Small Talk Topics" +desired number of subtopics: 7 +subtopics: ["weather", "weekend plans", "hobbies", "family", "books", "food", "music"] + +Example 2: +node path: "Small Talk Topics" -> "Family" +desired number of subtopics: 5 +subtopics: ["parents", "grandparents", "siblings", "family traditions", "family vacations"] + +Example 3: +node path: "Small Talk Topics" -> "Hobbies" -> "Cooking" +desired number of subtopics: 6 +subtopics: ["recipes", "asian food", "favourite dishes", "cookbooks", "kitchen gadgets", "vegan cooking"] + + +Here is a description / the system prompt for the model we want to train: + + +{{{{system_prompt}}}} + + + +Here is your topic input. When generating subtopics, remain somewhat vague. Things can only be tangentially related and they don't have to be interpreted in a single way. Importantly, make sure that the subtopics fit the system prompt, if one was supplied: +node path: {{{{subtopics_list}}}} +desired number of subtopics: {{{{num_subtopics}}}} + +Now return the subtopics as a python list, and return it in just one line, not multiple ones. Don't return anything else.""" + +TREE_JSON_INSTRUCTIONS = """When listing subtopics, format your response as a valid JSON array of strings. +Example: ["topic 1", "topic 2", "topic 3"] +1. Use double quotes for strings +2. Use square brackets for the array +3. Separate items with commas +4. Do not include any text before or after the JSON array +5. Ensure all JSON syntax is valid +""" + +OLD_ENGINE_JSON_INSTRUCTIONS = """Your response **must be valid JSON** that can be parsed by `json.loads()`. Follow these rules precisely: + +1. **Double Quotes Only**: Use double quotes (`"`) around all string values, including keys. +2. **No Extra Text**: Do not include any text before or after the JSON block. Ensure the output is **only JSON**. +3. **Valid Syntax**: Check that all JSON syntax is correct: + - Every key-value pair should be separated by a colon. + - Separate each item in an array or object with a comma, except for the last item. +4. **No Trailing Commas**: Ensure there are no trailing commas in arrays or objects. +5. **Number Formatting**: Ensure numbers are formatted correctly (e.g., no leading zeroes unless the number is decimal). +6. **Boolean & Null Values**: Use lowercase `true`, `false`, and `null` as valid JSON values. +7. **Final Validation**: Your response will be parsed as JSON. Any syntax errors will cause a failure, so check carefully. + +**Important**: The entire response must be **valid JSON**, with no explanations, comments, or text outside of the JSON structure. +""" + +ENGINE_JSON_INSTRUCTIONS = """You are an expert JSON builder designed to assist with a wide range of tasks. + +Your response **must be valid JSON** that can be parsed by `json.loads()`. Follow these rules precisely: + +1. **Double Quotes Only**: Use double quotes (`"`) around all string values, including keys. +2. **No Extra Text**: Do not include any text before or after the JSON block. Ensure the output is **only JSON**. +3. **Valid Syntax**: Check that all JSON syntax is correct: + - Every key-value pair should be separated by a colon. + - Separate each item in an array or object with a comma, except for the last item. +4. **No Trailing Commas**: Ensure there are no trailing commas in arrays or objects. +5. **Number Formatting**: Ensure numbers are formatted correctly (e.g., no leading zeroes unless the number is decimal). +6. **Boolean & Null Values**: Use lowercase `true`, `false`, and `null` as valid JSON values. +7. **Final Validation**: Your response will be parsed as JSON. Any syntax errors will cause a failure, so check carefully. + +**Important**: The entire response must be **valid JSON**, with no explanations, comments, or text outside of the JSON structure. + +**JSON Structure**: +```json +{ + "messages": [ + { + "role": "user", + "content": "" + }, + { + "role": "assistant", + "content": "" + } + ] +} +``` + +**JSON Examples**: +```json +{ + "messages": [ + { + "role": "user", + "content": "Hey, how are you today?" + }, + { + "role": "assistant", + "content": "I'm good thanks, how are you?" + } + ] +}, +{ + "messages": [ + { + "role": "user", + "content": "What color is the sky?" + }, + { + "role": "assistant", + "content": "The sky is blue." + } + ] +} +``` + +All of Assistant's communication is performed using this JSON format. +""" diff --git a/promptwright/topic_tree.py b/promptwright/topic_tree.py index bf1c0a7..19df16a 100644 --- a/promptwright/topic_tree.py +++ b/promptwright/topic_tree.py @@ -1,169 +1,281 @@ -import ast import json +import re +import time +import warnings from dataclasses import dataclass +from typing import Any -from .ollama_client import OllamaClient +import litellm + +from .prompts import TREE_GENERATION_PROMPT, TREE_JSON_INSTRUCTIONS +from .utils import extract_list + +warnings.filterwarnings("ignore", message="Pydantic serializer warnings:.*") + + +def validate_and_clean_response(response_text: str) -> str | list[str] | None: + """Clean and validate the response from the LLM.""" + try: + # First try to extract a JSON array if present + json_match = re.search(r"\[.*\]", response_text, re.DOTALL) + if json_match: + cleaned_json = json_match.group(0) + # Remove any markdown code block markers + cleaned_json = re.sub(r"```json\s*|\s*```", "", cleaned_json) + return json.loads(cleaned_json) + + # If no JSON array found, fall back to extract_list + topics = extract_list(response_text) + if topics: + return [topic.strip() for topic in topics if topic.strip()] + return None # noqa: TRY300 + except (json.JSONDecodeError, ValueError) as e: + print(f"Error parsing response: {str(e)}") + return None @dataclass -class LocalTopicTreeArguments: +class TopicTreeArguments: + """ + A class to represent the arguments for constructing a topic tree. + + Attributes: + root_prompt (str): The initial prompt to start the topic tree. + model_system_prompt (str): The system prompt for the model. + tree_degree (int): The branching factor of the tree. + tree_depth (int): The depth of the tree. + model_name (str): The name of the model to be used. + """ + root_prompt: str - model_system_prompt: str = None + model_system_prompt: str = "" tree_degree: int = 10 tree_depth: int = 3 - ollama_base_url: str = "http://localhost:11434" - model_name: str = "llama3.2" + model_name: str = "ollama/llama3" + temperature: float = 0.2 + + +class TopicTreeValidator: + """ + TopicTreeValidator validates and calculates unique paths in a tree structure. + """ + + def __init__(self, tree_degree: int, tree_depth: int): + self.tree_degree = tree_degree + self.tree_depth = tree_depth + + def calculate_paths(self) -> int: + """Calculate total number of paths in the tree.""" + return self.tree_degree**self.tree_depth + + def validate_configuration(self, num_steps: int, batch_size: int) -> dict[str, Any]: + """Validates tree configuration and provides recommendations if invalid.""" + total_requested_paths = num_steps * batch_size + total_tree_paths = self.calculate_paths() + + print(f"Total tree paths available: {total_tree_paths}") + print(f"Total requested paths: {total_requested_paths}") + + if total_requested_paths > total_tree_paths: + print("Warning: The requested paths exceed the available tree paths.") + recommendation = { + "valid": False, + "suggested_num_steps": total_tree_paths // batch_size, + "suggested_batch_size": total_tree_paths // num_steps, + "total_tree_paths": total_tree_paths, + "total_requested_paths": total_requested_paths, + } + print("Recommended configurations to fit within the tree paths:") + print(f" - Reduce num_steps to: {recommendation['suggested_num_steps']} or") + print(f" - Reduce batch_size to: {recommendation['suggested_batch_size']} or") + print(" - Increase tree_depth or tree_degree to provide more paths.") + return recommendation + + return { + "valid": True, + "total_tree_paths": total_tree_paths, + "total_requested_paths": total_requested_paths, + } + + +class TopicTree: + """A class to represent and build a hierarchical topic tree.""" + def __init__(self, args: TopicTreeArguments): + """Initialize the TopicTree with the given arguments.""" + if not args.model_name: + raise ValueError("model_name must be specified in TopicTreeArguments") # noqa: TRY003 + json_instructions = TREE_JSON_INSTRUCTIONS -class LocalTopicTree: - def __init__(self, args: LocalTopicTreeArguments): self.args = args + self.system_prompt = json_instructions + args.model_system_prompt + self.temperature = args.temperature + self.model_name = args.model_name + self.tree_degree = args.tree_degree + self.tree_depth = args.tree_depth self.tree_paths = [] - self.llm_client = OllamaClient(base_url=args.ollama_base_url) + self.failed_generations = [] + + def build_tree(self, model_name: str = None) -> None: + """Build the complete topic tree.""" + if model_name: + self.model_name = model_name + + print(f"Building the topic tree with model: {self.model_name}") - def _extract_list_from_response(self, response: str) -> list[str]: - """Extract a Python list from the response text, with multiple fallback methods.""" - # First, try to find a list in JSON format - try: - # Try to parse as JSON first - data = json.loads(response) - if isinstance(data, list): - return data - # If it's a JSON object, look for a list value - for value in data.values(): - if isinstance(value, list): - return value - except json.JSONDecodeError: - pass - - # Second, try to find and parse a Python list literal try: - # Find content between square brackets - start = response.find("[") - end = response.rfind("]") - if start != -1 and end != -1: - list_str = response[start : end + 1] - return ast.literal_eval(list_str) - except (SyntaxError, ValueError): - pass - - # Third, try to split by commas if the response looks like a comma-separated list - if "," in response and "[" not in response and "{" not in response: + self.tree_paths = self.build_subtree( + [self.args.root_prompt], + self.system_prompt, + self.args.tree_degree, + self.args.tree_depth, + model_name=self.model_name, + ) + + print(f"Tree building complete. Generated {len(self.tree_paths)} paths.") + if self.failed_generations: + print(f"Warning: {len(self.failed_generations)} subtopic generations failed.") + + except Exception as e: + print(f"Error building tree: {str(e)}") + if self.tree_paths: + print("Saving partial tree...") + self.save("partial_tree.jsonl") + raise + + def get_subtopics( + self, system_prompt: str, node_path: list[str], num_subtopics: int + ) -> list[str]: + """Generate subtopics with improved error handling and validation.""" + print(f"Generating {num_subtopics} subtopics for: {' -> '.join(node_path)}") + + prompt = TREE_GENERATION_PROMPT + prompt = prompt.replace("{{{{system_prompt}}}}", system_prompt if system_prompt else "") + prompt = prompt.replace("{{{{subtopics_list}}}}", " -> ".join(node_path)) + prompt = prompt.replace("{{{{num_subtopics}}}}", str(num_subtopics)) + + max_retries = 3 + retries = 0 + last_error = "No error recorded" + + while retries < max_retries: try: - items = [item.strip().strip("\"'") for item in response.split(",")] - if items: - return items + response = litellm.completion( + model=self.model_name, + max_tokens=1000, + temperature=self.temperature, + base_url="http://localhost:11434", + messages=[{"role": "user", "content": prompt}], + ) + + subtopics = validate_and_clean_response(response.choices[0].message.content) + + if subtopics and len(subtopics) > 0: + # Validate and clean each subtopic + cleaned_subtopics = [] + for topic in subtopics: + if isinstance(topic, str): + # Keep more special characters but ensure JSON safety + cleaned_topic = topic.strip() + if cleaned_topic: + cleaned_subtopics.append(cleaned_topic) + + if len(cleaned_subtopics) >= num_subtopics: + return cleaned_subtopics[:num_subtopics] + + last_error = "Insufficient valid subtopics generated" + print(f"Attempt {retries + 1}: {last_error}. Retrying...") + except Exception as e: - print(f"Error parsing comma-separated list: {str(e)}") + last_error = str(e) + print( + f"Error generating subtopics (attempt {retries + 1}/{max_retries}): {last_error}" + ) - raise ValueError(f"Could not extract list from response: {response}") # noqa: TRY003 + retries += 1 + if retries < max_retries: + time.sleep(2**retries) # Exponential backoff - def build_tree(self): - """Build the topic tree.""" - print( - f"\nBuilding topic tree with degree {self.args.tree_degree} and depth {self.args.tree_depth}" + # If all retries failed, generate default subtopics and log the failure + default_subtopics = [f"subtopic_{i+1}_for_{node_path[-1]}" for i in range(num_subtopics)] + self.failed_generations.append( + {"path": node_path, "attempts": retries, "last_error": last_error} ) - self.tree_paths = self.build_subtree( - [self.args.root_prompt], - self.args.model_system_prompt, - self.args.tree_degree, - self.args.tree_depth, + print( + f"Failed to generate valid subtopics after {max_retries} attempts. Using default subtopics." ) - print(f"Tree building complete. Generated {len(self.tree_paths)} paths.") + return default_subtopics def build_subtree( - self, node_path: list[str], system_prompt: str, tree_degree: int, subtree_depth: int + self, + node_path: list[str], + system_prompt: str, + tree_degree: int, + subtree_depth: int, + model_name: str, ) -> list[list[str]]: - """Build a subtree recursively.""" - print(f"Building subtree for path: {' -> '.join(node_path)}") + """Build a subtree with improved error handling and validation.""" + # Convert any non-string elements to strings + node_path = [str(node) if not isinstance(node, str) else node for node in node_path] + print(f"Building topic subtree: {' -> '.join(node_path)}") if subtree_depth == 0: return [node_path] - try: - subnodes = self.get_subtopics( - system_prompt=system_prompt, node_path=node_path, num_subtopics=tree_degree - ) + subnodes = self.get_subtopics(system_prompt, node_path, tree_degree) - if not subnodes: - print(f"Warning: No subtopics generated for path: {' -> '.join(node_path)}") - return [node_path] - - updated_node_paths = [node_path + [sub] for sub in subnodes] - result = [] + # Clean and validate subnodes + cleaned_subnodes = [] + for subnode in subnodes: + try: + if isinstance(subnode, dict | list): + cleaned_subnodes.append(json.dumps(subnode)) + else: + cleaned_subnodes.append(str(subnode)) + except Exception as e: + print(f"Error cleaning subnode: {str(e)}") + continue - for path in updated_node_paths: + result = [] + for subnode in cleaned_subnodes: + try: + new_path = node_path + [subnode] result.extend( - self.build_subtree(path, system_prompt, tree_degree, subtree_depth - 1) + self.build_subtree( + new_path, system_prompt, tree_degree, subtree_depth - 1, model_name + ) ) - return result # noqa: TRY300 - - except Exception as e: - print(f"Error building subtree for path {' -> '.join(node_path)}: {str(e)}") - return [node_path] - - def get_subtopics( - self, system_prompt: str, node_path: list[str], num_subtopics: int - ) -> list[str]: - """Get subtopics for a given node.""" - prompt = f"""Generate exactly {num_subtopics} subtopics about: {' -> '.join(node_path)} - -Requirements: -1. Return ONLY a Python list of strings -2. Each subtopic should be short and focused -3. No explanations or additional text -4. No numbered bullets or formatting -5. No nested lists or dictionaries - -Example output format: -["Subtopic 1", "Subtopic 2", "Subtopic 3"] + except Exception as e: + print(f"Error building subtree for {subnode}: {str(e)}") + continue -Generate {num_subtopics} subtopics:""" + return result + def save(self, save_path: str) -> None: + """Save the topic tree to a file.""" try: - response = self.llm_client.generate_completion( - prompt=prompt, - model=self.args.model_name, - system_prompt=system_prompt - or "You are a helpful assistant that generates lists of subtopics.", - temperature=0.7, - ) - - subtopics = self._extract_list_from_response(response.content) - - # Validate and clean the subtopics - cleaned_subtopics = [] - for topic in subtopics: - if isinstance(topic, str): - # Remove any quotes, brackets, or list formatting - cleaned = topic.strip(" []\"'") - if cleaned: - cleaned_subtopics.append(cleaned) + with open(save_path, "w") as f: + for path in self.tree_paths: + f.write(json.dumps({"path": path}) + "\n") - # Ensure we have the right number of subtopics - if len(cleaned_subtopics) < num_subtopics: - print( - f"Warning: Only generated {len(cleaned_subtopics)} subtopics instead of {num_subtopics}" - ) + # Save failed generations if any + if self.failed_generations: + failed_path = save_path.replace(".jsonl", "_failed.jsonl") + with open(failed_path, "w") as f: + for failure in self.failed_generations: + f.write(json.dumps(failure) + "\n") + print(f"Failed generations saved to {failed_path}") - return cleaned_subtopics[:num_subtopics] + print(f"Topic tree saved to {save_path}") + print(f"Total paths: {len(self.tree_paths)}") except Exception as e: - print(f"Error generating subtopics: {str(e)}") - print(f"Response content: {getattr(response, 'content', 'No content available')}") + print(f"Error saving topic tree: {str(e)}") raise - def save(self, save_path: str): - """Save the topic tree to a file.""" - with open(save_path, "w") as f: - for path in self.tree_paths: - f.write(json.dumps({"path": path}) + "\n") - print(f"\nTopic tree saved to {save_path}") - print(f"Total paths: {len(self.tree_paths)}") - - def print_tree(self): + def print_tree(self) -> None: """Print the topic tree in a readable format.""" - print("\nTopic Tree Structure:") + print("Topic Tree Structure:") for path in self.tree_paths: print(" -> ".join(path)) diff --git a/promptwright/utils.py b/promptwright/utils.py new file mode 100644 index 0000000..e0bb0ff --- /dev/null +++ b/promptwright/utils.py @@ -0,0 +1,98 @@ +import ast +import json +import re + + +def extract_list(input_string: str): + """ + Extracts a Python list from a given input string. + + This function attempts to parse the input string as JSON. If that fails, + it searches for the first Python list within the string by identifying + the opening and closing brackets. If a list is found, it is evaluated + safely to ensure it is a valid Python list. + + Args: + input_string (str): The input string potentially containing a Python list. + + Returns: + list: The extracted Python list if found and valid, otherwise an empty list. + + Raises: + None: This function handles its own exceptions and does not raise any. + """ + try: + return json.loads(input_string) + except json.JSONDecodeError: + print("Failed to parse the input string as JSON.") + + start = input_string.find("[") + if start == -1: + print("No Python list found in the input string.") + return [] + + count = 0 + for i, char in enumerate(input_string[start:]): + if char == "[": + count += 1 + elif char == "]": + count -= 1 + if count == 0: + end = i + start + 1 + break + else: + print("No matching closing bracket found.") + return [] + + found_list_str = input_string[start:end] + found_list = safe_literal_eval(found_list_str) + if found_list is None: + print("Failed to parse the list due to syntax issues.") + return [] + + return found_list + + +def remove_linebreaks_and_spaces(input_string): + """ + Remove line breaks and extra spaces from the input string. + + This function replaces all whitespace characters (including line breaks) + with a single space and then ensures that there are no consecutive spaces + in the resulting string. + + Args: + input_string (str): The string from which to remove line breaks and extra spaces. + + Returns: + str: The processed string with line breaks and extra spaces removed. + """ + no_linebreaks = re.sub(r"\s+", " ", input_string) + return " ".join(no_linebreaks.split()) + + +def safe_literal_eval(list_string: str): + """ + Safely evaluate a string containing a Python literal expression. + + This function attempts to evaluate a string containing a Python literal + expression using `ast.literal_eval`. If a `SyntaxError` or `ValueError` + occurs, it tries to sanitize the string by replacing problematic apostrophes + with the actual right single quote character and attempts the evaluation again. + + Args: + list_string (str): The string to be evaluated. + + Returns: + The result of the evaluated string if successful, otherwise `None`. + """ + try: + return ast.literal_eval(list_string) + except (SyntaxError, ValueError): + # Replace problematic apostrophes with the actual right single quote character + sanitized_string = re.sub(r"(\w)'(\w)", r"\1’\2", list_string) + try: + return ast.literal_eval(sanitized_string) + except (SyntaxError, ValueError): + print("Failed to parse the list due to syntax issues.") + return None diff --git a/requirements.txt b/requirements.txt index 6447806..f947b8c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,9 @@ -certifi==2024.8.30 +certifi==2023.11.17 # Changed to be compatible with litellm's requirement charset-normalizer==3.4.0 idna==3.10 requests==2.32.3 tqdm==4.66.5 urllib3==2.2.3 huggingface-hub==0.26.0 -datasets==3.0.2 \ No newline at end of file +datasets==3.0.2 +litellm==1.7.12 diff --git a/setup.py b/setup.py index 46527b7..bf25937 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ version="0.1.5", packages=find_packages(), install_requires=[ - "certifi==2024.8.30", + "certifi>=2023.7.22,<2024.0.0", # Updated to be compatible with litellm "charset-normalizer==3.4.0", "idna==3.10", "requests==2.32.3", @@ -16,6 +16,7 @@ "urllib3==2.2.3", "huggingface-hub==0.26.0", "datasets==3.0.2", + "litellm==1.7.12", ], extras_require={ "dev": [ diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index e596fe3..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,90 +0,0 @@ -import time - -from collections.abc import Generator - -import pytest -import requests - - -def is_ollama_ready(base_url: str = "http://localhost:11434", timeout: float = 1.0) -> bool: - """Check if Ollama server is responsive.""" - try: - response = requests.get(f"{base_url}/api/tags", timeout=timeout) - return response.status_code == 200 # noqa: TRY300, PLR2004 - except requests.RequestException: - return False - - -def check_model_available(model_name: str, base_url: str = "http://localhost:11434") -> bool: - """Check if specific model is available.""" - try: - response = requests.get(f"{base_url}/api/tags") # noqa: S113 - models = [model["name"] for model in response.json().get("models", [])] - return model_name in models # noqa: TRY300 - except requests.RequestException: - return False - - -@pytest.fixture(scope="session") -def ensure_ollama() -> Generator[None, None, None]: - """Ensure Ollama is running and the required model is available.""" - # Check if Ollama is running - max_retries = 5 - retry_delay = 2 - - print("\nChecking Ollama availability...") - for attempt in range(max_retries): - if is_ollama_ready(): - break - if attempt < max_retries - 1: - print(f"Ollama not ready, retrying in {retry_delay} seconds...") - time.sleep(retry_delay) - else: - pytest.skip("Ollama server is not available") - - # Check if required model is available - model_name = "llama3:latest" - if not check_model_available(model_name): - pytest.skip(f"{model_name} model is not available. Please run 'ollama pull {model_name}'") - - yield - - -@pytest.fixture(scope="session") -def model_name() -> str: - """Provide the model name for tests.""" - return "llama3:latest" - - -@pytest.fixture(scope="session") -def ollama_base_url() -> str: - """Provide the base URL for Ollama service.""" - return "http://localhost:11434" - - -# Add mock fixtures for unit tests -@pytest.fixture -def mock_ollama_response(): - """Mock Ollama API response.""" - return { - "model": "llama3:latest", - "response": '{"messages": [{"role": "user", "content": "test"}]}', - "total_duration": 1000000, - "prompt_eval_count": 10, - "eval_count": 20, - } - - -@pytest.fixture -def mock_ollama_client(mock_ollama_response): - """Mock OllamaClient.""" - from unittest.mock import Mock - - client = Mock() - client.generate_completion.return_value = Mock( - content=mock_ollama_response["response"], - total_duration=mock_ollama_response["total_duration"], - prompt_eval_count=mock_ollama_response["prompt_eval_count"], - eval_count=mock_ollama_response["eval_count"], - ) - return client diff --git a/tests/test_engine.py b/tests/test_engine.py index 523c17c..5c4d8b7 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,56 +1,112 @@ +from unittest.mock import MagicMock, patch + import pytest -from promptwright import LocalDataEngine, LocalEngineArguments # Updated import +from promptwright.engine import DataEngine, Dataset, EngineArguments -def test_engine_initialization(): - """Test LocalDataEngine initialization.""" - args = LocalEngineArguments( +@pytest.fixture +def engine_args(): + return EngineArguments( instructions="Test instructions", system_prompt="Test system prompt", - model_name="llama3:latest", + model_name="test-model", + prompt_template=None, + example_data=None, + temperature=0.7, + max_retries=3, + default_batch_size=5, + default_num_examples=3, + request_timeout=30, ) - engine = LocalDataEngine(args) - assert engine.args == args - assert len(engine.dataset) == 0 + +@pytest.fixture +def data_engine(engine_args): + return DataEngine(engine_args) + + +def test_engine_initialization(engine_args): + engine = DataEngine(engine_args) + assert engine.args == engine_args + assert isinstance(engine.dataset, Dataset) assert engine.failed_samples == [] -@pytest.mark.usefixtures("mock_ollama_client") -def test_engine_create_data(mock_ollama_client): - """Test create_data method.""" - args = LocalEngineArguments( - instructions="Test instructions", - system_prompt="Test system prompt", - model_name="llama3:latest", - ) +def test_create_data_no_steps(data_engine): + with pytest.raises(ValueError, match="num_steps must be specified"): + data_engine.create_data() - engine = LocalDataEngine(args) - engine.llm_client = mock_ollama_client - dataset = engine.create_data(num_steps=1, batch_size=1) - assert len(dataset) == 1 +@patch("promptwright.engine.litellm.batch_completion") +def test_create_data_success(mock_batch_completion, data_engine): + # Mock valid JSON responses to match the expected structure for 10 samples + mock_batch_completion.return_value = [ + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"messages": [{"role": "user", "content": "example"}, {"role": "assistant", "content": "response"}]}' + ) + ) + ] + ) + ] * 10 # Mock 10 responses to match the batch size + topic_tree = MagicMock() + topic_tree.tree_paths = [ + "path1", + "path2", + "path3", + "path4", + "path5", + "path6", + "path7", + "path8", + "path9", + "path10", + ] -def test_engine_validation(): - """Test sample validation in engine.""" - args = LocalEngineArguments( - instructions="Test instructions", - system_prompt="Test system prompt", - model_name="llama3:latest", - ) + # Define a constant for the expected number of samples + expected_num_samples = 10 + + # Generate the data + dataset = data_engine.create_data(num_steps=1, batch_size=10, topic_tree=topic_tree) + + # Assert that the dataset contains exactly the expected number of samples + assert len(dataset.samples) == expected_num_samples + + +def test_build_prompt(data_engine): + prompt = data_engine.build_prompt("Test prompt", 3, ["subtopic1", "subtopic2"]) + assert "{{system_prompt}}" not in prompt + assert "{{instructions}}" not in prompt + assert "{{examples}}" not in prompt + assert "{{subtopics}}" not in prompt + + +def test_build_system_prompt(data_engine): + system_prompt = data_engine.build_system_prompt() + assert system_prompt == data_engine.args.system_prompt + + +def test_build_custom_instructions_text(data_engine): + instructions_text = data_engine.build_custom_instructions_text() + assert "" in instructions_text + assert data_engine.args.instructions in instructions_text + + +def test_build_examples_text_no_examples(data_engine): + examples_text = data_engine.build_examples_text(3) + assert examples_text == "" - engine = LocalDataEngine(args) - valid_sample = { - "messages": [ - {"role": "user", "content": "test"}, - {"role": "assistant", "content": "response"}, - ] - } +def test_build_subtopics_text(data_engine): + subtopics_text = data_engine.build_subtopics_text(["subtopic1", "subtopic2"]) + assert "subtopic1 -> subtopic2" in subtopics_text - invalid_sample = {"messages": [{"role": "invalid", "content": ""}]} - assert engine._validate_sample(valid_sample) is True - assert engine._validate_sample(invalid_sample) is False +@patch.object(Dataset, "save") +def test_save_dataset(mock_save, data_engine): + data_engine.save_dataset("test_path.jsonl") + mock_save.assert_called_once_with("test_path.jsonl") diff --git a/tests/test_ollama_client.py b/tests/test_ollama_client.py deleted file mode 100644 index eb0f2b8..0000000 --- a/tests/test_ollama_client.py +++ /dev/null @@ -1,49 +0,0 @@ -from unittest.mock import Mock - -import pytest -import requests - -from promptwright import OllamaClient - - -def test_ollama_client_initialization(): - """Test OllamaClient initialization.""" - client = OllamaClient() - assert client.base_url == "http://localhost:11434" - - client = OllamaClient("http://custom:1234") - assert client.base_url == "http://custom:1234" - - -def test_generate_completion(mock_ollama_response, mocker): - """Test generate_completion method.""" - # Setup mock - mock_post = mocker.patch("requests.post") - mock_post.return_value = Mock() - mock_post.return_value.json.return_value = mock_ollama_response - mock_post.return_value.raise_for_status = Mock() - - client = OllamaClient() - response = client.generate_completion(prompt="Test prompt", model="llama3:latest") - - assert response.content is not None - assert isinstance(response.content, str) - mock_post.assert_called_once() - - # Verify the request - call_args = mock_post.call_args - assert call_args is not None - args, kwargs = call_args - assert args[0].endswith("/generate") # Verify endpoint - assert kwargs["json"]["prompt"] == "Test prompt" # Verify prompt - assert kwargs["json"]["model"] == "llama3:latest" # Verify model - - -def test_generate_completion_error(mocker): - """Test generate_completion error handling.""" - # Setup mock to raise Timeout - _mock_post = mocker.patch("requests.post", side_effect=requests.exceptions.Timeout()) - - client = OllamaClient() - with pytest.raises(TimeoutError): - client.generate_completion(prompt="Test prompt", model="llama3:latest")