Skip to content

Commit

Permalink
repo-sync-2024-11-19T11:32:29+0800 (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
oeqqwq authored Nov 19, 2024
1 parent fe3231d commit 99b1c19
Show file tree
Hide file tree
Showing 94 changed files with 1,605 additions and 467 deletions.
50 changes: 19 additions & 31 deletions .ci/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,32 +372,20 @@ def exec(self, epsilon=0.0001):

if __name__ == "__main__":
AccuracyTestCase(
service_id="bin_onehot_glm",
service_id="glm",
parties=['alice', 'bob'],
case_dir='.ci/test_data/bin_onehot_glm',
case_dir='.ci/test_data/glm',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'],
score_col_name='pred',
).exec()

AccuracyTestCase(
service_id="bin_onehot_glm_alice_no_feature",
parties=['alice', 'bob'],
case_dir='.ci/test_data/bin_onehot_glm_alice_no_feature',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'],
score_col_name='pred',
use_oss=True,
).exec()
score_col_name='pred_y',
).exec(0.1)

AccuracyTestCase(
service_id="bin_sgb",
service_id="sgb",
parties=['alice', 'bob'],
case_dir='.ci/test_data/bin_sgb',
case_dir='.ci/test_data/sgb',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -406,9 +394,9 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="bin_sgb_alice_no_feature",
service_id="sgb_fetures_in_one_party",
parties=['alice', 'bob'],
case_dir='.ci/test_data/bin_sgb_alice_no_feature',
case_dir='.ci/test_data/fetures_in_one_party/sgb',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -428,9 +416,9 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="xgb_alice_no_feature",
service_id="xgb_fetures_in_one_party",
parties=['alice', 'bob'],
case_dir='.ci/test_data/xgb_alice_no_feature',
case_dir='.ci/test_data/fetures_in_one_party/xgb',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -450,9 +438,9 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="sgd_alice_no_feature",
service_id="sgd_fetures_in_one_party",
parties=['alice', 'bob'],
case_dir='.ci/test_data/sgd_alice_no_feature',
case_dir='.ci/test_data/fetures_in_one_party/sgd',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -461,9 +449,9 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="phe_sgd",
service_id="ou_sgd",
parties=['alice', 'bob'],
case_dir='.ci/test_data/phe_sgd',
case_dir='.ci/test_data/ou_sgd',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -472,9 +460,9 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="phe_sgd_no_feature",
service_id="ou_sgd_fetures_in_one_party",
parties=['alice', 'bob'],
case_dir='.ci/test_data/phe_sgd_no_feature',
case_dir='.ci/test_data/fetures_in_one_party/ou_sgd',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
Expand All @@ -483,12 +471,12 @@ def exec(self, epsilon=0.0001):
).exec()

AccuracyTestCase(
service_id="phe_glm",
service_id="ou_glm",
parties=['alice', 'bob'],
case_dir='.ci/test_data/phe_glm',
case_dir='.ci/test_data/ou_glm',
package_name='s_model.tar.gz',
input_csv_names={'alice': 'alice.csv', 'bob': 'bob.csv'},
expect_csv_name='predict.csv',
query_ids=['1', '2', '3', '4', '5', '6', '7', '8', '9', '15'],
score_col_name='predict_score',
score_col_name='pred_y',
).exec(0.1)
19 changes: 14 additions & 5 deletions .ci/inferencer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ async def run_process(command):
)


async def run_inferencer_example(exmaple_dir: str, result_path: str, target_path: str):
async def run_inferencer_example(
exmaple_dir: str,
result_path: str,
target_path: str,
target_col_name: str,
epsilon=0.0001,
):
print(f"====begin example: {exmaple_dir}=====")

with resources.path(
Expand All @@ -64,11 +70,11 @@ async def run_inferencer_example(exmaple_dir: str, result_path: str, target_path
target_df = pd.read_csv(target_path)

score_col = result_df['score']
pred_col = target_df['pred']
pred_col = target_df[target_col_name]

assert len(score_col) == len(pred_col)

are_close = np.isclose(score_col, pred_col, atol=0.0001)
are_close = np.isclose(score_col, pred_col, atol=epsilon)

for i, match in enumerate(are_close):
assert match, f"row {i} mismatch: {score_col[i]} != {pred_col[i]}"
Expand All @@ -79,14 +85,17 @@ async def run_inferencer_example(exmaple_dir: str, result_path: str, target_path
run_inferencer_example(
"secretflow_serving/tools/inferencer/example/normal",
"tmp/alice/score.csv",
".ci/test_data/bin_onehot_glm/predict.csv",
".ci/test_data/glm/predict.csv",
"pred_y",
0.01,
)
)

asyncio.run(
run_inferencer_example(
"secretflow_serving/tools/inferencer/example/one_party_no_feature",
"tmp/bob/score.csv",
".ci/test_data/bin_onehot_glm_alice_no_feature/predict.csv",
".ci/test_data/fetures_in_one_party/sgd/predict.csv",
"pred",
)
)
20 changes: 0 additions & 20 deletions .ci/test_data/bin_onehot_glm/alice/alice.csv

This file was deleted.

Binary file removed .ci/test_data/bin_onehot_glm/alice/s_model.tar.gz
Binary file not shown.
20 changes: 0 additions & 20 deletions .ci/test_data/bin_onehot_glm/bob/bob.csv

This file was deleted.

Binary file removed .ci/test_data/bin_onehot_glm/bob/s_model.tar.gz
Binary file not shown.
20 changes: 0 additions & 20 deletions .ci/test_data/bin_onehot_glm/predict.csv

This file was deleted.

33 changes: 0 additions & 33 deletions .ci/test_data/bin_onehot_glm_alice_no_feature/alice/alice.csv

This file was deleted.

Binary file not shown.
Loading

0 comments on commit 99b1c19

Please sign in to comment.