Skip to content

Commit

Permalink
Implementations for methods for machine_views and associated modules (#…
Browse files Browse the repository at this point in the history
…1429)

* initial commit for machine view adjacent modules

* Formatting

* Tests for new machine_view.cc functions

* formatting

* Minor Test correction

* formatting

* PR fixes

* PR Fixes

---------

Co-authored-by: Pietro Max Marsella <[email protected]>
  • Loading branch information
2 people authored and oOTigger committed Jul 31, 2024
1 parent 723515b commit ba586ae
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 19 deletions.
2 changes: 1 addition & 1 deletion lib/pcg/include/pcg/device_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace FlexFlow {

device_id_t operator+(device_id_t, size_t);

DeviceType get_device_type(device_id_t);
DeviceType get_device_type(device_id_t const &device_id);
gpu_id_t unwrap_gpu(device_id_t);
cpu_id_t unwrap_cpu(device_id_t);

Expand Down
19 changes: 18 additions & 1 deletion lib/pcg/include/pcg/machine_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H

#include "pcg/cpu_id_t.dtg.h"
#include "pcg/device_id.h"
#include "pcg/device_id_t.dtg.h"
#include "pcg/device_type.dtg.h"
#include "pcg/gpu_id_t.dtg.h"
Expand All @@ -14,15 +15,31 @@
namespace FlexFlow {

std::vector<device_id_t> device_ids(MachineView const &);
std::size_t num_dims(MachineView const &);
size_t num_dims(MachineView const &);
std::size_t num_devices(MachineView const &);
DeviceType get_device_type(MachineView const &);

MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride = 1);
MachineView
make_1d_machine_view(device_id_t start, device_id_t stop, int stride = 1);

MachineView make_1d_machine_view(gpu_id_t start,
num_points_t num_points,
int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start,
num_points_t num_points,
int stride = 1);
MachineView make_1d_machine_view(device_id_t start,
num_points_t num_points,
int stride = 1);

MachineView make_1d_machine_view(gpu_id_t start,
side_size_t interval_size,
int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start,
side_size_t interval_size,
int stride = 1);
MachineView make_1d_machine_view(device_id_t start,
side_size_t interval_size,
int stride = 1);
Expand Down
5 changes: 3 additions & 2 deletions lib/pcg/include/pcg/strided_rectangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
namespace FlexFlow {

size_t get_num_dims(StridedRectangle const &);
StridedRectangleSide get_side_at_idx(StridedRectangle const &,
ff_dim_t const &);
StridedRectangleSide get_side_at_idx(StridedRectangle const &rect,
ff_dim_t const &idx);
num_points_t get_num_points(StridedRectangle const &rect);

} // namespace FlexFlow

Expand Down
75 changes: 66 additions & 9 deletions lib/pcg/src/pcg/machine_view.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "pcg/machine_view.h"
#include "pcg/device_id.h"
#include "pcg/strided_rectangle.dtg.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

namespace FlexFlow {
Expand All @@ -8,16 +10,16 @@ std::vector<device_id_t> device_ids(MachineView const &) {
NOT_IMPLEMENTED();
}

std::size_t num_dims(MachineView const &) {
NOT_IMPLEMENTED();
std::size_t num_dims(MachineView const &mv) {
return get_num_dims(mv.rect);
}

std::size_t num_devices(MachineView const &) {
NOT_IMPLEMENTED();
size_t num_devices(MachineView const &mv) {
return get_num_points(mv.rect).unwrapped;
}

DeviceType get_device_type(MachineView const &) {
NOT_IMPLEMENTED();
DeviceType get_device_type(MachineView const &mv) {
return get_device_type(mv.start);
}

static StridedRectangle make_1d_rect(int start, int stop, int stride) {
Expand All @@ -40,18 +42,73 @@ MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) {
return MachineView{device_id_t{start}, rect};
}

MachineView
make_1d_machine_view(device_id_t start, device_id_t stop, int stride) {
assert(get_device_type(start) == get_device_type(stop));
if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), unwrap_cpu(stop), stride);
}
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), unwrap_gpu(stop), stride);
}

static StridedRectangle
make_1d_rect(int start, num_points_t num_points, int stride) {
return make_1d_rect(start, start + num_points.unwrapped * stride, stride);
}

MachineView
make_1d_machine_view(cpu_id_t start, num_points_t num_points, int stride) {
StridedRectangle rect = make_1d_rect(start.cpu_index, num_points, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView
make_1d_machine_view(gpu_id_t start, num_points_t num_points, int stride) {
StridedRectangle rect = make_1d_rect(start.gpu_index, num_points, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView make_1d_machine_view(device_id_t start,
num_points_t num_points,
int stride) {
NOT_IMPLEMENTED();
if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), num_points, stride);
} else {
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), num_points, stride);
}
}

MachineView make_1d_machine_view(device_id_t start,
static StridedRectangle
make_1d_rect(int start, side_size_t interval_size, int stride) {
return make_1d_rect(start, start + interval_size.unwrapped, stride);
}

MachineView make_1d_machine_view(cpu_id_t start,
side_size_t interval_size,
int stride) {
NOT_IMPLEMENTED();
StridedRectangle rect = make_1d_rect(start.cpu_index, interval_size, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView make_1d_machine_view(gpu_id_t start,
side_size_t interval_size,
int stride) {
StridedRectangle rect = make_1d_rect(start.gpu_index, interval_size, stride);
return MachineView{device_id_t{start}, rect};
}
MachineView make_1d_machine_view(device_id_t start,
side_size_t interval_size,
int stride) {

if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), interval_size, stride);
} else {
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), interval_size, stride);
}
}
MachineView make_1d_machine_view(device_id_t start, size_t interval_size) {
NOT_IMPLEMENTED();
}
Expand Down
6 changes: 4 additions & 2 deletions lib/pcg/src/pcg/strided_rectangle_side.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

namespace FlexFlow {

StridedRectangleSide strided_side_from_size_and_stride(side_size_t,
StridedRectangleSide strided_side_from_size_and_stride(side_size_t side_size,
int stride) {
NOT_IMPLEMENTED();
assert((side_size.unwrapped % stride) == 0);
return StridedRectangleSide{num_points_t{side_size.unwrapped / stride},
stride};
}

side_size_t get_side_size(StridedRectangleSide const &s) {
Expand Down
17 changes: 13 additions & 4 deletions lib/pcg/src/strided_rectangle.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "pcg/strided_rectangle.h"
#include "op-attrs/dim_ordered/transform.h"
#include "utils/containers.h"

namespace FlexFlow {
Expand All @@ -15,12 +16,20 @@ namespace FlexFlow {
/* return idx; */
/* } */

size_t get_num_dims(StridedRectangle const &) {
NOT_IMPLEMENTED();
size_t get_num_dims(StridedRectangle const &rect) {
return rect.sides.size();
}

size_t get_side_at_idx(StridedRectangle const &) {
NOT_IMPLEMENTED();
num_points_t get_num_points(StridedRectangle const &rect) {
return num_points_t{
product(transform(rect.sides, [](StridedRectangleSide const &side) {
return side.num_points.unwrapped;
}))};
}

StridedRectangleSide get_side_at_idx(StridedRectangle const &rect,
ff_dim_t const &idx) {
return rect.sides.at(idx);
}

} // namespace FlexFlow
74 changes: 74 additions & 0 deletions lib/pcg/test/src/test_machine_view.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "doctest/doctest.h"
#include "pcg/machine_view.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("MachineView general util functions") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5},
StridedRectangleSide{num_points_t{10}, 2}}};
gpu_id_t start(1);
MachineView mv{device_id_t{start}, rect};
SUBCASE("num_dims") {
CHECK(num_dims(mv) == 2);
}
SUBCASE("num_devices") {
CHECK(num_devices(mv) == 7 * 10);
}
SUBCASE("get_device_type") {
CHECK(get_device_type(mv) == DeviceType::GPU);
}
}

TEST_CASE("MachineView make_1d_machine_view - GPU") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}}};
device_id_t start_gpu{gpu_id_t{1}};
MachineView gpu_mv{start_gpu, rect};

SUBCASE("make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride)") {
MachineView result =
make_1d_machine_view(start_gpu, device_id_t{gpu_id_t(1 + 7 * 5)}, 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(gpu_id_t start, num_points_t num_points, int "
"stride)") {
MachineView result = make_1d_machine_view(start_gpu, num_points_t{7}, 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(gpu_id_t start, side_size_t interval_size, "
"int stride)") {
MachineView result = make_1d_machine_view(
start_gpu, get_side_size(rect.sides.at(ff_dim_t{0})), 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
}

TEST_CASE("MachineView make_1d_machine_view - CPU") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{11}, 4}}};
device_id_t start_cpu{cpu_id_t{2}};
MachineView cpu_mv{start_cpu, rect};

SUBCASE("make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride)") {
MachineView result =
make_1d_machine_view(start_cpu, device_id_t{cpu_id_t(2 + 11 * 4)}, 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(cpu_id_t start, num_points_t num_points, int "
"stride)") {
MachineView result = make_1d_machine_view(start_cpu, num_points_t{11}, 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(cpu_id_t start, side_size_t interval_size, "
"int stride)") {
MachineView result = make_1d_machine_view(
start_cpu, get_side_size(rect.sides.at(ff_dim_t{0})), 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
}
}
37 changes: 37 additions & 0 deletions lib/pcg/test/src/test_strided_rectangle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "doctest/doctest.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("get_side_size(StridedRectangleSide)") {
StridedRectangleSide side{num_points_t{7}, 5};

CHECK(get_side_size(side) == side_size_t{7 * 5});
}
TEST_CASE("strided_side_from_size_and_stride") {
StridedRectangleSide correct{num_points_t{10}, 3};
StridedRectangleSide result =
strided_side_from_size_and_stride(side_size_t{10 * 3}, 3);
CHECK(result == correct);
}

TEST_CASE("StridedRectangle - helper functions") {

StridedRectangleSide s0{num_points_t{7}, 5};
StridedRectangleSide s1{num_points_t{10}, 2};
StridedRectangleSide s2{num_points_t{8}, 1};
StridedRectangle rect{{s0, s1, s2}};

SUBCASE("get_num_dims") {
CHECK(get_num_dims(rect) == 3);
}
SUBCASE("get_num_points") {
CHECK(get_num_points(rect) == num_points_t{7 * 8 * 10});
}
SUBCASE("get_side_at_idx") {
CHECK(get_side_at_idx(rect, ff_dim_t{0}) == s0);
CHECK(get_side_at_idx(rect, ff_dim_t{1}) == s1);
CHECK(get_side_at_idx(rect, ff_dim_t{2}) == s2);
}
}
}

0 comments on commit ba586ae

Please sign in to comment.