Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sbordt committed Jul 4, 2024
1 parent 2c7f98b commit 06eed47
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Interface to the LLM
--------------------

.. automodule:: t2ebm.llm
:members: AbstractChatModel, OpenAIChatModel, openai_setup, setup, chat_completion
:members: AbstractChatModel, OpenAIChatModel, openai_setup, chat_completion
:show-inheritance:


41 changes: 32 additions & 9 deletions t2ebm/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
TalkToEBM: A Natural Language Interface to Explainable Boosting Machines
TalkToEBM: A Natural Language Interface to Explainable Boosting Machines.
"""

import inspect
Expand All @@ -24,7 +24,15 @@
###################################################################################################


def feature_importances_to_text(ebm):
def feature_importances_to_text(ebm: Union[ExplainableBoostingClassifier, ExplainableBoostingRegressor]):
"""Convert the feature importances of an EBM to text.
Args:
ebm (_type_): The EBM.
Returns:
str: Textual representation of the feature importances.
"""
feature_importances = ""
for feature_idx, feature_name in enumerate(ebm.feature_names_in_):
feature_importances += (
Expand All @@ -45,14 +53,18 @@ def describe_graph(
num_sentences: int = 7,
**kwargs,
):
"""Ask the LLM to describe a graph from an EBM, using chain-of-thought reasoning.
"""Ask the LLM to describe a graph. Uses chain-of-thought reasoning.
The function accepts additional keyword arguments that are passed to extract_graph, graph_to_text, and describe_graph_cot.
This function accepts arbitrary keyword arguments that are passed to the corresponding lower-level functions.
Args:
llm (Union[AbstractChatModel, str]): The LLM.
ebm (Union[ExplainableBoostingClassifier, ExplainableBoostingRegressor]): The EBM.
feature_index (int): The index of the feature to describe.
num_sentences (int, optional): The desired number of senteces for the description. Defaults to 7.
:param ebm:
:param feature_index:
:param kwargs: see llm_describe_graph
:return: A summary of the graph in at most num_sentences sentences.
Returns:
str: The description of the graph.
"""

# llm setup
Expand Down Expand Up @@ -91,7 +103,18 @@ def describe_ebm(
num_sentences: int = 30,
**kwargs,
):
"""Ask the LLM to describe the LLM in at most {num_sentences} sentences."""
"""Ask the LLM to describe an EBM.
The function accepts additional keyword arguments that are passed to extract_graph, graph_to_text, and describe_graph_cot.
Args:
llm (Union[AbstractChatModel, str]): The LLM.
ebm (Union[ExplainableBoostingClassifier, ExplainableBoostingRegressor]): The EBM.
num_sentences (int, optional): The desired number of senteces for the description. Defaults to 30.
Returns:
str: The description of the EBM.
"""

# llm setup
llm = t2ebm.llm.setup(llm)
Expand Down
10 changes: 8 additions & 2 deletions t2ebm/graphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass
import typing
from typing import Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -8,6 +9,11 @@

from interpret.glassbox._ebm._utils import convert_to_intervals

from interpret.glassbox import (
ExplainableBoostingClassifier,
ExplainableBoostingRegressor,
)

from t2ebm.utils import num_tokens_from_string_

###################################################################################################
Expand All @@ -30,7 +36,7 @@ class EBMGraph:


def extract_graph(
ebm,
ebm : Union[ExplainableBoostingClassifier, ExplainableBoostingRegressor],
feature_index :int,
normalization="none",
use_feature_bounds=True,
Expand Down Expand Up @@ -199,7 +205,7 @@ def graph_to_text(
The function takes care of a variety of different formatting issues that can arise in the process of converting a graph to text.
Args:
graph (EBMGraph): _description_
graph (EBMGraph): The graph.
include_description (bool, optional): Whether to include a short descriptive preamble that describes the graph to the LLM. Defaults to True.
feature_format (_type_, optional): The format of the feature (continuous, cateorical, boolean). Defaults to None (auto-detect).
x_axis_precision (_type_, optional): The precision of the values on the x-axis. Defaults to None.
Expand Down
6 changes: 3 additions & 3 deletions t2ebm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
TalkToEBM structures conversations in a generic message format that can be executed with different LLMs.
TalkToEBM structures conversations in a generic OpenAI message format that can be executed with different LLMs.
To use a custom LLM, simply implement AbstractChatModel.
We interface the LLM via the simple class AbstractChatModel. To use your own LLM, simply implement the chat_completion method in a subclass.
"""

from dataclasses import dataclass
Expand Down Expand Up @@ -119,7 +119,7 @@ def openai_setup(model: str, azure: bool = False, *args, **kwargs):


def setup(model: Union[AbstractChatModel, str]):
"""Setup for a chat model. If the input is a string, we assume that it is the name of an OpenAI model."""
"""Setup a chat model. If the input is a string, we assume that it is the name of an OpenAI model."""
if isinstance(model, str):
model = openai_setup(model)
return model
Expand Down
32 changes: 24 additions & 8 deletions t2ebm/prompts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""
Prompts that ask the LLM to perform tasks with Graphs and EBMs.
Functions either return a string or a sequene of messages / desired responses in the openai message format.
Functions either return a string or a sequene of messages / desired responses in the OpenAI message format.
"""


def graph_system_msg(expert_description="an expert statistician and data scientist"):
"""Instruct the LLM to work with the graphs of a GAM."""
"""A system message that instructs the LLM to work with the graphs of an EBM.
Args:
expert_description (str, optional): Description of the expert that we want the LLM to be. Defaults to "an expert statistician and data scientist".
Returns:
str: The system message.
"""
return f"You are {expert_description}. You interpret global explanations produced by a Generalized Additive Model (GAM). You answer all questions to the best of your ability, relying on the graphs provided by the user, any other information you are given, and your knowledge about the real world."


Expand All @@ -18,12 +25,15 @@ def describe_graph(
):
"""Prompt the LLM to describe a graph. This is intended to be the first prompt in a conversation about a graph.
:param task_description: The final user message that instructs the LLM (default: 'Please describe the general pattern of the graph.')
:param y_axis_description: description of the outcome
:param dataset_description: description of the dataset
Args:
graph (str): The graph to describe (in JSON format, obtained from graph_to_text).
graph_description (str, optional): Additional description of the graph (e.g. "The y-axis of the graph depicts the probability of sucess."). Defaults to "".
dataset_description (str, optional): Additional description of the dataset (e.g. "The dataset is a Pneumonia dataset collected by [...]"). Defaults to "".
task_description (str, optional): A final prompt to instruct the LLM. Defaults to "Please describe the general pattern of the graph.".
:return: str
"""
Returns:
str: The prompt to describe the graph.
"""
prompt = """Below is the graph of a Generalized Additive Model (GAM). The graph is presented as a JSON object with keys representing the x-axis and values representing the y-axis. For continuous features, the keys are intervals that represent ranges where the function predicts the same value. For categorical features, each key represents a possible value that the feature can take.
The graph is provided in the following format:
Expand Down Expand Up @@ -52,7 +62,8 @@ def describe_graph(
def describe_graph_cot(graph, num_sentences=7, **kwargs):
"""Use chain-of-thought reasoning to elicit a description of a graph in at most {num_sentences} sentences.
Return: messages in openai format.
Returns:
Messages in OpenAI format.
"""
return [
{"role": "system", "content": graph_system_msg()},
Expand All @@ -78,6 +89,11 @@ def summarize_ebm(
dataset_description="",
num_sentences: int = None,
):
"""Prompt the LLM to summarize a Generalized Additive Model (GAM).
Returns:
Messages in OpenAI format.
"""
messages = [
{
"role": "system",
Expand Down
2 changes: 1 addition & 1 deletion t2ebm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"

0 comments on commit 06eed47

Please sign in to comment.