From 2f80f4dce0f680ed9ea4a2d40cb858dd142c2092 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:38:09 -0500 Subject: [PATCH] test case for fine-tuned model --- services/inference/tests/test_inference.py | 59 ++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/services/inference/tests/test_inference.py b/services/inference/tests/test_inference.py index 1104d775..6ed2e8eb 100644 --- a/services/inference/tests/test_inference.py +++ b/services/inference/tests/test_inference.py @@ -15,6 +15,7 @@ "ttm-r1": {"context_length": 512, "prediction_length": 96}, "ttm-1024-96-r1": {"context_length": 1024, "prediction_length": 96}, "ttm-r2": {"context_length": 512, "prediction_length": 96}, + "ttm-r2-etth-finetuned": {"context_length": 512, "prediction_length": 96}, "ttm-1024-96-r2": {"context_length": 1024, "prediction_length": 96}, "ttm-1536-96-r2": {"context_length": 1536, "prediction_length": 96}, "ibm/test-patchtst": {"context_length": 512, "prediction_length": 96}, @@ -434,6 +435,64 @@ def test_zero_shot_forecast_inference_no_timestamp(ts_data): print(df_out[0].head()) +@pytest.mark.parametrize( + "ts_data", + [ + "ttm-r2-etth-finetuned", + ], + indirect=True, +) +def test_finetuned_model_inference(ts_data): + test_data, params = ts_data + id_columns = params["id_columns"] + model_id = params["model_id"] + + prediction_length = 96 + + # test single + test_data_ = test_data[test_data[id_columns[0]] == "a"].copy() + encoded_data = encode_data(test_data_, params["timestamp_column"]) + + msg = { + "model_id": model_id, + "parameters": { + # "prediction_length": params["prediction_length"], + }, + "schema": { + "timestamp_column": params["timestamp_column"], + "id_columns": params["id_columns"], + "target_columns": params["target_columns"], + }, + "data": encoded_data, + "future_data": {}, + } + + out = get_inference_response(msg) + assert "Attempted to use a fine-tuned model with a different schema" in out.text + + test_data_ = test_data_.drop(columns=params["timestamp_column"]) + msg = { + "model_id": model_id, + "parameters": { + # "prediction_length": params["prediction_length"], + }, + "schema": { + "timestamp_column": params["timestamp_column"], + "id_columns": params["id_columns"], + "target_columns": ["OT"], + "freq": "1h", + "conditional_columns": [c for c in params["target_columns"] if c != "OT"], + }, + "data": encoded_data, + "future_data": {}, + } + + df_out = get_inference_response(msg) + assert len(df_out) == 1 + assert df_out[0].shape[0] == prediction_length + print(df_out[0].head()) + + @pytest.mark.parametrize( "ts_data", [