Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make system message optional #18

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ just open an issue).
- **YAML Configuration**: Define your generation tasks using YAML configuration files
- **Command Line Interface**: Run generation tasks directly from the command line
- **Push to Hugging Face**: Push the generated dataset to Hugging Face Hub with automatic dataset cards and tags
- **System Message Control**: Choose whether to include system messages in the generated dataset

## Getting Started

Expand Down Expand Up @@ -95,6 +96,7 @@ dataset:
num_steps: 5
batch_size: 1
model_name: "ollama/mistral:latest"
sys_msg: true # Include system message in dataset (default: true)
save_as: "basic_prompt_dataset.jsonl"

# Optional Hugging Face Hub configuration
Expand Down Expand Up @@ -128,6 +130,7 @@ promptwright start config.yaml \
--tree-depth 3 \
--num-steps 10 \
--batch-size 2 \
--sys-msg true \ # Control system message inclusion (default: true)
--hf-repo username/dataset-name \
--hf-token your-token \
--hf-tags tag1 --hf-tags tag2
Expand Down Expand Up @@ -185,6 +188,7 @@ engine = DataEngine(
model_name="ollama/llama3",
temperature=0.9,
max_retries=2,
sys_msg=True, # Include system message in dataset (default: true)
)
)
```
Expand Down Expand Up @@ -218,6 +222,7 @@ make all

### Prompt Output Examples

With sys_msg=true (default):
```json
{
"messages": [
Expand All @@ -237,6 +242,22 @@ make all
}
```

With sys_msg=false:
```json
{
"messages": [
{
"role": "user",
"content": "Create a descriptive passage about a character discovering their hidden talents."
},
{
"role": "assistant",
"content": "As she stared at the canvas, Emma's fingers hovered above the paintbrushes, as if hesitant to unleash the colors that had been locked within her. The strokes began with bold abandon, swirling blues and greens merging into a mesmerizing dance of light and shadow. With each passing moment, she felt herself becoming the art – her very essence seeping onto the canvas like watercolors in a spring storm. The world around her melted away, leaving only the vibrant symphony of color and creation."
}
]
}
```

## Model Compatibility

The library should work with most LLM models. It has been tested with the
Expand Down
1 change: 1 addition & 0 deletions examples/example_basic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
model_name="ollama/mistral-nemo:latest", # Model name
temperature=0.9, # Higher temperature for more creative variations
max_retries=2, # Retry failed prompts up to 2 times
sys_msg=True, # Include system message in dataset (default: true)
)
)

Expand Down
1 change: 1 addition & 0 deletions examples/example_basic_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ dataset:
batch_size: 1
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
sys_msg: true # Include system message in dataset (default: true)
save_as: "basic_prompt_dataset.jsonl"
12 changes: 7 additions & 5 deletions examples/example_culinary_database.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Example YAML configuration for basic prompt generation
system_prompt: "You are a helpful assistant. You provide clear and concise answers to user questions."
system_prompt: |
You are a culinary expert who documents recipes and cooking techniques.
Your entries should be detailed, precise, and include both traditional and modern cooking methods.

topic_tree:
args:
root_prompt: "You are a culinary expert who documents recipes and cooking techniques.
Your entries should be detailed, precise, and include both traditional and modern cooking methods."
root_prompt: "Global Cuisine and Cooking Techniques"
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
tree_degree: 5 # Different continents
tree_depth: 3 # Deeper tree for more specific topics
tree_degree: 5 # Different cuisine types
tree_depth: 3 # Specific dishes and techniques
temperature: 0.7 # Higher temperature for more creative variations
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
Expand Down Expand Up @@ -35,4 +36,5 @@ dataset:
batch_size: 1
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
sys_msg: true # Include system message in dataset (default: true)
save_as: "culinary_database.jsonl"
13 changes: 7 additions & 6 deletions examples/example_historic_figures.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ system_prompt: |

topic_tree:
args:
root_prompt: "Capital Cities of the World."
root_prompt: "Notable Historical Figures Across Different Eras and Fields"
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
tree_degree: 3 # Different continents
tree_depth: 2 # Deeper tree for more specific topics
temperature: 0.7 # Higher temperature for more creative variations
tree_degree: 4 # Different categories
tree_depth: 3 # Deeper tree for more specific figures
temperature: 0.6 # Balanced temperature for creativity and accuracy
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
save_as: "historical_figures_tree.jsonl"
Expand All @@ -23,7 +23,7 @@ data_engine:
system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
temperature: 0.9 # Higher temperature for more creative variations
temperature: 0.7 # Balance between creativity and accuracy
max_retries: 2 # Retry failed prompts up to 2 times

dataset:
Expand All @@ -32,4 +32,5 @@ dataset:
batch_size: 1
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
save_as: "basic_prompt_dataset.jsonl"
sys_msg: true # Include system message in dataset (default: true)
save_as: "historical_figures_database.jsonl"
13 changes: 7 additions & 6 deletions examples/example_programming_challenges.py.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ topic_tree:
args:
root_prompt: "Programming Challenges Across Different Difficulty Levels and Concepts"
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
tree_degree: 3 # Different continents
tree_depth: 2 # Deeper tree for more specific topics
temperature: 0.7 # Higher temperature for more creative variations
tree_degree: 4 # Different programming concepts
tree_depth: 2 # Various difficulty levels
temperature: 0.7 # Higher temperature for creative problem scenarios
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
save_as: "basic_prompt_topictree.jsonl"
save_as: "programming_challenges_tree.jsonl"

data_engine:
args:
Expand All @@ -27,7 +27,7 @@ data_engine:
system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
temperature: 0.9 # Higher temperature for more creative variations
temperature: 0.8 # Higher temperature for creative problem scenarios
max_retries: 2 # Retry failed prompts up to 2 times

dataset:
Expand All @@ -36,4 +36,5 @@ dataset:
batch_size: 1
provider: "ollama" # LLM provider
model: "mistral-nemo:latest" # Model name
save_as: "basic_prompt_dataset.jsonl"
sys_msg: true # Include system message in dataset (default: true)
save_as: "programming_challenges.jsonl"
1 change: 1 addition & 0 deletions examples/example_with_hf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dataset:
num_steps: 5
batch_size: 1
model_name: "ollama/mistral:latest"
sys_msg: true # Include system message in dataset (default: true)
save_as: "basic_prompt_dataset.jsonl"

# Hugging Face Hub configuration (optional)
Expand Down
7 changes: 7 additions & 0 deletions promptwright/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ def cli():
multiple=True,
help="Additional tags for the dataset (can be specified multiple times)",
)
@click.option(
"--sys-msg",
type=bool,
help="Include system message in dataset (default: true)",
)
def start( # noqa: PLR0912
config_file: str,
topic_tree_save_as: str | None = None,
Expand All @@ -61,6 +66,7 @@ def start( # noqa: PLR0912
hf_repo: str | None = None,
hf_token: str | None = None,
hf_tags: list[str] | None = None,
sys_msg: bool | None = None,
) -> None:
"""Generate training data from a YAML configuration file."""
try:
Expand Down Expand Up @@ -150,6 +156,7 @@ def start( # noqa: PLR0912
batch_size=batch_size or dataset_params.get("batch_size", 1),
topic_tree=tree,
model_name=model_name,
sys_msg=sys_msg, # Pass sys_msg to create_data
)
except Exception as e:
handle_error(
Expand Down
5 changes: 5 additions & 0 deletions promptwright/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,17 @@ def get_engine_args(self, **overrides) -> EngineArguments:
# Construct full model string
args["model_name"] = construct_model_string(provider, model)

# Get sys_msg from dataset config, defaulting to True
dataset_config = self.get_dataset_config()
sys_msg = dataset_config.get("creation", {}).get("sys_msg", True)

return EngineArguments(
instructions=args.get("instructions", ""),
system_prompt=args.get("system_prompt", ""),
model_name=args["model_name"],
temperature=args.get("temperature", 0.9),
max_retries=args.get("max_retries", 2),
sys_msg=sys_msg,
)

def get_dataset_config(self) -> dict:
Expand Down
26 changes: 23 additions & 3 deletions promptwright/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EngineArguments:
default_batch_size: int = 5
default_num_examples: int = 3
request_timeout: int = 30
sys_msg: bool = True # Default to True for including system message


class DataEngine:
Expand Down Expand Up @@ -81,7 +82,10 @@ def __init__(self, args: EngineArguments):
"malformed_responses": [],
"other_errors": [],
}
self.args.system_prompt = ENGINE_JSON_INSTRUCTIONS + self.args.system_prompt
# Store original system prompt for dataset inclusion
self.original_system_prompt = args.system_prompt
# Use ENGINE_JSON_INSTRUCTIONS only for generation prompt
self.generation_system_prompt = ENGINE_JSON_INSTRUCTIONS + args.system_prompt

def analyze_failure(self, response_content: str, error: Exception = None) -> str:
"""Analyze the failure reason for a sample."""
Expand Down Expand Up @@ -134,6 +138,7 @@ def create_data( # noqa: PLR0912
batch_size: int = 10,
topic_tree: TopicTree = None,
model_name: str = None,
sys_msg: bool = None, # Allow overriding sys_msg from args
):
if num_steps is None:
raise ValueError("num_steps must be specified") # noqa: TRY003
Expand All @@ -144,6 +149,9 @@ def create_data( # noqa: PLR0912
if not self.model_name:
raise ValueError("No valid model_name provided") # noqa: TRY003

# Use provided sys_msg or fall back to args.sys_msg
include_sys_msg = sys_msg if sys_msg is not None else self.args.sys_msg

data_creation_prompt = SAMPLE_GENERATION_PROMPT

tree_paths = None
Expand Down Expand Up @@ -204,6 +212,17 @@ def create_data( # noqa: PLR0912
response_content = r.choices[0].message.content
parsed_json = validate_json_response(response_content)

if parsed_json and include_sys_msg: # noqa: SIM102
# Add system message at the start if sys_msg is True
if "messages" in parsed_json:
parsed_json["messages"].insert(
0,
{
"role": "system",
"content": self.original_system_prompt,
},
)

if parsed_json:
samples.append(parsed_json)
else:
Expand Down Expand Up @@ -284,7 +303,7 @@ def build_prompt(
subtopics_list: list[str] = None,
) -> str:
prompt = data_creation_prompt.replace(
"{{{{system_prompt}}}}", self.build_system_prompt()
"{{{{system_prompt}}}}", self.generation_system_prompt
)
prompt = prompt.replace(
"{{{{instructions}}}}", self.build_custom_instructions_text()
Expand All @@ -297,7 +316,8 @@ def build_prompt(
)

def build_system_prompt(self):
return self.args.system_prompt
"""Return the original system prompt for dataset inclusion."""
return self.original_system_prompt

def build_custom_instructions_text(self) -> str:
if self.args.instructions is None:
Expand Down
4 changes: 2 additions & 2 deletions promptwright/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def update_dataset_card(self, repo_id: str, tags: list[str] | None = None):
try:
card = DatasetCard.load(repo_id)

# Initialize tags if not a list
if not isinstance(card.data.tags, list):
# Initialize tags if not present
if not hasattr(card.data, "tags") or not isinstance(card.data.tags, list):
card.data.tags = []

# Add default promptwright tags
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "promptwright"
version = "1.1.1"
version = "1.2.1"
description = "LLM based Synthetic Data Generation"
authors = ["Luke Hinds <[email protected]>"]
readme = "README.md"
Expand Down
Loading
Loading