Skip to content

Commit

Permalink
graph docs
Browse files Browse the repository at this point in the history
  • Loading branch information
sbordt committed Jun 21, 2024
1 parent 47eaa2e commit 637c1f5
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
18 changes: 9 additions & 9 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
API Reference
=============

High-level API
High-Level API
--------------

.. automodule:: t2ebm
:members: describe_graph, describe_ebm, feature_importances_to_text,
:members: describe_graph, describe_ebm, feature_importances_to_text
:show-inheritance:

Graphs
------
Extract graphs from EBM's and convert them to text
--------------------------------------------------

.. automodule:: t2ebm.graphs
:members: EBMGraph, extract_graph
:members: EBMGraph, extract_graph, simplify_graph, plot_graph, graph_to_text, text_to_graph
:show-inheritance:

Prompts
-------
Prompt templates
----------------

.. automodule:: t2ebm.prompts
:members: graph_system_msg, describe_graph, describe_graph_cot, summarize_ebm
:show-inheritance:

Interace to the LLM
-------------------
Interface to the LLM
--------------------

.. automodule:: t2ebm.llm
:members: AbstractChatModel, OpenAIChatModel, openai_setup, setup, chat_completion
Expand Down
68 changes: 43 additions & 25 deletions t2ebm/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
###################################################################################################


# a low-level datastructure for the graphs of explainable boosting machines
@dataclass
class EBMGraph:
"""A datastructure for the graphs of Explainable Boosting Machines.
"""
feature_name: str
feature_type: str
x_vals: typing.List[
Expand All @@ -30,23 +31,28 @@ class EBMGraph:

def extract_graph(
ebm,
feature_index,
feature_index :int,
normalization="none",
use_feature_bounds=True,
):
"""Extract a graph from an Explainable Boosting Machine.
"""Extract the graph of a feature from an Explainable Boosting Machine.
This is a low-level function. It does not return the final format in which the graph is presented to the LLM.
The purpose of this function is to extract the graph from the interals of the EBM and return it in a format that is easy to work with.
:param ebm:
:param feature_index:
:param normalization: how to normalize the graph. possible values are: 'mean', 'min', 'none'
:param use_feature_bounds: if True, the first and last bin are min and max value of the feature stored in the EBM. If false, the first and last value are -inf and inf, respectively.
:return: EBMGraph
"""
Args:
ebm (_type_): The EBM.
feature_index (int): The index of the feature in the EBM.
normalization (str, optional): How to normalize the graph (shift on the y-axis). Possible values are: 'mean', 'min', 'none'. Defaults to "none".
use_feature_bounds (bool, optional): If True, the first and last x-axis bins are the min and max values of the feature stored in the EBM. If false, the first and last value are -inf and inf, respectively. Defaults to True.
Raises:
Exception: If an error occurs.
Returns:
EBMGraph: The graph.
"""
# read the variables from the ebm
feature_name = ebm.feature_names_in_[feature_index]
feature_type = ebm.feature_types_in_[feature_index]
Expand Down Expand Up @@ -90,15 +96,13 @@ def extract_graph(
def simplify_graph(graph: EBMGraph, min_variation_per_cent: float = 0.0):
"""Simplifies a graph. Removes redundant (flat) bins from the graph.
With min_variation_per_cent>0 (default 0.0), the function simplifies the graph by removing bins
that correspond to a less that min_variation_per_cent change in the score, considering the overal min/max difference of score for the feature as 100%.
this can be useful to keep a query within the context limit. Empirically, removing changes of less than 2% simplifies graphs a lot
in terms of the number of bins/tokens, but visually we can hardly see the difference.
Args:
graph (EBMGraph): The graph.
min_variation_per_cent (float, optional): Parameter that controlls the degree of simplification. If min_variation_per_cent>0, the function simplifies the graph by removing bins that correspond to a less that min_variation_per_cent change in the score, considering the overal min/max difference of score for the feature as 100%. Defaults to 0.0.
:param bins:
:param scores:
:return: EBMGraph. A new simplified graph.
"""
Returns:
EBMGraph: The simplified graph.
"""
assert graph.feature_type == "continuous", "Can only simplify continuous graphs."
x_vals, scores, stds = graph.x_vals, graph.scores, graph.stds
total_variation = np.max(scores) - np.min(scores)
Expand Down Expand Up @@ -128,6 +132,11 @@ def simplify_graph(graph: EBMGraph, min_variation_per_cent: float = 0.0):


def plot_graph(graph: EBMGraph):
"""Plot a graph.
Args:
graph (EBMGraph): The graph.
"""
x_vals, scores, stds = graph.x_vals, graph.scores, graph.stds
if graph.feature_type == "continuous":
x, y, y_lower, y_upper = [], [], [], []
Expand Down Expand Up @@ -163,7 +172,7 @@ def plot_graph(graph: EBMGraph):


def xy_to_json_(x_vals, y_vals):
"""convert a sequence of x_vals and y_vals to a json string"""
"""Convert a sequence of x_vals and y_vals to a json string"""
# continuous features
if isinstance(x_vals[0], tuple):
return (
Expand All @@ -185,17 +194,26 @@ def graph_to_text(
confidence_level=0.95,
max_tokens=3000,
):
"""Convert a graph to a textual representation that can be passed to a LLM.
"""Convert an EBMGraph to text. This is the text that we then pass to the LLM.
This function takes care of all the different formatting issues that can arise in this process.
The function takes care of a variety of different formatting issues that can arise in the process of converting a graph to text.
The user can explicitly specify the format of the feature (continuous, cateorical, boolean), as well as the precision of the values on the x-axis and y-axis. If the user does not specify these values, the function will try to infer them from the graph.
Args:
graph (EBMGraph): _description_
include_description (bool, optional): Whether to include a short descriptive preamble that describes the graph to the LLM. Defaults to True.
feature_format (_type_, optional): The format of the feature (continuous, cateorical, boolean). Defaults to None (auto-detect).
x_axis_precision (_type_, optional): The precision of the values on the x-axis. Defaults to None.
y_axis_precision (str, optional): The precision of the values on the x-axis. Defaults to "auto".
confidence_bounds (bool, optional): Whether to inlcude confidence bounds. Defaults to True.
confidence_level (float, optional): The desired confidence level of the bounds. Defaults to 0.95.
max_tokens (int, optional): The maximum number of tokens that the textual description of the graph can have. The function simplifies the graph so to fit into this token limit. Defaults to 3000.
By default, this functions adds a short descriptive text that describes the graph to the LLM.
Raises:
Exception: If an error occurs.
This function simplifies the graph so that the textual description length is at most {max_tokens} GPT-4 tokens.
Returns:
str: The textual representation of the graph.
"""

# a simple auto-detect for boolean feautres
try:
if (
Expand Down Expand Up @@ -315,7 +333,7 @@ def graph_to_text(
raise Exception(
f"The graph for feature {graph.feature_name} of type"
f" {graph.feature_type} requires {total_tokens} tokens even at"
" a simplification level of 10\%. This graph is too complex to"
" a simplification level of 10%. This graph is too complex to"
" be passed to the LLM within the loken limit of"
f" {max_tokens} tokens."
)
Expand Down

0 comments on commit 637c1f5

Please sign in to comment.