diff --git a/test/int/nnc/parallel.tests.c b/test/int/nnc/parallel.tests.c index ae49d605b..2dde11328 100644 --- a/test/int/nnc/parallel.tests.c +++ b/test/int/nnc/parallel.tests.c @@ -166,10 +166,10 @@ TEST_CASE("schedule symbolic graph to data parallel with broadcast and reduce") ccv_nnc_graph_free(graph); ccv_nnc_tensor_arena_free(tensor_arena); ccv_nnc_graph_exec_arena_free(graph_exec_arena); - REQUIRE_TENSOR_EQ(np_updated[0], updated[0], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[1], updated[1], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[2], updated[2], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[3], updated[3], "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[0]->data.f32, updated[0]->data.f32, 8 * 3 * 5 * 5, 1e-4, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[1]->data.f32, updated[1]->data.f32, 8, 1e-5, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[2]->data.f32, updated[2]->data.f32, 8 * 8 * 5 * 5, 1e-4, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[3]->data.f32, updated[3]->data.f32, 8, 1e-4, "updated params should be equal"); ccv_nnc_tensor_free(cpu_input); ccv_nnc_tensor_free(cpu_fit); ccv_nnc_tensor_free(np_updated[0]); @@ -345,10 +345,10 @@ TEST_CASE("schedule symbolic graph to data parallel with allreduce") ccv_nnc_graph_free(graph); ccv_nnc_tensor_arena_free(tensor_arena); ccv_nnc_graph_exec_arena_free(graph_exec_arena); - REQUIRE_TENSOR_EQ(np_updated[0], updated[0], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[1], updated[1], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[2], updated[2], "updated params should be equal"); - REQUIRE_TENSOR_EQ(np_updated[3], updated[3], "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[0]->data.f32, updated[0]->data.f32, 8 * 3 * 5 * 5, 1e-4, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[1]->data.f32, updated[1]->data.f32, 8, 1e-5, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[2]->data.f32, updated[2]->data.f32, 8 * 8 * 5 * 5, 1e-4, "updated params should be equal"); + REQUIRE_ARRAY_EQ_WITH_TOLERANCE(float, np_updated[3]->data.f32, updated[3]->data.f32, 8, 1e-4, "updated params should be equal"); ccv_nnc_tensor_free(cpu_input); ccv_nnc_tensor_free(cpu_fit); ccv_nnc_tensor_free(np_updated[0]);