From 9d8798ec7f27f734ccd3d49f0fa19dc62482fc29 Mon Sep 17 00:00:00 2001 From: Christina Xu Date: Mon, 18 Mar 2024 15:27:37 -0400 Subject: [PATCH] add plot method and other minor edits --- src/trustyai/utils/extras/metrics_service.py | 77 ++++++++++++++------ 1 file changed, 55 insertions(+), 22 deletions(-) diff --git a/src/trustyai/utils/extras/metrics_service.py b/src/trustyai/utils/extras/metrics_service.py index d78cb99..d824f79 100644 --- a/src/trustyai/utils/extras/metrics_service.py +++ b/src/trustyai/utils/extras/metrics_service.py @@ -5,6 +5,7 @@ import datetime as dt import pandas as pd import requests +import matplotlib.pyplot as plt from trustyai.utils.api.api import TrustyAIApi @@ -128,7 +129,7 @@ def print_name_mapping(self): f"{self.trusty_url}/info/names", json=payload, headers=self.headers, - verify=True, + verify=self.verify, timeout=timeout, ) if response.status_code == 200: @@ -182,27 +183,59 @@ def upload_data_to_model(self, model_name: str, json_file: str, timeout=5): return response.text raise RuntimeError(f"Error {response.status_code}: {response.reason}") - def get_metric_data( - self, namespace: str, metric: str, time_interval: List[str], timeout=5 - ): + def get_metric_data(self, metric: str, time_interval: List[str], timeout=5): """ - Retrives metric data for a specific range in time + Retrives metric data for a specific range in time for each subcategory in data field """ - params = {"query": f"{metric}{{namespace='{namespace}'}}{time_interval}"} - response = requests.get( - f"{self.thanos_url}/api/v1/query?", - params=params, - headers=self.headers, - verify=self.verify, - timeout=timeout, - ) - if response.status_code == 200: - data_dict = json.loads(response.text)["data"]["result"][0]["values"] - metric_df = pd.DataFrame(data_dict, columns=["timestamp", metric]) - metric_df["timestamp"] = metric_df["timestamp"].apply( - lambda epoch: dt.datetime.fromtimestamp(epoch).strftime( - "%Y-%m-%d %H:%M:%S" - ) + metric_df = pd.DataFrame() + for subcategory in list( + self.get_model_metadata()[0]["data"]["inputSchema"]["nameMapping"].values() + ): + params = { + "query": f"{metric}{{subcategory='{subcategory}'}}{time_interval}" + } + + response = requests.get( + f"{self.thanos_url}/api/v1/query?", + params=params, + headers=self.headers, + verify=self.verify, + timeout=timeout, ) - return metric_df - raise RuntimeError(f"Error {response.status_code}: {response.reason}") + if response.status_code == 200: + if "timestamp" in metric_df.columns: + pass + else: + metric_df["timestamp"] = [ + item[0] + for item in json.loads(response.text)["data"]["result"][0][ + "values" + ] + ] + metric_df[subcategory] = [ + item[1] + for item in json.loads(response.text)["data"]["result"][0]["values"] + ] + else: + raise RuntimeError(f"Error {response.status_code}: {response.reason}") + + metric_df["timestamp"] = metric_df["timestamp"].apply( + lambda epoch: dt.datetime.fromtimestamp(epoch).strftime("%Y-%m-%d %H:%M:%S") + ) + return metric_df + + @staticmethod + def plot_metric(metric_df: pd.DataFrame, metric: str): + """ + Plots a line for each subcategory in the pandas DataFrame returned by get_metric_request + with the timestamp on x-axis and specified metric on the y-axis + """ + plt.figure(figsize=(12, 5)) + for col in metric_df.columns[1:]: + plt.plot(metric_df["timestamp"], metric_df[col]) + plt.xlabel("timestamp") + plt.ylabel(metric) + plt.xticks(rotation=45) + plt.legend(metric_df.columns[1:]) + plt.tight_layout() + plt.show()