diff --git a/langkit/__init__.py b/langkit/__init__.py index a0c131ba..ed0deec1 100644 --- a/langkit/__init__.py +++ b/langkit/__init__.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from typing import Dict, List - +from .extract import extract import importlib.resources as resources @@ -67,4 +67,4 @@ def package_version(package: str = __package__) -> str: __version__ = package_version() -__ALL__ = [__version__, LangKitConfig] +__ALL__ = [__version__, LangKitConfig, extract] diff --git a/langkit/extract.py b/langkit/extract.py new file mode 100644 index 00000000..059fd187 --- /dev/null +++ b/langkit/extract.py @@ -0,0 +1,20 @@ +import pandas as pd +from typing import Any, Dict, Optional, Union +from whylogs.experimental.core.udf_schema import udf_schema, UdfSchema + + +def extract( + data: Union[pd.DataFrame, Dict[str, Any]], + schema: Optional[UdfSchema] = None, +): + if schema is None: + schema = udf_schema() + if isinstance(data, pd.DataFrame): + df_enhanced, _ = schema.apply_udfs(pandas=data) + return df_enhanced + elif isinstance(data, dict): + _, row_enhanced = schema.apply_udfs(row=data) + return row_enhanced + raise ValueError( + f"Extract: data of type {type(data)} is invalid: supported input types are pandas dataframe or dictionary" + ) diff --git a/langkit/response_hallucination.py b/langkit/response_hallucination.py index 6c2b85d6..b47e2c7f 100644 --- a/langkit/response_hallucination.py +++ b/langkit/response_hallucination.py @@ -255,7 +255,12 @@ def init(llm: LLMInvocationParams, num_samples=1): def response_hallucination(text): series_result = [] for prompt, response in zip(text[_prompt], text[_response]): - result: ConsistencyResult = checker.consistency_check(prompt, response) + if checker is not None: + result: ConsistencyResult = checker.consistency_check(prompt, response) + else: + raise Exception( + "Response Hallucination: you need to call init() before using this function" + ) series_result.append(result.final_score) return series_result diff --git a/langkit/tests/test_extract.py b/langkit/tests/test_extract.py new file mode 100644 index 00000000..100b3b1d --- /dev/null +++ b/langkit/tests/test_extract.py @@ -0,0 +1,48 @@ +import langkit +import pandas as pd +from whylogs.experimental.core.udf_schema import UdfSchema, UdfSpec + + +def test_extract_pandas(): + from langkit import textstat + + textstat.init() + df = pd.DataFrame({"prompt": ["I love you", "I hate you"]}) + enhanced_df = langkit.extract(data=df) + assert "prompt.flesch_reading_ease" in enhanced_df.columns + + +def test_extract_row(): + from langkit import regexes + + regexes.init() + row = {"prompt": "I love you", "response": "address: 123 Main St."} + enhanced_row = langkit.extract(data=row) + assert enhanced_row.get("response.has_patterns") == "mailing address" + assert not enhanced_row.get("prompt.has_patterns") + + +def test_extract_light_metrics(): + from langkit import light_metrics + + light_metrics.init() + + row = {"prompt": "I love you", "response": "address: 123 Main St."} + enhanced_row = langkit.extract(row) + assert enhanced_row.get("response.has_patterns") == "mailing address" + assert not enhanced_row.get("prompt.has_patterns") + assert "prompt.flesch_reading_ease" in enhanced_row.keys() + + +def test_extract_with_custom_schema(): + schema = UdfSchema( + udf_specs=[ + UdfSpec( + column_names=["prompt"], + udfs={"prompt.customfeature": lambda x: x["prompt"]}, + ) + ], + ) + row = {"prompt": "I love you", "response": "address: 123 Main St."} + enhanced_row = langkit.extract(row, schema=schema) + assert enhanced_row.get("prompt.customfeature") == "I love you"