From ba586aed2093e71ec430bc22cf8c13c8d1e3100d Mon Sep 17 00:00:00 2001 From: Marsella8 <45826022+Marsella8@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:38:22 -0700 Subject: [PATCH] Implementations for methods for machine_views and associated modules (#1429) * initial commit for machine view adjacent modules * Formatting * Tests for new machine_view.cc functions * formatting * Minor Test correction * formatting * PR fixes * PR Fixes --------- Co-authored-by: Pietro Max Marsella --- lib/pcg/include/pcg/device_id.h | 2 +- lib/pcg/include/pcg/machine_view.h | 19 +++++- lib/pcg/include/pcg/strided_rectangle.h | 5 +- lib/pcg/src/pcg/machine_view.cc | 75 +++++++++++++++++++--- lib/pcg/src/pcg/strided_rectangle_side.cc | 6 +- lib/pcg/src/strided_rectangle.cc | 17 +++-- lib/pcg/test/src/test_machine_view.cc | 74 +++++++++++++++++++++ lib/pcg/test/src/test_strided_rectangle.cc | 37 +++++++++++ 8 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 lib/pcg/test/src/test_machine_view.cc create mode 100644 lib/pcg/test/src/test_strided_rectangle.cc diff --git a/lib/pcg/include/pcg/device_id.h b/lib/pcg/include/pcg/device_id.h index be92be7081..1157a2932a 100644 --- a/lib/pcg/include/pcg/device_id.h +++ b/lib/pcg/include/pcg/device_id.h @@ -10,7 +10,7 @@ namespace FlexFlow { device_id_t operator+(device_id_t, size_t); -DeviceType get_device_type(device_id_t); +DeviceType get_device_type(device_id_t const &device_id); gpu_id_t unwrap_gpu(device_id_t); cpu_id_t unwrap_cpu(device_id_t); diff --git a/lib/pcg/include/pcg/machine_view.h b/lib/pcg/include/pcg/machine_view.h index 625b128d35..56abf5aa20 100644 --- a/lib/pcg/include/pcg/machine_view.h +++ b/lib/pcg/include/pcg/machine_view.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H #include "pcg/cpu_id_t.dtg.h" +#include "pcg/device_id.h" #include "pcg/device_id_t.dtg.h" #include "pcg/device_type.dtg.h" #include "pcg/gpu_id_t.dtg.h" @@ -14,15 +15,31 @@ namespace FlexFlow { std::vector device_ids(MachineView const &); -std::size_t num_dims(MachineView const &); +size_t num_dims(MachineView const &); std::size_t num_devices(MachineView const &); DeviceType get_device_type(MachineView const &); MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride = 1); MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride = 1); +MachineView + make_1d_machine_view(device_id_t start, device_id_t stop, int stride = 1); + +MachineView make_1d_machine_view(gpu_id_t start, + num_points_t num_points, + int stride = 1); +MachineView make_1d_machine_view(cpu_id_t start, + num_points_t num_points, + int stride = 1); MachineView make_1d_machine_view(device_id_t start, num_points_t num_points, int stride = 1); + +MachineView make_1d_machine_view(gpu_id_t start, + side_size_t interval_size, + int stride = 1); +MachineView make_1d_machine_view(cpu_id_t start, + side_size_t interval_size, + int stride = 1); MachineView make_1d_machine_view(device_id_t start, side_size_t interval_size, int stride = 1); diff --git a/lib/pcg/include/pcg/strided_rectangle.h b/lib/pcg/include/pcg/strided_rectangle.h index 24ae51ac41..9c3b8eeda9 100644 --- a/lib/pcg/include/pcg/strided_rectangle.h +++ b/lib/pcg/include/pcg/strided_rectangle.h @@ -8,8 +8,9 @@ namespace FlexFlow { size_t get_num_dims(StridedRectangle const &); -StridedRectangleSide get_side_at_idx(StridedRectangle const &, - ff_dim_t const &); +StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, + ff_dim_t const &idx); +num_points_t get_num_points(StridedRectangle const &rect); } // namespace FlexFlow diff --git a/lib/pcg/src/pcg/machine_view.cc b/lib/pcg/src/pcg/machine_view.cc index 00bf1296fe..c09ab1a3c9 100644 --- a/lib/pcg/src/pcg/machine_view.cc +++ b/lib/pcg/src/pcg/machine_view.cc @@ -1,5 +1,7 @@ #include "pcg/machine_view.h" +#include "pcg/device_id.h" #include "pcg/strided_rectangle.dtg.h" +#include "pcg/strided_rectangle.h" #include "pcg/strided_rectangle_side.h" namespace FlexFlow { @@ -8,16 +10,16 @@ std::vector device_ids(MachineView const &) { NOT_IMPLEMENTED(); } -std::size_t num_dims(MachineView const &) { - NOT_IMPLEMENTED(); +std::size_t num_dims(MachineView const &mv) { + return get_num_dims(mv.rect); } -std::size_t num_devices(MachineView const &) { - NOT_IMPLEMENTED(); +size_t num_devices(MachineView const &mv) { + return get_num_points(mv.rect).unwrapped; } -DeviceType get_device_type(MachineView const &) { - NOT_IMPLEMENTED(); +DeviceType get_device_type(MachineView const &mv) { + return get_device_type(mv.start); } static StridedRectangle make_1d_rect(int start, int stop, int stride) { @@ -40,18 +42,73 @@ MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) { return MachineView{device_id_t{start}, rect}; } +MachineView + make_1d_machine_view(device_id_t start, device_id_t stop, int stride) { + assert(get_device_type(start) == get_device_type(stop)); + if (get_device_type(start) == DeviceType::CPU) { + return make_1d_machine_view(unwrap_cpu(start), unwrap_cpu(stop), stride); + } + assert(get_device_type(start) == DeviceType::GPU); + return make_1d_machine_view(unwrap_gpu(start), unwrap_gpu(stop), stride); +} + +static StridedRectangle + make_1d_rect(int start, num_points_t num_points, int stride) { + return make_1d_rect(start, start + num_points.unwrapped * stride, stride); +} + +MachineView + make_1d_machine_view(cpu_id_t start, num_points_t num_points, int stride) { + StridedRectangle rect = make_1d_rect(start.cpu_index, num_points, stride); + return MachineView{device_id_t{start}, rect}; +} + +MachineView + make_1d_machine_view(gpu_id_t start, num_points_t num_points, int stride) { + StridedRectangle rect = make_1d_rect(start.gpu_index, num_points, stride); + return MachineView{device_id_t{start}, rect}; +} + MachineView make_1d_machine_view(device_id_t start, num_points_t num_points, int stride) { - NOT_IMPLEMENTED(); + if (get_device_type(start) == DeviceType::CPU) { + return make_1d_machine_view(unwrap_cpu(start), num_points, stride); + } else { + assert(get_device_type(start) == DeviceType::GPU); + return make_1d_machine_view(unwrap_gpu(start), num_points, stride); + } } -MachineView make_1d_machine_view(device_id_t start, +static StridedRectangle + make_1d_rect(int start, side_size_t interval_size, int stride) { + return make_1d_rect(start, start + interval_size.unwrapped, stride); +} + +MachineView make_1d_machine_view(cpu_id_t start, side_size_t interval_size, int stride) { - NOT_IMPLEMENTED(); + StridedRectangle rect = make_1d_rect(start.cpu_index, interval_size, stride); + return MachineView{device_id_t{start}, rect}; +} + +MachineView make_1d_machine_view(gpu_id_t start, + side_size_t interval_size, + int stride) { + StridedRectangle rect = make_1d_rect(start.gpu_index, interval_size, stride); + return MachineView{device_id_t{start}, rect}; } +MachineView make_1d_machine_view(device_id_t start, + side_size_t interval_size, + int stride) { + if (get_device_type(start) == DeviceType::CPU) { + return make_1d_machine_view(unwrap_cpu(start), interval_size, stride); + } else { + assert(get_device_type(start) == DeviceType::GPU); + return make_1d_machine_view(unwrap_gpu(start), interval_size, stride); + } +} MachineView make_1d_machine_view(device_id_t start, size_t interval_size) { NOT_IMPLEMENTED(); } diff --git a/lib/pcg/src/pcg/strided_rectangle_side.cc b/lib/pcg/src/pcg/strided_rectangle_side.cc index 5e7274141d..e6caf4cb86 100644 --- a/lib/pcg/src/pcg/strided_rectangle_side.cc +++ b/lib/pcg/src/pcg/strided_rectangle_side.cc @@ -3,9 +3,11 @@ namespace FlexFlow { -StridedRectangleSide strided_side_from_size_and_stride(side_size_t, +StridedRectangleSide strided_side_from_size_and_stride(side_size_t side_size, int stride) { - NOT_IMPLEMENTED(); + assert((side_size.unwrapped % stride) == 0); + return StridedRectangleSide{num_points_t{side_size.unwrapped / stride}, + stride}; } side_size_t get_side_size(StridedRectangleSide const &s) { diff --git a/lib/pcg/src/strided_rectangle.cc b/lib/pcg/src/strided_rectangle.cc index 9c8ff69b42..1c61424ab9 100644 --- a/lib/pcg/src/strided_rectangle.cc +++ b/lib/pcg/src/strided_rectangle.cc @@ -1,4 +1,5 @@ #include "pcg/strided_rectangle.h" +#include "op-attrs/dim_ordered/transform.h" #include "utils/containers.h" namespace FlexFlow { @@ -15,12 +16,20 @@ namespace FlexFlow { /* return idx; */ /* } */ -size_t get_num_dims(StridedRectangle const &) { - NOT_IMPLEMENTED(); +size_t get_num_dims(StridedRectangle const &rect) { + return rect.sides.size(); } -size_t get_side_at_idx(StridedRectangle const &) { - NOT_IMPLEMENTED(); +num_points_t get_num_points(StridedRectangle const &rect) { + return num_points_t{ + product(transform(rect.sides, [](StridedRectangleSide const &side) { + return side.num_points.unwrapped; + }))}; +} + +StridedRectangleSide get_side_at_idx(StridedRectangle const &rect, + ff_dim_t const &idx) { + return rect.sides.at(idx); } } // namespace FlexFlow diff --git a/lib/pcg/test/src/test_machine_view.cc b/lib/pcg/test/src/test_machine_view.cc new file mode 100644 index 0000000000..92a96d5e9a --- /dev/null +++ b/lib/pcg/test/src/test_machine_view.cc @@ -0,0 +1,74 @@ +#include "doctest/doctest.h" +#include "pcg/machine_view.h" +#include "pcg/strided_rectangle.h" +#include "pcg/strided_rectangle_side.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("MachineView general util functions") { + StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}, + StridedRectangleSide{num_points_t{10}, 2}}}; + gpu_id_t start(1); + MachineView mv{device_id_t{start}, rect}; + SUBCASE("num_dims") { + CHECK(num_dims(mv) == 2); + } + SUBCASE("num_devices") { + CHECK(num_devices(mv) == 7 * 10); + } + SUBCASE("get_device_type") { + CHECK(get_device_type(mv) == DeviceType::GPU); + } + } + + TEST_CASE("MachineView make_1d_machine_view - GPU") { + StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}}}; + device_id_t start_gpu{gpu_id_t{1}}; + MachineView gpu_mv{start_gpu, rect}; + + SUBCASE("make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride)") { + MachineView result = + make_1d_machine_view(start_gpu, device_id_t{gpu_id_t(1 + 7 * 5)}, 5); + MachineView correct = gpu_mv; + CHECK(result == correct); + } + SUBCASE("make_1d_machine_view(gpu_id_t start, num_points_t num_points, int " + "stride)") { + MachineView result = make_1d_machine_view(start_gpu, num_points_t{7}, 5); + MachineView correct = gpu_mv; + CHECK(result == correct); + } + SUBCASE("make_1d_machine_view(gpu_id_t start, side_size_t interval_size, " + "int stride)") { + MachineView result = make_1d_machine_view( + start_gpu, get_side_size(rect.sides.at(ff_dim_t{0})), 5); + MachineView correct = gpu_mv; + CHECK(result == correct); + } + } + + TEST_CASE("MachineView make_1d_machine_view - CPU") { + StridedRectangle rect{{StridedRectangleSide{num_points_t{11}, 4}}}; + device_id_t start_cpu{cpu_id_t{2}}; + MachineView cpu_mv{start_cpu, rect}; + + SUBCASE("make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride)") { + MachineView result = + make_1d_machine_view(start_cpu, device_id_t{cpu_id_t(2 + 11 * 4)}, 4); + MachineView correct = cpu_mv; + CHECK(result == correct); + } + SUBCASE("make_1d_machine_view(cpu_id_t start, num_points_t num_points, int " + "stride)") { + MachineView result = make_1d_machine_view(start_cpu, num_points_t{11}, 4); + MachineView correct = cpu_mv; + CHECK(result == correct); + } + SUBCASE("make_1d_machine_view(cpu_id_t start, side_size_t interval_size, " + "int stride)") { + MachineView result = make_1d_machine_view( + start_cpu, get_side_size(rect.sides.at(ff_dim_t{0})), 4); + MachineView correct = cpu_mv; + CHECK(result == correct); + } + } +} diff --git a/lib/pcg/test/src/test_strided_rectangle.cc b/lib/pcg/test/src/test_strided_rectangle.cc new file mode 100644 index 0000000000..ef342944de --- /dev/null +++ b/lib/pcg/test/src/test_strided_rectangle.cc @@ -0,0 +1,37 @@ +#include "doctest/doctest.h" +#include "pcg/strided_rectangle.h" +#include "pcg/strided_rectangle_side.h" + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("get_side_size(StridedRectangleSide)") { + StridedRectangleSide side{num_points_t{7}, 5}; + + CHECK(get_side_size(side) == side_size_t{7 * 5}); + } + TEST_CASE("strided_side_from_size_and_stride") { + StridedRectangleSide correct{num_points_t{10}, 3}; + StridedRectangleSide result = + strided_side_from_size_and_stride(side_size_t{10 * 3}, 3); + CHECK(result == correct); + } + + TEST_CASE("StridedRectangle - helper functions") { + + StridedRectangleSide s0{num_points_t{7}, 5}; + StridedRectangleSide s1{num_points_t{10}, 2}; + StridedRectangleSide s2{num_points_t{8}, 1}; + StridedRectangle rect{{s0, s1, s2}}; + + SUBCASE("get_num_dims") { + CHECK(get_num_dims(rect) == 3); + } + SUBCASE("get_num_points") { + CHECK(get_num_points(rect) == num_points_t{7 * 8 * 10}); + } + SUBCASE("get_side_at_idx") { + CHECK(get_side_at_idx(rect, ff_dim_t{0}) == s0); + CHECK(get_side_at_idx(rect, ff_dim_t{1}) == s1); + CHECK(get_side_at_idx(rect, ff_dim_t{2}) == s2); + } + } +}