Skip to content

Commit

Permalink
Allow sending label descriptions to the LLM
Browse files Browse the repository at this point in the history
  • Loading branch information
rajasbansal committed Dec 11, 2024
1 parent f43a21d commit e5e3b6b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
10 changes: 9 additions & 1 deletion src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,17 @@ async def arun(
example_template = self.config.example_template()
toEmbed = example_template.format_map(defaultdict(str, chunk))
selected_labels_map = {}
selected_labels_desc_map = {}
for attribute in self.config.attributes():
attribute_name = attribute.get("name")
label_selector = self.label_selector_map.get(attribute_name)
if label_selector:
selected_labels = label_selector.select_labels(toEmbed)
(
selected_labels,
selected_labels_desc,
) = label_selector.select_labels(toEmbed)
selected_labels_map[attribute_name] = selected_labels
selected_labels_desc_map[attribute_name] = selected_labels_desc
if self.example_selector:
examples = self.example_selector.select_examples(
safe_serialize_to_string(chunk),
Expand All @@ -288,6 +293,9 @@ async def arun(
selected_labels_map=selected_labels_map
if self.label_selector_map
else None,
selected_labels_desc_map=selected_labels_desc_map
if self.label_selector_map
else None,
max_input_tokens=self.llm.max_context_length,
get_num_tokens=self.llm.get_num_tokens,
)
Expand Down
23 changes: 20 additions & 3 deletions src/autolabel/tasks/attribute_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, config: AutolabelConfig) -> None:
def _construct_attribute_json(
self,
selected_labels_map: Dict[str, List[str]] = None,
selected_labels_desc_map: Dict[str, Dict[str, str]] = None,
) -> Tuple[str, Dict]:
"""
This function is used to construct the attribute json string for the output guidelines.
Expand Down Expand Up @@ -107,6 +108,20 @@ def _construct_attribute_json(
attribute_options = selected_labels_map[attribute_name]
attribute_desc += f"\nOptions:\n{','.join(attribute_options)}"

attribute_options_desc = attribute_dict.get("options_desc", {})
if (
selected_labels_desc_map
and attribute_name in selected_labels_desc_map
):
attribute_options_desc = selected_labels_desc_map[attribute_name]
attribute_options_desc = {
k: v for k, v in attribute_options_desc.items() if v is not None
}
if attribute_options_desc:
attribute_desc += "\nDescription for each option:"
for k, v in attribute_options_desc.items():
attribute_desc += f"\n{k}: {v}"

output_json[attribute_name] = attribute_desc

if (
Expand Down Expand Up @@ -182,6 +197,7 @@ def construct_prompt(
max_input_tokens: Optional[int] = None,
get_num_tokens: Optional[Callable] = None,
selected_labels_map: Dict[str, List[str]] = None,
selected_labels_desc_map: Dict[str, Dict[str, str]] = None,
**kwargs,
) -> Tuple[str, str]:
fmt_task_guidelines = self.task_guidelines
Expand Down Expand Up @@ -211,6 +227,7 @@ def construct_prompt(

attribute_json, output_schema = self._construct_attribute_json(
selected_labels_map=selected_labels_map,
selected_labels_desc_map=selected_labels_desc_map,
)
output_guidelines = (
self.output_guidelines
Expand Down Expand Up @@ -389,9 +406,9 @@ def parse_llm_response(
original_attr_labels,
),
)
llm_label[attribute["name"]] = (
self.config.label_separator().join(filtered_attr_labels)
)
llm_label[
attribute["name"]
] = self.config.label_separator().join(filtered_attr_labels)
if len(filtered_attr_labels) != len(original_attr_labels):
logger.warning(
f"Attribute {attr_label} from the LLM response {llm_label} is not in the labels list. Filtered list: {filtered_attr_labels}",
Expand Down

0 comments on commit e5e3b6b

Please sign in to comment.