diff --git a/README.md b/README.md index e44c0e8..52f7032 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ just open an issue). - **YAML Configuration**: Define your generation tasks using YAML configuration files - **Command Line Interface**: Run generation tasks directly from the command line - **Push to Hugging Face**: Push the generated dataset to Hugging Face Hub with automatic dataset cards and tags +- **System Message Control**: Choose whether to include system messages in the generated dataset ## Getting Started @@ -95,6 +96,7 @@ dataset: num_steps: 5 batch_size: 1 model_name: "ollama/mistral:latest" + sys_msg: true # Include system message in dataset (default: true) save_as: "basic_prompt_dataset.jsonl" # Optional Hugging Face Hub configuration @@ -128,6 +130,7 @@ promptwright start config.yaml \ --tree-depth 3 \ --num-steps 10 \ --batch-size 2 \ + --sys-msg true \ # Control system message inclusion (default: true) --hf-repo username/dataset-name \ --hf-token your-token \ --hf-tags tag1 --hf-tags tag2 @@ -185,6 +188,7 @@ engine = DataEngine( model_name="ollama/llama3", temperature=0.9, max_retries=2, + sys_msg=True, # Include system message in dataset (default: true) ) ) ``` @@ -218,6 +222,7 @@ make all ### Prompt Output Examples +With sys_msg=true (default): ```json { "messages": [ @@ -237,6 +242,22 @@ make all } ``` +With sys_msg=false: +```json +{ + "messages": [ + { + "role": "user", + "content": "Create a descriptive passage about a character discovering their hidden talents." + }, + { + "role": "assistant", + "content": "As she stared at the canvas, Emma's fingers hovered above the paintbrushes, as if hesitant to unleash the colors that had been locked within her. The strokes began with bold abandon, swirling blues and greens merging into a mesmerizing dance of light and shadow. With each passing moment, she felt herself becoming the art – her very essence seeping onto the canvas like watercolors in a spring storm. The world around her melted away, leaving only the vibrant symphony of color and creation." + } + ] +} +``` + ## Model Compatibility The library should work with most LLM models. It has been tested with the diff --git a/examples/example_basic_prompt.py b/examples/example_basic_prompt.py index 6638d31..a6c6e6c 100644 --- a/examples/example_basic_prompt.py +++ b/examples/example_basic_prompt.py @@ -29,6 +29,7 @@ model_name="ollama/mistral-nemo:latest", # Model name temperature=0.9, # Higher temperature for more creative variations max_retries=2, # Retry failed prompts up to 2 times + sys_msg=True, # Include system message in dataset (default: true) ) ) diff --git a/examples/example_basic_prompt.yaml b/examples/example_basic_prompt.yaml index 28a231c..be12d73 100644 --- a/examples/example_basic_prompt.yaml +++ b/examples/example_basic_prompt.yaml @@ -27,4 +27,5 @@ dataset: batch_size: 1 provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name + sys_msg: true # Include system message in dataset (default: true) save_as: "basic_prompt_dataset.jsonl" diff --git a/examples/example_culinary_database.yaml b/examples/example_culinary_database.yaml index c669978..347e493 100644 --- a/examples/example_culinary_database.yaml +++ b/examples/example_culinary_database.yaml @@ -1,13 +1,14 @@ # Example YAML configuration for basic prompt generation -system_prompt: "You are a helpful assistant. You provide clear and concise answers to user questions." +system_prompt: | + You are a culinary expert who documents recipes and cooking techniques. + Your entries should be detailed, precise, and include both traditional and modern cooking methods. topic_tree: args: - root_prompt: "You are a culinary expert who documents recipes and cooking techniques. -Your entries should be detailed, precise, and include both traditional and modern cooking methods." + root_prompt: "Global Cuisine and Cooking Techniques" model_system_prompt: "" # Will be replaced with system_prompt - tree_degree: 5 # Different continents - tree_depth: 3 # Deeper tree for more specific topics + tree_degree: 5 # Different cuisine types + tree_depth: 3 # Specific dishes and techniques temperature: 0.7 # Higher temperature for more creative variations provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name @@ -35,4 +36,5 @@ dataset: batch_size: 1 provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name + sys_msg: true # Include system message in dataset (default: true) save_as: "culinary_database.jsonl" diff --git a/examples/example_historic_figures.yaml b/examples/example_historic_figures.yaml index 7f2d2cb..f8d85df 100644 --- a/examples/example_historic_figures.yaml +++ b/examples/example_historic_figures.yaml @@ -5,11 +5,11 @@ system_prompt: | topic_tree: args: - root_prompt: "Capital Cities of the World." + root_prompt: "Notable Historical Figures Across Different Eras and Fields" model_system_prompt: "" # Will be replaced with system_prompt - tree_degree: 3 # Different continents - tree_depth: 2 # Deeper tree for more specific topics - temperature: 0.7 # Higher temperature for more creative variations + tree_degree: 4 # Different categories + tree_depth: 3 # Deeper tree for more specific figures + temperature: 0.6 # Balanced temperature for creativity and accuracy provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name save_as: "historical_figures_tree.jsonl" @@ -23,7 +23,7 @@ data_engine: system_prompt: "" # Will be replaced with system_prompt provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name - temperature: 0.9 # Higher temperature for more creative variations + temperature: 0.7 # Balance between creativity and accuracy max_retries: 2 # Retry failed prompts up to 2 times dataset: @@ -32,4 +32,5 @@ dataset: batch_size: 1 provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name - save_as: "basic_prompt_dataset.jsonl" + sys_msg: true # Include system message in dataset (default: true) + save_as: "historical_figures_database.jsonl" diff --git a/examples/example_programming_challenges.py.yaml b/examples/example_programming_challenges.py.yaml index 46ce3aa..f4926be 100644 --- a/examples/example_programming_challenges.py.yaml +++ b/examples/example_programming_challenges.py.yaml @@ -7,12 +7,12 @@ topic_tree: args: root_prompt: "Programming Challenges Across Different Difficulty Levels and Concepts" model_system_prompt: "" # Will be replaced with system_prompt - tree_degree: 3 # Different continents - tree_depth: 2 # Deeper tree for more specific topics - temperature: 0.7 # Higher temperature for more creative variations + tree_degree: 4 # Different programming concepts + tree_depth: 2 # Various difficulty levels + temperature: 0.7 # Higher temperature for creative problem scenarios provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name - save_as: "basic_prompt_topictree.jsonl" + save_as: "programming_challenges_tree.jsonl" data_engine: args: @@ -27,7 +27,7 @@ data_engine: system_prompt: "" # Will be replaced with system_prompt provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name - temperature: 0.9 # Higher temperature for more creative variations + temperature: 0.8 # Higher temperature for creative problem scenarios max_retries: 2 # Retry failed prompts up to 2 times dataset: @@ -36,4 +36,5 @@ dataset: batch_size: 1 provider: "ollama" # LLM provider model: "mistral-nemo:latest" # Model name - save_as: "basic_prompt_dataset.jsonl" + sys_msg: true # Include system message in dataset (default: true) + save_as: "programming_challenges.jsonl" diff --git a/examples/example_with_hf.yaml b/examples/example_with_hf.yaml index 02306d9..fe354cb 100644 --- a/examples/example_with_hf.yaml +++ b/examples/example_with_hf.yaml @@ -23,6 +23,7 @@ dataset: num_steps: 5 batch_size: 1 model_name: "ollama/mistral:latest" + sys_msg: true # Include system message in dataset (default: true) save_as: "basic_prompt_dataset.jsonl" # Hugging Face Hub configuration (optional) diff --git a/promptwright/cli.py b/promptwright/cli.py index b0f42d1..f742f76 100644 --- a/promptwright/cli.py +++ b/promptwright/cli.py @@ -47,6 +47,11 @@ def cli(): multiple=True, help="Additional tags for the dataset (can be specified multiple times)", ) +@click.option( + "--sys-msg", + type=bool, + help="Include system message in dataset (default: true)", +) def start( # noqa: PLR0912 config_file: str, topic_tree_save_as: str | None = None, @@ -61,6 +66,7 @@ def start( # noqa: PLR0912 hf_repo: str | None = None, hf_token: str | None = None, hf_tags: list[str] | None = None, + sys_msg: bool | None = None, ) -> None: """Generate training data from a YAML configuration file.""" try: @@ -150,6 +156,7 @@ def start( # noqa: PLR0912 batch_size=batch_size or dataset_params.get("batch_size", 1), topic_tree=tree, model_name=model_name, + sys_msg=sys_msg, # Pass sys_msg to create_data ) except Exception as e: handle_error( diff --git a/promptwright/config.py b/promptwright/config.py index 9c6a7d4..716f7f5 100644 --- a/promptwright/config.py +++ b/promptwright/config.py @@ -86,12 +86,17 @@ def get_engine_args(self, **overrides) -> EngineArguments: # Construct full model string args["model_name"] = construct_model_string(provider, model) + # Get sys_msg from dataset config, defaulting to True + dataset_config = self.get_dataset_config() + sys_msg = dataset_config.get("creation", {}).get("sys_msg", True) + return EngineArguments( instructions=args.get("instructions", ""), system_prompt=args.get("system_prompt", ""), model_name=args["model_name"], temperature=args.get("temperature", 0.9), max_retries=args.get("max_retries", 2), + sys_msg=sys_msg, ) def get_dataset_config(self) -> dict: diff --git a/promptwright/engine.py b/promptwright/engine.py index b71b39f..8111259 100644 --- a/promptwright/engine.py +++ b/promptwright/engine.py @@ -54,6 +54,7 @@ class EngineArguments: default_batch_size: int = 5 default_num_examples: int = 3 request_timeout: int = 30 + sys_msg: bool = True # Default to True for including system message class DataEngine: @@ -81,7 +82,10 @@ def __init__(self, args: EngineArguments): "malformed_responses": [], "other_errors": [], } - self.args.system_prompt = ENGINE_JSON_INSTRUCTIONS + self.args.system_prompt + # Store original system prompt for dataset inclusion + self.original_system_prompt = args.system_prompt + # Use ENGINE_JSON_INSTRUCTIONS only for generation prompt + self.generation_system_prompt = ENGINE_JSON_INSTRUCTIONS + args.system_prompt def analyze_failure(self, response_content: str, error: Exception = None) -> str: """Analyze the failure reason for a sample.""" @@ -134,6 +138,7 @@ def create_data( # noqa: PLR0912 batch_size: int = 10, topic_tree: TopicTree = None, model_name: str = None, + sys_msg: bool = None, # Allow overriding sys_msg from args ): if num_steps is None: raise ValueError("num_steps must be specified") # noqa: TRY003 @@ -144,6 +149,9 @@ def create_data( # noqa: PLR0912 if not self.model_name: raise ValueError("No valid model_name provided") # noqa: TRY003 + # Use provided sys_msg or fall back to args.sys_msg + include_sys_msg = sys_msg if sys_msg is not None else self.args.sys_msg + data_creation_prompt = SAMPLE_GENERATION_PROMPT tree_paths = None @@ -204,6 +212,17 @@ def create_data( # noqa: PLR0912 response_content = r.choices[0].message.content parsed_json = validate_json_response(response_content) + if parsed_json and include_sys_msg: # noqa: SIM102 + # Add system message at the start if sys_msg is True + if "messages" in parsed_json: + parsed_json["messages"].insert( + 0, + { + "role": "system", + "content": self.original_system_prompt, + }, + ) + if parsed_json: samples.append(parsed_json) else: @@ -284,7 +303,7 @@ def build_prompt( subtopics_list: list[str] = None, ) -> str: prompt = data_creation_prompt.replace( - "{{{{system_prompt}}}}", self.build_system_prompt() + "{{{{system_prompt}}}}", self.generation_system_prompt ) prompt = prompt.replace( "{{{{instructions}}}}", self.build_custom_instructions_text() @@ -297,7 +316,8 @@ def build_prompt( ) def build_system_prompt(self): - return self.args.system_prompt + """Return the original system prompt for dataset inclusion.""" + return self.original_system_prompt def build_custom_instructions_text(self) -> str: if self.args.instructions is None: diff --git a/promptwright/hf_hub.py b/promptwright/hf_hub.py index e3e2f76..e186e2c 100644 --- a/promptwright/hf_hub.py +++ b/promptwright/hf_hub.py @@ -48,8 +48,8 @@ def update_dataset_card(self, repo_id: str, tags: list[str] | None = None): try: card = DatasetCard.load(repo_id) - # Initialize tags if not a list - if not isinstance(card.data.tags, list): + # Initialize tags if not present + if not hasattr(card.data, "tags") or not isinstance(card.data.tags, list): card.data.tags = [] # Add default promptwright tags diff --git a/pyproject.toml b/pyproject.toml index 4ebdfb3..0ef3a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "promptwright" -version = "1.1.1" +version = "1.2.1" description = "LLM based Synthetic Data Generation" authors = ["Luke Hinds "] readme = "README.md" diff --git a/tests/test_cli.py b/tests/test_cli.py index cdeeb12..87ebf33 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,6 +23,40 @@ def sample_yaml_content(): """Sample YAML content for testing.""" return """ system_prompt: "Test system prompt" +topic_tree: + args: + root_prompt: "Test root prompt" + model_system_prompt: "" + tree_degree: 3 + tree_depth: 2 + temperature: 0.7 + provider: "test" + model: "model" + save_as: "test_tree.jsonl" +data_engine: + args: + instructions: "Test instructions" + system_prompt: "" + provider: "test" + model: "model" + temperature: 0.9 + max_retries: 2 +dataset: + creation: + num_steps: 5 + batch_size: 1 + provider: "test" + model: "model" + sys_msg: true + save_as: "test_dataset.jsonl" +""" + + +@pytest.fixture +def sample_yaml_content_no_sys_msg(): + """Sample YAML content without sys_msg setting.""" + return """ +system_prompt: "Test system prompt" topic_tree: args: root_prompt: "Test root prompt" @@ -65,6 +99,20 @@ def sample_config_file(sample_yaml_content): os.unlink(temp_path) +@pytest.fixture +def sample_config_file_no_sys_msg(sample_yaml_content_no_sys_msg): + """Create a temporary config file without sys_msg setting.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + f.write(sample_yaml_content_no_sys_msg) + temp_path = f.name + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_cli_help(cli_runner): """Test CLI help command.""" result = cli_runner.invoke(cli, ["--help"]) @@ -77,6 +125,7 @@ def test_start_help(cli_runner): result = cli_runner.invoke(cli, ["start", "--help"]) assert result.exit_code == 0 assert "Generate training data from a YAML configuration file" in result.output + assert "--sys-msg" in result.output @patch("promptwright.cli.TopicTree") @@ -109,6 +158,66 @@ def test_start_command_basic( mock_dataset.save.assert_called_once() +@patch("promptwright.cli.TopicTree") +@patch("promptwright.cli.DataEngine") +def test_start_command_with_sys_msg_override( + mock_data_engine, mock_topic_tree, cli_runner, sample_config_file +): + """Test start command with sys_msg override.""" + # Setup mocks + mock_tree_instance = Mock() + mock_engine_instance = Mock() + mock_dataset = Mock() + + mock_topic_tree.return_value = mock_tree_instance + mock_data_engine.return_value = mock_engine_instance + mock_engine_instance.create_data.return_value = mock_dataset + + # Run command with sys_msg override + result = cli_runner.invoke( + cli, + [ + "start", + sample_config_file, + "--sys-msg", + "false", + ], + ) + + # Verify command executed successfully + assert result.exit_code == 0 + + # Verify create_data was called with sys_msg=False + args, kwargs = mock_engine_instance.create_data.call_args + assert kwargs["sys_msg"] is False + + +@patch("promptwright.cli.TopicTree") +@patch("promptwright.cli.DataEngine") +def test_start_command_default_sys_msg( + mock_data_engine, mock_topic_tree, cli_runner, sample_config_file_no_sys_msg +): + """Test start command with default sys_msg behavior.""" + # Setup mocks + mock_tree_instance = Mock() + mock_engine_instance = Mock() + mock_dataset = Mock() + + mock_topic_tree.return_value = mock_tree_instance + mock_data_engine.return_value = mock_engine_instance + mock_engine_instance.create_data.return_value = mock_dataset + + # Run command without sys_msg override + result = cli_runner.invoke(cli, ["start", sample_config_file_no_sys_msg]) + + # Verify command executed successfully + assert result.exit_code == 0 + + # Verify create_data was called with default sys_msg (should be None to use engine default) + args, kwargs = mock_engine_instance.create_data.call_args + assert "sys_msg" not in kwargs or kwargs["sys_msg"] is None + + @patch("promptwright.cli.TopicTree") @patch("promptwright.cli.DataEngine") def test_start_command_with_overrides( @@ -148,6 +257,8 @@ def test_start_command_with_overrides( "10", "--batch-size", "2", + "--sys-msg", + "false", ], ) @@ -169,6 +280,7 @@ def test_start_command_with_overrides( assert kwargs["num_steps"] == 10 # noqa: PLR2004 assert kwargs["batch_size"] == 2 # noqa: PLR2004 assert kwargs["model_name"] == "override/model" + assert kwargs["sys_msg"] is False def test_start_command_missing_config(cli_runner): diff --git a/tests/test_config.py b/tests/test_config.py index 9d0ff08..15caec3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -14,6 +14,46 @@ @pytest.fixture def sample_config_dict(): """Sample configuration dictionary for testing.""" + return { + "system_prompt": "Test system prompt", + "topic_tree": { + "args": { + "root_prompt": "Test root prompt", + "model_system_prompt": "", + "tree_degree": 3, + "tree_depth": 2, + "temperature": 0.7, + "provider": "test", + "model": "model", + }, + "save_as": "test_tree.jsonl", + }, + "data_engine": { + "args": { + "instructions": "Test instructions", + "system_prompt": "", + "provider": "test", + "model": "model", + "temperature": 0.9, + "max_retries": 2, + } + }, + "dataset": { + "creation": { + "num_steps": 5, + "batch_size": 1, + "provider": "test", + "model": "model", + "sys_msg": True, + }, + "save_as": "test_dataset.jsonl", + }, + } + + +@pytest.fixture +def sample_config_dict_no_sys_msg(): + """Sample configuration dictionary without sys_msg setting.""" return { "system_prompt": "Test system prompt", "topic_tree": { @@ -64,6 +104,20 @@ def sample_yaml_file(sample_config_dict): os.unlink(temp_path) +@pytest.fixture +def sample_yaml_file_no_sys_msg(sample_config_dict_no_sys_msg): + """Create a temporary YAML file without sys_msg setting.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml.dump(sample_config_dict_no_sys_msg, f) + temp_path = f.name + + yield temp_path + + # Cleanup + if os.path.exists(temp_path): + os.unlink(temp_path) + + def test_load_from_yaml(sample_yaml_file, sample_config_dict): """Test loading configuration from YAML file.""" config = PromptWrightConfig.from_yaml(sample_yaml_file) @@ -99,6 +153,16 @@ def test_get_engine_args(sample_yaml_file): assert args.model_name == "test/model" assert args.temperature == 0.9 # noqa: PLR2004 assert args.max_retries == 2 # noqa: PLR2004 + assert args.sys_msg is True # Default from dataset config + + +def test_get_engine_args_no_sys_msg(sample_yaml_file_no_sys_msg): + """Test getting EngineArguments without sys_msg setting.""" + config = PromptWrightConfig.from_yaml(sample_yaml_file_no_sys_msg) + args = config.get_engine_args() + + assert isinstance(args, EngineArguments) + assert args.sys_msg is True # Default value when not specified def test_get_topic_tree_args_with_overrides(sample_yaml_file): @@ -133,6 +197,15 @@ def test_get_dataset_config(sample_yaml_file, sample_config_dict): dataset_config = config.get_dataset_config() assert dataset_config == sample_config_dict["dataset"] + assert dataset_config["creation"]["sys_msg"] is True + + +def test_get_dataset_config_no_sys_msg(sample_yaml_file_no_sys_msg): + """Test getting dataset configuration without sys_msg setting.""" + config = PromptWrightConfig.from_yaml(sample_yaml_file_no_sys_msg) + dataset_config = config.get_dataset_config() + + assert "sys_msg" not in dataset_config["creation"] def test_missing_yaml_file(): diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 401418a..59b6cbe 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -24,6 +24,72 @@ def test_dataset_validation(): assert Dataset.validate_sample(invalid_sample) is False +def test_dataset_validation_with_system_message(): + """Test sample validation with system message.""" + from promptwright import Dataset + + valid_sample = { + "messages": [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "response"}, + ] + } + + assert Dataset.validate_sample(valid_sample) is True + + +def test_dataset_validation_system_message_order(): + """Test sample validation with system message in different positions.""" + from promptwright import Dataset + + # System message should be valid in any position + valid_sample_start = { + "messages": [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "response"}, + ] + } + + valid_sample_middle = { + "messages": [ + {"role": "user", "content": "test"}, + {"role": "system", "content": "system prompt"}, + {"role": "assistant", "content": "response"}, + ] + } + + valid_sample_end = { + "messages": [ + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "response"}, + {"role": "system", "content": "system prompt"}, + ] + } + + assert Dataset.validate_sample(valid_sample_start) is True + assert Dataset.validate_sample(valid_sample_middle) is True + assert Dataset.validate_sample(valid_sample_end) is True + + +def test_dataset_validation_multiple_system_messages(): + """Test sample validation with multiple system messages.""" + from promptwright import Dataset + + # Multiple system messages should be valid + valid_sample = { + "messages": [ + {"role": "system", "content": "system prompt 1"}, + {"role": "system", "content": "system prompt 2"}, + {"role": "user", "content": "test"}, + {"role": "assistant", "content": "response"}, + ] + } + + assert Dataset.validate_sample(valid_sample) is True + + def test_dataset_add_samples(): """Test adding samples to dataset.""" from promptwright import Dataset @@ -50,6 +116,35 @@ def test_dataset_add_samples(): assert dataset[0] == samples[0] +def test_dataset_add_samples_with_system_messages(): + """Test adding samples with system messages to dataset.""" + from promptwright import Dataset + + dataset = Dataset() + + samples = [ + { + "messages": [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "test1"}, + {"role": "assistant", "content": "response1"}, + ] + }, + { + "messages": [ + {"role": "system", "content": "system prompt"}, + {"role": "user", "content": "test2"}, + {"role": "assistant", "content": "response2"}, + ] + }, + ] + + dataset.add_samples(samples) + assert len(dataset) == 2 # noqa: PLR2004 + assert dataset[0] == samples[0] + assert dataset[0]["messages"][0]["role"] == "system" + + def test_dataset_filter_by_role(): """Test filtering samples by role.""" from promptwright import Dataset @@ -70,3 +165,32 @@ def test_dataset_filter_by_role(): user_messages = dataset.filter_by_role("user") assert len(user_messages) == 1 assert user_messages[0]["messages"][0]["content"] == "test1" + + system_messages = dataset.filter_by_role("system") + assert len(system_messages) == 1 + assert system_messages[0]["messages"][0]["content"] == "sys" + + +def test_dataset_get_statistics(): + """Test getting dataset statistics.""" + from promptwright import Dataset + + dataset = Dataset() + + samples = [ + { + "messages": [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "test1"}, + {"role": "assistant", "content": "response1"}, + ] + } + ] + + dataset.add_samples(samples) + stats = dataset.get_statistics() + + assert stats["total_samples"] == 1 + assert stats["avg_messages_per_sample"] == 3 # noqa: PLR2004 + assert "system" in stats["role_distribution"] + assert stats["role_distribution"]["system"] == 1 / 3 # noqa: PLR2004 diff --git a/tests/test_engine.py b/tests/test_engine.py index 5c4d8b7..75a40c5 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -77,6 +77,103 @@ def test_create_data_success(mock_batch_completion, data_engine): assert len(dataset.samples) == expected_num_samples +@patch("promptwright.engine.litellm.batch_completion") +def test_create_data_with_sys_msg_default(mock_batch_completion, data_engine): + # Mock valid JSON response + mock_batch_completion.return_value = [ + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"messages": [{"role": "user", "content": "example"}, {"role": "assistant", "content": "response"}]}' + ) + ) + ] + ) + ] + + topic_tree = MagicMock() + topic_tree.tree_paths = ["path1"] + + # Generate data with default sys_msg (True) + dataset = data_engine.create_data(num_steps=1, batch_size=1, topic_tree=topic_tree) + + # Verify system message is included + assert len(dataset.samples) == 1 + assert len(dataset.samples[0]["messages"]) == 3 # noqa: PLR2004 + assert dataset.samples[0]["messages"][0]["role"] == "system" + assert ( + dataset.samples[0]["messages"][0]["content"] == data_engine.args.system_prompt + ) + + +@patch("promptwright.engine.litellm.batch_completion") +def test_create_data_without_sys_msg(mock_batch_completion, data_engine): + # Mock valid JSON response + mock_batch_completion.return_value = [ + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"messages": [{"role": "user", "content": "example"}, {"role": "assistant", "content": "response"}]}' + ) + ) + ] + ) + ] + + topic_tree = MagicMock() + topic_tree.tree_paths = ["path1"] + + # Generate data with sys_msg=False + dataset = data_engine.create_data( + num_steps=1, batch_size=1, topic_tree=topic_tree, sys_msg=False + ) + + # Verify system message is not included + assert len(dataset.samples) == 1 + assert len(dataset.samples[0]["messages"]) == 2 # noqa: PLR2004 + assert dataset.samples[0]["messages"][0]["role"] == "user" + + +@patch("promptwright.engine.litellm.batch_completion") +def test_create_data_sys_msg_override(mock_batch_completion): + # Create engine with sys_msg=False + args = EngineArguments( + instructions="Test instructions", + system_prompt="Test system prompt", + model_name="test-model", + sys_msg=False, # Default to False + ) + engine = DataEngine(args) + + # Mock valid JSON response + mock_batch_completion.return_value = [ + MagicMock( + choices=[ + MagicMock( + message=MagicMock( + content='{"messages": [{"role": "user", "content": "example"}, {"role": "assistant", "content": "response"}]}' + ) + ) + ] + ) + ] + + topic_tree = MagicMock() + topic_tree.tree_paths = ["path1"] + + # Override sys_msg=False with True in create_data + dataset = engine.create_data( + num_steps=1, batch_size=1, topic_tree=topic_tree, sys_msg=True + ) + + # Verify system message is included despite engine default + assert len(dataset.samples) == 1 + assert len(dataset.samples[0]["messages"]) == 3 # noqa: PLR2004 + assert dataset.samples[0]["messages"][0]["role"] == "system" + + def test_build_prompt(data_engine): prompt = data_engine.build_prompt("Test prompt", 3, ["subtopic1", "subtopic2"]) assert "{{system_prompt}}" not in prompt