Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable substitutions #1471

Merged
merged 77 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 71 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
510c2d9
Start on pcg builder
lockshaw Jun 4, 2024
7b55ed1
Add tests and some implementation for pcg builder
lockshaw Jun 4, 2024
c379efd
Add pcg tests, make dtgen constructors explicit to fix bug
lockshaw Jun 10, 2024
35fa653
Add remainder of PCG tests
lockshaw Jun 10, 2024
865a28e
Merge remote-tracking branch 'origin/repo-refactor' into pcg-builder
lockshaw Jun 10, 2024
f379539
Fix build issues in local-execution
lockshaw Jun 10, 2024
2dbb3b9
Format
lockshaw Jun 10, 2024
4050c99
Address Reyna comments, add topological_order function for PCG
lockshaw Jun 17, 2024
42c1968
Pre multidigraph refactor
lockshaw Jun 19, 2024
3be816f
Removing visitable from sp code
lockshaw Jun 21, 2024
6d68324
Add open dataflow graph, start to replace pcg dataflow graph
lockshaw Jun 23, 2024
64a3403
Start refactoring substitutions
lockshaw Jun 24, 2024
7d4c7be
Add utility functions to support pattern matching
lockshaw Jun 25, 2024
9ab9eb2
Pre-refactor inputs
lockshaw Jun 26, 2024
7ae7c65
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jun 26, 2024
f9b129e
Fix proj url
lockshaw Jun 26, 2024
cf73f08
Get back to substitutions, now with unordered graph inputs
lockshaw Jul 7, 2024
5fd666d
Get substitutions building
lockshaw Jul 13, 2024
5f0c88a
substitutions-tests now builds
lockshaw Jul 13, 2024
3228f2d
Fix bug in filter, pass some initial substitution tests
lockshaw Jul 14, 2024
5f4cc01
Add tests for fmt::to_string, fix some substitutions bugs
lockshaw Jul 15, 2024
ad60be0
Pass initial unit tests for find_pattern_matches
lockshaw Jul 15, 2024
a972da2
Start on unit tests for pcg pattern
lockshaw Jul 15, 2024
bcf776e
Pass initial test for find_pattern_matches
lockshaw Jul 19, 2024
e28400e
Merge remote-tracking branch 'origin/repo-refactor' into dataflow-graph
lockshaw Jul 19, 2024
fe6d65d
Fix small build issue in tests
lockshaw Jul 19, 2024
e647af7
Format
lockshaw Jul 19, 2024
8b58760
Sync tests in CI with tests in proj
lockshaw Jul 19, 2024
1fafb9d
Fix minor build errors in kernels and local-execution
lockshaw Jul 19, 2024
0804314
Format
lockshaw Jul 19, 2024
dd5465c
Remove outdated code
lockshaw Jul 20, 2024
29ec5b8
More outdated code removal
lockshaw Jul 20, 2024
ff41743
More cleanup, add test for sp decomposition
lockshaw Jul 20, 2024
e71d200
Pull apart containers.h
lockshaw Jul 21, 2024
c06710c
More sp testing and fixes
lockshaw Jul 21, 2024
2f75566
Break up graph algorithms.h
lockshaw Jul 21, 2024
c81d3a4
Pre- full SP algo commit
lockshaw Jul 23, 2024
2a11c7e
Add initial implementation and tests for cbc decomposition and invers…
lockshaw Jul 23, 2024
71a9e0f
Pass test for get_inverse_line_graph
lockshaw Jul 24, 2024
25eb1db
Add new multidigraph
lockshaw Jul 24, 2024
64f1932
Fix get_inverse_line_graph to return a MultiDiGraph instead of a DiGraph
lockshaw Jul 24, 2024
31c8d17
Add tests for parallel and series reduction finding
lockshaw Jul 24, 2024
19e7e28
Add really rough implementation of valdez sp decomposition
lockshaw Jul 24, 2024
3791e86
Fix local-execution build
lockshaw Jul 25, 2024
267b72d
Add implementations and tests for applying series/parallel reductions
lockshaw Jul 25, 2024
bb2769a
Format
lockshaw Jul 26, 2024
39cb7b3
Clean up sp decomposition interface and tests
lockshaw Jul 27, 2024
ce0234d
Format
lockshaw Jul 27, 2024
3dc3ec6
Add comments for top-level substitutions functions, add proj doxygen …
lockshaw Jul 31, 2024
ee518c2
Start sketching out substitutions code
lockshaw Jul 31, 2024
f69b95a
Merge branch 'dataflow-graph' into substitutions-fix
lockshaw Jul 31, 2024
3c06b88
Fix build errors
lockshaw Aug 1, 2024
3d6f681
Add ability to permute node ids
lockshaw Aug 1, 2024
098a9d1
Cleanup and start to test new substitutions code
lockshaw Aug 4, 2024
9bd4f14
Add test case for evaluate_substitution_output
lockshaw Aug 5, 2024
101083b
Add naive isomorphism detection code
lockshaw Aug 5, 2024
9fec50c
Add graph inputs to open dataflow graph isomorphism
lockshaw Aug 6, 2024
7c60736
Add input permutation to evaluate_substitution_output
lockshaw Aug 6, 2024
cb6eab2
Fix permute_node_ids
lockshaw Aug 8, 2024
2f3d67a
Add test for permute_input_ids
lockshaw Aug 8, 2024
03cbd02
Migrate over to mutable implementation of apply_substitution
lockshaw Aug 23, 2024
4a8deae
Add fast isomorphism checking and an initial implementation of full s…
lockshaw Aug 24, 2024
0757e94
Pass initial full substitutions test
lockshaw Aug 24, 2024
ba0a174
Cleanup old isomorphism checking code
lockshaw Aug 24, 2024
4dfa403
Merge remote-tracking branch 'origin/repo-refactor' into substitution…
lockshaw Aug 24, 2024
f156f96
Fix post-merge bugs
lockshaw Aug 24, 2024
5f09298
Fix broken pcg builder test
lockshaw Aug 26, 2024
deff4f8
Format
lockshaw Aug 26, 2024
d71d24f
Reorganize code and remove some outdated code pre-code-review
lockshaw Aug 26, 2024
1a63f90
Format
lockshaw Aug 26, 2024
aecbbe6
Address review comments
lockshaw Aug 31, 2024
afc6f7f
Address missed comment
lockshaw Aug 31, 2024
7d6fadf
Remove latex dependency to avoid CI out-of-disk-space
lockshaw Aug 31, 2024
f804308
Format
lockshaw Aug 31, 2024
82c42d7
Merge remote-tracking branch 'origin/repo-refactor' into substitution…
lockshaw Sep 4, 2024
8499ed1
Fix build issues
lockshaw Sep 5, 2024
9e39e08
Fix incorrect test case
lockshaw Sep 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 3 additions & 3 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
tl-expected
doxygen
lcov # for code coverage
texliveFull
])
(with proj-repo.packages.${system}; [
proj
Expand Down
2 changes: 1 addition & 1 deletion lib/compiler/src/machine_mapping.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
#include "utils/containers/contains_key.h"
#include "utils/containers/get_only.h"
#include "utils/containers/keys.h"
#include "utils/containers/merge_maps.h"
#include "utils/exception.h"
#include "utils/graph/graph_split.dtg.h"
#include "utils/graph/node/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms.h"
#include "utils/graph/open_dataflow_graph/algorithms/get_subgraph.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.dtg.h"
#include "utils/graph/serial_parallel/serial_parallel_decomposition.h"
Expand Down
1 change: 0 additions & 1 deletion lib/local-execution/src/ops/pool_2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/pool_2d.h"
#include "utils/exception.decl.h"
#include "utils/exception.h"
#include "utils/hash-utils.h"

Expand Down
1 change: 0 additions & 1 deletion lib/local-execution/src/ops/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "kernels/transpose_kernels.h"
#include "op-attrs/get_output_shapes.h"
#include "op-attrs/ops/transpose.h"
#include "utils/exception.decl.h"

using namespace FlexFlow::Kernels::Transpose;

Expand Down
15 changes: 0 additions & 15 deletions lib/op-attrs/include/op-attrs/as_dot.h

This file was deleted.

2 changes: 2 additions & 0 deletions lib/op-attrs/include/op-attrs/computation_graph_op_attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_COMPUTATION_GRAPH_OP_ATTRS_H

#include "op-attrs/computation_graph_op_attrs.dtg.h"
#include "utils/record_formatter.h"

namespace FlexFlow {

OperatorType get_op_type(ComputationGraphOpAttrs const &);
RecordFormatter as_dot(ComputationGraphOpAttrs const &);

} // namespace FlexFlow

Expand Down
1 change: 1 addition & 0 deletions lib/op-attrs/include/op-attrs/dim_ordered.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_OPATTRS_INCLUDE_OPATTRS_FF_STACK_VECTOR_H

#include "op-attrs/ff_dim.dtg.h"
#include "utils/containers/count.h"
#include "utils/json.h"
#include "utils/stack_vector.h"

Expand Down
29 changes: 29 additions & 0 deletions lib/op-attrs/include/op-attrs/dim_ordered/enumerate.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H
#define _FLEXFLOW_LIB_OP_ATTRS_INCLUDE_OP_ATTRS_DIM_ORDERED_ENUMERATE_H

#include "op-attrs/dim_ordered.h"
#include "utils/bidict/bidict.h"

namespace FlexFlow {

/**
* @brief Generate a map from indices to elements of \p c.
*
* @note We return a <tt>std::map</tt> to prevent mixups of \ref ff_dim_t and
* \ref legion_dim_t. Note that <tt>std::map</tt> provides ordered iteration in
* increasing order, so iterating through the result of this function should
* function as expected.
*/
template <typename T>
std::map<ff_dim_t, T> enumerate(FFOrdered<T> const &ff_ordered) {
std::map<ff_dim_t, T> result;
for (int raw_ff_dim : count(ff_ordered.size())) {
ff_dim_t ff_dim = ff_dim_t{raw_ff_dim};
result.insert({ff_dim, ff_ordered.at(ff_dim)});
}
return result;
}

} // namespace FlexFlow

#endif
231 changes: 5 additions & 226 deletions lib/op-attrs/include/op-attrs/get_output_shapes.h
Original file line number Diff line number Diff line change
@@ -1,236 +1,15 @@
#ifndef _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H
#define _FLEXFLOW_INCLUDE_OP_ATTRS_GET_OUTPUT_SHAPES_H

#include "op-attrs/operator_attrs.h"
#include "op-attrs/parallel_tensor_shape.h"
#include "ops/reverse.h"
#include "tensor_shape.h"
#include "utils/containers/get_only.h"
#include "utils/optional.h"
#include "op-attrs/parallel_tensor_shape.dtg.h"
#include "op-attrs/pcg_operator_attrs.dtg.h"
#include <vector>

namespace FlexFlow {

template <typename T, typename Enable = void>
struct has_unary_output_t : std::false_type {};
template <typename T, typename Enable = void>
struct has_unary_input_t : std::false_type {};
template <typename T, typename Enable = void>
struct has_binary_input_t : std::false_type {};

template <typename T, typename Enable = void>
struct has_multi_output_t : std::true_type {};
template <typename T, typename Enable = void>
struct has_multi_input_t : std::true_type {};

template <typename T>
struct has_multi_output_t<
T,
typename std::enable_if<has_unary_output_t<T>::value>::type>
: std::false_type {};

template <typename T>
struct has_multi_input_t<
T,
typename std::enable_if<(has_unary_input_t<T>::value ||
has_binary_input_t<T>::value)>::type>
: std::false_type {};

/* template <typename T, typename Enable = void> struct output_type_t { using
* type = std::vector<ParallelTensorShape>; }; */

template <typename T>
typename std::enable_if<has_unary_input_t<T>::value, bool>::type
is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
if (shapes.size() != 1) {
return false;
}

return is_valid(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<has_binary_input_t<T>::value, bool>::type
is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
if (shapes.size() != 2) {
return false;
}

return is_valid(t, shapes.at(0), shapes.at(1));
}

template <typename T>
typename std::enable_if<(has_unary_input_t<T>::value &&
has_unary_output_t<T>::value),
ParallelTensorShape>::type
output_shapes(T const &t, std::vector<ParallelTensorShape> const &shapes) {
return output_shape(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<(has_binary_input_t<T>::value &&
has_unary_output_t<T>::value),
std::vector<ParallelTensorShape>>::type
output_shapes(T const &t, std::vector<ParallelTensorShape> const &shapes) {
assert(shapes.size() == 2);

return {output_shape(t, shapes.at(0), shapes.at(1))};
}

TensorShape get_tensor_shape_unsafe(ParallelTensorShape const &);
std::vector<TensorShape>
get_tensor_shapes_unsafe(std::vector<ParallelTensorShape> const &);

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs, TensorShape const &) {
NOT_IMPLEMENTED();
}

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs,
TensorShape const &,
TensorShape const &) {
NOT_IMPLEMENTED();
}

template <typename Attrs>
TensorShape get_output_shape(Attrs const &attrs,
std::vector<TensorShape> const &) {
NOT_IMPLEMENTED();
}
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
TensorShape const &);
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
TensorShape const &,
TensorShape const &) {
NOT_IMPLEMENTED();
}
template <typename Attrs>
std::vector<TensorShape> get_output_shapes(Attrs const &attrs,
std::vector<TensorShape> const &);

ParallelTensorShape get_output_shape(MultiHeadAttentionAttrs const &,
std::vector<ParallelTensorShape> const &);
ParallelTensorShape get_output_shape(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
ParallelTensorShape get_output_shape(DropoutAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(FlatAttrs const &,
ParallelTensorShape const &);
std::vector<ParallelTensorShape> get_output_shapes(GatherAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(LayerNormAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(Pool2DAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReduceAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(ReverseAttrs const &,
ParallelTensorShape const &);
std::vector<ParallelTensorShape> get_output_shapes(SplitAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(SoftmaxAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(TopKAttrs const &,
ParallelTensorShape const &);
ParallelTensorShape get_output_shape(TransposeAttrs const &,
std::vector<ParallelTensorShape> const &);

struct GetOutputShapesFunctor {
GetOutputShapesFunctor(std::vector<ParallelTensorShape> const &s) : s(s) {}

std::vector<ParallelTensorShape> const &s;

template <typename T>
std::vector<ParallelTensorShape> operator()(T const &t) {
return get_output_shapes(t, s);
}
};

template <typename... Ts>
std::vector<ParallelTensorShape>
get_output_shapes(std::variant<Ts...> const &t,
std::vector<ParallelTensorShape> const &s) {
return get_output_shape(GetOutputShapesFunctor{s}, t);
}

template <typename T>
typename std::enable_if<!has_unary_output_t<T>::value, std::optional<int>>::type
get_num_outputs(T const &) {
return std::nullopt;
}

template <typename T>
typename std::enable_if<has_unary_output_t<T>::value, std::optional<int>>::type
get_num_outputs(T const &) {
return 1;
}

int get_num_outputs(SplitAttrs const &attrs);

template <typename T>
bool is_valid(T const &t, std::vector<ParallelTensorShape> const &shapes) {
auto num_outputs = get_num_outputs(t);
if (num_outputs.has_value() && shapes.size() != num_outputs.value()) {
return false;
}

for (ParallelTensorShape const &shape : shapes) {
if (!is_valid(shape)) {
return false;
}
}

return is_valid_internal(t, shapes);
}

template <typename T>
typename std::enable_if<has_unary_input_t<T>::value, bool>::type
is_valid_internal(T const &t,
std::vector<ParallelTensorShape> const &shapes) {
return is_valid_internal(t, get_only(shapes));
}

template <typename T>
typename std::enable_if<has_binary_input_t<T>::value, bool>::type
is_valid_internal(T const &t,
std::vector<ParallelTensorShape> const &shapes) {
return is_valid_internal(t, shapes.at(0), shapes.at(1));
}

bool is_valid_internal(MultiHeadAttentionAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(BatchMatmulAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(CastAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ConcatAttrs const &,
std::vector<ParallelTensorShape> const &);
bool is_valid_internal(Conv2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(DropoutAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ElementBinaryAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(ElementUnaryAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(EmbeddingAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(FlatAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(GatherAttrs const &,
ParallelTensorShape const &,
ParallelTensorShape const &);
bool is_valid_internal(LayerNormAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(LinearAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(Pool2DAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReduceAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReductionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(RepartitionAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReplicateAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(ReshapeAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SoftmaxAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(SplitAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TopKAttrs const &, ParallelTensorShape const &);
bool is_valid_internal(TransposeAttrs const &, ParallelTensorShape const &);
get_output_shapes(PCGOperatorAttrs const &,
std::vector<ParallelTensorShape> const &);

} // namespace FlexFlow

Expand Down
Loading
Loading