Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gsm8k bootstrap #10

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 4 additions & 61 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,62 +1,5 @@
aiohttp==3.9.5
aiosignal==1.3.1
annotated-types==0.6.0
anthropic==0.25.4
anyio==4.3.0
asttokens==2.4.1
attrs==23.2.0
black==24.4.0
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.7
distro==1.9.0
docstring-parser==0.15
executing==2.0.1
filelock==3.13.4
frozenlist==1.4.1
fsspec==2024.3.1
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.22.2
idna==3.7
iniconfig==2.0.0
inline-snapshot==0.8.0
instructor==1.4.0
langsmith==0.1.48
markdown-it-py==3.0.0
mdurl==0.1.2
multidict==6.0.5
mypy-extensions==1.0.0
numpy==1.26.4
openai==1.40.1
orjson==3.10.1
packaging==24.0
pandas==2.2.2
pathspec==0.12.1
platformdirs==4.2.0
pluggy==1.4.0
pydantic==2.7.0
pydantic-core==2.18.1
pygments==2.17.2
pytest==8.1.1
pytest-asyncio==0.23.6
pytest-asyncio-cooperative==0.36.0
python-dateutil==2.9.0.post0
pytz==2024.1
pyyaml==6.0.1
requests==2.31.0
rich==13.7.1
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
tenacity==8.2.3
tokenizers==0.19.1
toml==0.10.2
tqdm==4.66.2
typer==0.12.3
types-toml==0.10.8.20240310
typing-extensions==4.11.0
tzdata==2024.1
urllib3==2.2.1
yarl==1.9.4
instructor[anthropic]==1.4.0
matplotlib==3.9.2
seaborn==0.13.2
scipy==1.14.1
123 changes: 123 additions & 0 deletions scripts/visualise_gsm8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy import stats
from glob import glob


def extract_scores(json_path: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets load from braintrust?

with open(json_path, "r") as file:
data = json.load(file)
for item in data:
yield int(item["scores"]["ExactMatch"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool?



def bootstrap_batch(data, num_samples, sample_size):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_Size should be len(data)

data = list(data)
return [
np.mean(np.random.choice(data, size=sample_size, replace=True))
for _ in range(num_samples)
]


def generate_kde_plot(results: list[dict], visualisation_path: str, sample_size: int):
accuracies = [item["scores"] for item in results]
kdes = [stats.gaussian_kde(accuracy) for accuracy in accuracies]

x_range = np.linspace(
min([min(accuracy) for accuracy in accuracies]),
max([max(accuracy) for accuracy in accuracies]),
100,
).tolist()
kdes = [kde(x_range) for kde in kdes]

# Plot the KDEs
plt.figure(figsize=(10, 6))
for i, (kde, item) in enumerate(zip(kdes, results)):
file_name = os.path.basename(item["file_name"])
plt.plot(
x_range,
kde,
label="Correct Answer" if "correct-answer.json" in file_name else "Answer",
)
plt.title(f"Kernel Density Estimation for bootstrap sample size {sample_size}")
plt.xlabel("Accuracy")
plt.ylabel("Density")
plt.legend()
plt.grid(True)
plt.savefig(visualisation_path)


def generate_boxplot(results: list[dict], visualisation_path: str):
df = pd.DataFrame(
[
{"File": item["file_name"], "Accuracy": score}
for item in results
for score in item["scores"]
]
)
# Create a mapping for the labels
df["File"] = df["File"].apply(
lambda x: "Correct Answer" if "correct-answer.json" in x else "Answer"
)

# Sort the dataframe to ensure consistent ordering
df = df.sort_values("File")

plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")

# Create the box plot
sns.boxplot(x="File", y="Accuracy", data=df)

# Customize the plot
plt.title("Accuracy Distribution by File", fontsize=16)
plt.xlabel("File", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.xticks(rotation=45)

# Add individual data points
sns.stripplot(x="File", y="Accuracy", data=df, color="black", size=4, alpha=0.5)

# Adjust layout and save the plot
plt.tight_layout()
plt.savefig(visualisation_path)


def compute_statistics(results: list[dict]):
for item in results:
mean = np.mean(item["scores"])
std = np.std(item["scores"])
var = np.var(item["scores"])
yield {
"file_name": item["file_name"],
"mean": mean,
"std": std,
"var": var,
}


if __name__ == "__main__":
BOOTSTRAP_SAMPLES = 1000
BOOTSTRAP_SAMPLE_SIZE = 200
RESULTS_FILE = "./bootstrap_results.jsonl"
DATA_DIR = "./scripts/data/raw"
np.random.seed(42)

# Read in the .json files and transform them into an input-output pair
results = [
{
"file_name": f,
"scores": bootstrap_batch(
extract_scores(f), BOOTSTRAP_SAMPLES, BOOTSTRAP_SAMPLE_SIZE
),
}
for f in glob(f"{DATA_DIR}/*.json")
]

generate_boxplot(results, "./boxplot.png")
generate_kde_plot(results, "./kde.png", BOOTSTRAP_SAMPLE_SIZE)
print(pd.DataFrame(compute_statistics(results)))