Skip to content

Commit

Permalink
add plot method and other minor edits
Browse files Browse the repository at this point in the history
  • Loading branch information
Christina Xu committed Mar 18, 2024
1 parent 7a2621a commit e2dc9dc
Showing 1 changed file with 56 additions and 22 deletions.
78 changes: 56 additions & 22 deletions src/trustyai/utils/extras/metrics_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime as dt
import pandas as pd
import requests
import matplotlib.pyplot as plt

from trustyai.utils.api.api import TrustyAIApi

Expand Down Expand Up @@ -128,7 +129,7 @@ def print_name_mapping(self):
f"{self.trusty_url}/info/names",
json=payload,
headers=self.headers,
verify=True,
verify=self.verify,
timeout=timeout,
)
if response.status_code == 200:
Expand Down Expand Up @@ -182,27 +183,60 @@ def upload_data_to_model(self, model_name: str, json_file: str, timeout=5):
return response.text
raise RuntimeError(f"Error {response.status_code}: {response.reason}")

def get_metric_data(
self, namespace: str, metric: str, time_interval: List[str], timeout=5
):
def get_metric_data(self, metric: str, time_interval: List[str], timeout=5):
"""
Retrives metric data for a specific range in time
Retrives metric data for a specific range in time for each subcategory in data field
"""
params = {"query": f"{metric}{{namespace='{namespace}'}}{time_interval}"}
response = requests.get(
f"{self.thanos_url}/api/v1/query?",
params=params,
headers=self.headers,
verify=self.verify,
timeout=timeout,
)
if response.status_code == 200:
data_dict = json.loads(response.text)["data"]["result"][0]["values"]
metric_df = pd.DataFrame(data_dict, columns=["timestamp", metric])
metric_df["timestamp"] = metric_df["timestamp"].apply(
lambda epoch: dt.datetime.fromtimestamp(epoch).strftime(
"%Y-%m-%d %H:%M:%S"
)
metric_df = pd.DataFrame()
for subcategory in list(
self.get_model_metadata()[0]["data"]["inputSchema"]["nameMapping"].values()
):
params = {
"query": f"{metric}{{subcategory='{subcategory}'}}{time_interval}"
}

response = requests.get(
f"{self.thanos_url}/api/v1/query?",
params=params,
headers=self.headers,
verify=self.verify,
timeout=timeout,
)
return metric_df
raise RuntimeError(f"Error {response.status_code}: {response.reason}")
if response.status_code == 200:
if "timestamp" in metric_df.columns:
pass
else:
metric_df["timestamp"] = [
item[0]
for item in json.loads(response.text)["data"]["result"][0][
"values"
]
]
metric_df[subcategory] = [
item[1]
for item in json.loads(response.text)["data"]["result"][0]["values"]
]
else:
raise RuntimeError(f"Error {response.status_code}: {response.reason}")

metric_df["timestamp"] = metric_df["timestamp"].apply(
lambda epoch: dt.datetime.fromtimestamp(epoch).strftime("%Y-%m-%d %H:%M:%S")
)
return metric_df


@staticmethod
def plot_metric(metric_df: pd.DataFrame, metric: str):
"""
Plots a line for each subcategory in the pandas DataFrame returned by get_metric_request
with the timestamp on x-axis and specified metric on the y-axis
"""
plt.figure(figsize=(12, 5))
for col in metric_df.columns[1:]:
plt.plot(metric_df["timestamp"], metric_df[col])
plt.xlabel("timestamp")
plt.ylabel(metric)
plt.xticks(rotation=45)
plt.legend(metric_df.columns[1:])
plt.tight_layout()
plt.show()

0 comments on commit e2dc9dc

Please sign in to comment.