Skip to content

Commit

Permalink
test case for fine-tuned model
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Nov 15, 2024
1 parent 24462b1 commit 2f80f4d
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions services/inference/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"ttm-r1": {"context_length": 512, "prediction_length": 96},
"ttm-1024-96-r1": {"context_length": 1024, "prediction_length": 96},
"ttm-r2": {"context_length": 512, "prediction_length": 96},
"ttm-r2-etth-finetuned": {"context_length": 512, "prediction_length": 96},
"ttm-1024-96-r2": {"context_length": 1024, "prediction_length": 96},
"ttm-1536-96-r2": {"context_length": 1536, "prediction_length": 96},
"ibm/test-patchtst": {"context_length": 512, "prediction_length": 96},
Expand Down Expand Up @@ -434,6 +435,64 @@ def test_zero_shot_forecast_inference_no_timestamp(ts_data):
print(df_out[0].head())


@pytest.mark.parametrize(
"ts_data",
[
"ttm-r2-etth-finetuned",
],
indirect=True,
)
def test_finetuned_model_inference(ts_data):
test_data, params = ts_data
id_columns = params["id_columns"]
model_id = params["model_id"]

prediction_length = 96

# test single
test_data_ = test_data[test_data[id_columns[0]] == "a"].copy()
encoded_data = encode_data(test_data_, params["timestamp_column"])

msg = {
"model_id": model_id,
"parameters": {
# "prediction_length": params["prediction_length"],
},
"schema": {
"timestamp_column": params["timestamp_column"],
"id_columns": params["id_columns"],
"target_columns": params["target_columns"],
},
"data": encoded_data,
"future_data": {},
}

out = get_inference_response(msg)
assert "Attempted to use a fine-tuned model with a different schema" in out.text

test_data_ = test_data_.drop(columns=params["timestamp_column"])
msg = {
"model_id": model_id,
"parameters": {
# "prediction_length": params["prediction_length"],
},
"schema": {
"timestamp_column": params["timestamp_column"],
"id_columns": params["id_columns"],
"target_columns": ["OT"],
"freq": "1h",
"conditional_columns": [c for c in params["target_columns"] if c != "OT"],
},
"data": encoded_data,
"future_data": {},
}

df_out = get_inference_response(msg)
assert len(df_out) == 1
assert df_out[0].shape[0] == prediction_length
print(df_out[0].head())


@pytest.mark.parametrize(
"ts_data",
[
Expand Down

0 comments on commit 2f80f4d

Please sign in to comment.