Skip to content

Commit

Permalink
R & W accessor changes, minimize code bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
oOTigger committed Nov 5, 2024
1 parent 7106dec commit 51c3eb7
Show file tree
Hide file tree
Showing 38 changed files with 330 additions and 456 deletions.
154 changes: 64 additions & 90 deletions lib/kernels/include/kernels/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,54 +13,36 @@ namespace FlexFlow {

struct Allocator;

class GenericTensorAccessorW {
class GenericTensorAccessorR {
public:
template <DataType DT>
typename data_type_enum_to_class<DT>::type *get() const {
typename data_type_enum_to_class<DT>::type const *get() const {
if (this->data_type == DT) {
return static_cast<real_type_t<DT> *>(this->ptr);
return static_cast<real_type_t<DT> const *>(this->ptr);
} else {
throw mk_runtime_error(fmt::format(
"Invalid access data type ({} != {})", this->data_type, DT));
}
}

int32_t *get_int32_ptr() const;
int64_t *get_int64_ptr() const;
float *get_float_ptr() const;
double *get_double_ptr() const;
half *get_half_ptr() const;
int32_t const *get_int32_ptr() const;
int64_t const *get_int64_ptr() const;
float const *get_float_ptr() const;
double const *get_double_ptr() const;
half const *get_half_ptr() const;

GenericTensorAccessorW() = delete;
GenericTensorAccessorR() = delete;

GenericTensorAccessorW(DataType data_type,
GenericTensorAccessorR(DataType data_type,
ArrayShape const &shape,
void *ptr,
void const *ptr,
DeviceType device_type);

bool operator==(GenericTensorAccessorW const &) const;
bool operator!=(GenericTensorAccessorW const &) const;

template <DataType DT, typename... Indices>
real_type_t<DT> &at(Indices... indices) {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
if (this->data_type != DT) {
throw mk_runtime_error(fmt::format(
"Invalid access data type ({} != {})", this->data_type, DT));
}

using T = real_type_t<DT>;

T *data_ptr = static_cast<T *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});

return data_ptr[offset];
}
bool operator==(GenericTensorAccessorR const &) const;
bool operator!=(GenericTensorAccessorR const &) const;

template <DataType DT, typename... Indices>
real_type_t<DT> const &at(Indices... indices) const {
template <DataType DT>
real_type_t<DT> const &at(std::vector<size_t> const &indices) const {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
Expand All @@ -72,15 +54,15 @@ class GenericTensorAccessorW {
using T = real_type_t<DT>;

T const *data_ptr = static_cast<T const *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});
size_t offset = calculate_index_offset(indices);

return data_ptr[offset];
}

public:
DataType data_type;
ArrayShape shape;
void *ptr;
void const *ptr;
DeviceType device_type;

private:
Expand All @@ -90,43 +72,62 @@ class GenericTensorAccessorW {
decltype(device_type) const &>
tie() const;

size_t calculate_index_offset(
std::initializer_list<size_t> const &indices) const;
size_t calculate_index_offset(std::vector<size_t> const &indices) const;
};

std::string format_as(GenericTensorAccessorW const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorW const &);
std::string format_as(GenericTensorAccessorR const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorR const &);

class GenericTensorAccessorR {
class GenericTensorAccessorW {
public:
template <DataType DT>
typename data_type_enum_to_class<DT>::type const *get() const {
typename data_type_enum_to_class<DT>::type *get() const {
if (this->data_type == DT) {
return static_cast<real_type_t<DT> const *>(this->ptr);
return static_cast<real_type_t<DT> *>(this->ptr);
} else {
throw mk_runtime_error(fmt::format(
"Invalid access data type ({} != {})", this->data_type, DT));
}
}

int32_t const *get_int32_ptr() const;
int64_t const *get_int64_ptr() const;
float const *get_float_ptr() const;
double const *get_double_ptr() const;
half const *get_half_ptr() const;
int32_t *get_int32_ptr() const;
int64_t *get_int64_ptr() const;
float *get_float_ptr() const;
double *get_double_ptr() const;
half *get_half_ptr() const;

GenericTensorAccessorR() = delete;
GenericTensorAccessorW() = delete;

GenericTensorAccessorR(DataType data_type,
GenericTensorAccessorW(DataType data_type,
ArrayShape const &shape,
void const *ptr,
void *ptr,
DeviceType device_type);

bool operator==(GenericTensorAccessorR const &) const;
bool operator!=(GenericTensorAccessorR const &) const;
bool operator==(GenericTensorAccessorW const &) const;
bool operator!=(GenericTensorAccessorW const &) const;

operator GenericTensorAccessorR() const;

template <DataType DT>
real_type_t<DT> &at(std::vector<size_t> const &indices) {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
if (this->data_type != DT) {
throw mk_runtime_error(fmt::format(
"Invalid access data type ({} != {})", this->data_type, DT));
}

using T = real_type_t<DT>;

T *data_ptr = static_cast<T *>(this->ptr);
size_t offset = calculate_index_offset(indices);

return data_ptr[offset];
}

template <DataType DT, typename... Indices>
real_type_t<DT> const &at(Indices... indices) const {
template <DataType DT>
real_type_t<DT> &at(std::vector<size_t> const &indices) const {
if (this->device_type != DeviceType::CPU) {
throw mk_runtime_error("Calling at() on non-CPU allocated tensor");
}
Expand All @@ -138,15 +139,15 @@ class GenericTensorAccessorR {
using T = real_type_t<DT>;

T const *data_ptr = static_cast<T const *>(this->ptr);
size_t offset = calculate_index_offset({static_cast<size_t>(indices)...});
size_t offset = calculate_index_offset(indices);

return data_ptr[offset];
}

public:
DataType data_type;
ArrayShape shape;
void const *ptr;
void *ptr;
DeviceType device_type;

private:
Expand All @@ -156,27 +157,11 @@ class GenericTensorAccessorR {
decltype(device_type) const &>
tie() const;

size_t calculate_index_offset(
std::initializer_list<size_t> const &indices) const;
size_t calculate_index_offset(std::vector<size_t> const &indices) const;
};

std::string format_as(GenericTensorAccessorR const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorR const &);

int32_t *get_int32_ptr(GenericTensorAccessorW const &);
int64_t *get_int64_ptr(GenericTensorAccessorW const &);
float *get_float_ptr(GenericTensorAccessorW const &);
double *get_double_ptr(GenericTensorAccessorW const &);
half *get_half_ptr(GenericTensorAccessorW const &);
std::vector<int32_t *>
get_int32_ptrs(std::vector<GenericTensorAccessorW> const &);
std::vector<int64_t *>
get_int64_ptrs(std::vector<GenericTensorAccessorW> const &);
std::vector<float *>
get_float_ptrs(std::vector<GenericTensorAccessorW> const &);
std::vector<double *>
get_double_ptrs(std::vector<GenericTensorAccessorW> const &);
std::vector<half *> get_half_ptrs(std::vector<GenericTensorAccessorW> const &);
std::string format_as(GenericTensorAccessorW const &);
std::ostream &operator<<(std::ostream &, GenericTensorAccessorW const &);

static_assert(is_fmtable<req<DataType> const &>::value, "");

Expand Down Expand Up @@ -241,29 +226,18 @@ std::vector<real_type_t<DT> const *>
GenericTensorAccessorR read_only_accessor_from_write_accessor(
GenericTensorAccessorW const &write_accessor);

bool is_shape_and_dtype_equal(GenericTensorAccessorW const &acc1,
GenericTensorAccessorW const &acc2);

bool shape_and_dtype_matches(GenericTensorAccessorW const &accessor,
ArrayShape const &expected_shape,
DataType const &expected_dtype);
bool is_shape_and_dtype_equal(GenericTensorAccessorR const &acc1,
GenericTensorAccessorR const &acc2);

bool shape_and_dtype_matches(GenericTensorAccessorR const &accessor,
ArrayShape const &expected_shape,
DataType const &expected_dtype);

std::pair<ArrayShape, DataType>
get_shape_and_datatype(GenericTensorAccessorR const &accessor);
std::pair<ArrayShape, DataType>
get_shape_and_datatype(GenericTensorAccessorW const &accessor);

void transfer_data_between_accessors(
GenericTensorAccessorW &dst_accessor,
GenericTensorAccessorR const &src_accessor);

void transfer_data_between_accessors(
GenericTensorAccessorW &dst_accessor,
GenericTensorAccessorW const &src_accessor);
void copy_accessor_data_to_l_from_r(GenericTensorAccessorW &dst_accessor,
GenericTensorAccessorR const &src_accessor);

GenericTensorAccessorR
copy_tensor_accessor_r(GenericTensorAccessorR const &src_accessor,
Expand Down
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/cast_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@ namespace FlexFlow::Kernels::Cast {

void forward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);
GenericTensorAccessorW const &output);

void backward_kernel(ffStream_t stream,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);
GenericTensorAccessorW const &output);

} // namespace FlexFlow::Kernels::Cast

Expand Down
8 changes: 2 additions & 6 deletions lib/kernels/include/kernels/cast_kernels_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
namespace FlexFlow::Kernels::Cast {

void cpu_forward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);
GenericTensorAccessorW const &output);

void cpu_backward_kernel(GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output,
DataType input_type,
DataType output_type);
GenericTensorAccessorW const &output);

} // namespace FlexFlow::Kernels::Cast

Expand Down
10 changes: 5 additions & 5 deletions lib/kernels/include/kernels/datatype_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ struct DataTypeDispatch1 {
template <typename... Args,
typename Out = decltype(std::declval<F<DataType::FLOAT>>()(
std::declval<Args>()...))>
Out operator()(Args... args) const {
Out operator()(Args &&...args) const {
return F<DT>{}(std::forward<Args>(args)...);
}
};

template <typename... Args,
typename Out = decltype(std::declval<F<DataType::FLOAT>>()(
std::declval<Args>()...))>
Out operator()(DataType data_type, Args... args) {
Out operator()(DataType data_type, Args &&...args) {
return dispatch<Type1Dispatch>(data_type, std::forward<Args>(args)...);
}
};
Expand All @@ -55,21 +55,21 @@ struct DataTypeDispatch2 {
template <DataType OT>
struct OutputType {
template <typename... Args>
void operator()(Args... args) const {
void operator()(Args &&...args) const {
F<IT, OT>{}(std::forward<Args>(args)...);
}
};

template <typename... Args>
void operator()(DataType output_type, Args... args) const {
void operator()(DataType output_type, Args &&...args) const {
dispatch<OutputType>(output_type, std::forward<Args>(args)...);
}
};

template <typename... Args>
void operator()(DataType input_data_type,
DataType output_data_type,
Args... args) {
Args &&...args) {
dispatch<InputType>(
input_data_type, output_data_type, std::forward<Args>(args)...);
}
Expand Down
5 changes: 4 additions & 1 deletion lib/kernels/include/kernels/managed_per_device_ff_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ namespace FlexFlow {

struct ManagedPerDeviceFFHandle {
public:
ManagedPerDeviceFFHandle();
ManagedPerDeviceFFHandle() = delete;

ManagedPerDeviceFFHandle(size_t workSpaceSize,
bool allowTensorOpMathConversion);

ManagedPerDeviceFFHandle(ManagedPerDeviceFFHandle const &) = delete;
ManagedPerDeviceFFHandle &
Expand Down
Loading

0 comments on commit 51c3eb7

Please sign in to comment.