Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor of element_binary_kernels, unary kenerls and embedding_kernels #1367

Closed
wants to merge 702 commits into from
Closed
Changes from 15 commits
Commits
Show all changes
702 commits
Select commit Hold shift + click to select a range
a6ed3b5
Small set of graph fixes
lockshaw Aug 28, 2023
a1bffc5
updates
wmdi Aug 28, 2023
efb0112
Address lambda comments
lockshaw Aug 28, 2023
59839b7
Format
lockshaw Aug 28, 2023
395df2f
Fix strange nccl build issue from -w flag
lockshaw Aug 29, 2023
ee9f7ca
readme for substitutions
wmdi Aug 29, 2023
af92caa
fix the cmake
lambda7xx Aug 29, 2023
dede2c7
implement the constructor in cc
lambda7xx Aug 29, 2023
c65c4ac
format
lambda7xx Aug 29, 2023
08dd3fe
format
wmdi Aug 30, 2023
82e2c2c
check substitution validity
wmdi Aug 30, 2023
8682892
Fix fmt bugs
lockshaw Aug 31, 2023
9051bf6
Merge remote-tracking branch 'refs/remotes/lockshaw/lambda-utils-test…
lockshaw Aug 31, 2023
a0cac29
Revert lambda lib/CMakeLists.txt changes
lockshaw Aug 31, 2023
ae97d59
initialize tests for substitutions
wmdi Aug 31, 2023
c1b5f4a
Fix bug in pcg build
lockshaw Aug 31, 2023
cee33f3
Merge pull request #940 from lockshaw/lambda-utils-testing
lambda7xx Sep 1, 2023
066eb1d
Bump c++ version to 17 (#1067)
lockshaw Sep 1, 2023
e0559cb
Merge remote-tracking branch 'upstream/repo-refactor' into substitutions
wmdi Sep 1, 2023
c2513c6
fix
wmdi Sep 1, 2023
c2b6b04
format
wmdi Sep 1, 2023
21b8549
remove output tensor computation
wmdi Sep 4, 2023
193b44a
start to implement the softmax_kernel
lambda7xx Sep 7, 2023
e776068
add API inferface
lambda7xx Sep 7, 2023
47d7d90
softmax kernels version0.1
lambda7xx Sep 7, 2023
204d0b9
combine
KateUnger Sep 7, 2023
6c32d50
concat
KateUnger Sep 8, 2023
cf7cf05
add empty method
lambda7xx Sep 8, 2023
bd63d52
copy some old code and implement topK version 0.1
lambda7xx Sep 8, 2023
19e20b1
add API method
lambda7xx Sep 8, 2023
06e21c3
transpose version0.1
lambda7xx Sep 8, 2023
244a0ec
start to do the Repartition
lambda7xx Sep 8, 2023
6f0af2a
modify the backward and forward
lambda7xx Sep 8, 2023
c5df897
start to implement the init
lambda7xx Sep 8, 2023
c2780b2
start measure_operator_cost
lambda7xx Sep 8, 2023
a43ff0f
partition version 0.1
lambda7xx Sep 8, 2023
e3b633f
implement get_operator_attrs
wmdi Sep 8, 2023
3041dc7
get parallel operator attributes & minor fix
wmdi Sep 8, 2023
ea1e8e5
format
wmdi Sep 8, 2023
8e28dd0
concat
KateUnger Sep 11, 2023
676f247
combine
KateUnger Sep 12, 2023
98a3c67
format
KateUnger Sep 12, 2023
bc12e40
Conv 2D Op (#1112)
KateUnger Sep 13, 2023
1508050
match open graphs
wmdi Sep 13, 2023
c0d3513
conv2d typo (#1135)
KateUnger Sep 14, 2023
6b64ae6
cuda
KateUnger Sep 14, 2023
2138358
format
KateUnger Sep 14, 2023
d764545
format and cuda
KateUnger Sep 14, 2023
116d162
Merge branch 'repo-refactor' into concat_new
KateUnger Sep 14, 2023
c2bdb7a
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Sep 14, 2023
f315b74
delete comments
KateUnger Sep 14, 2023
faa2411
format and fix cc
KateUnger Sep 15, 2023
58a7c62
Merge remote-tracking branch 'upstream/repo-refactor' into substitutions
wmdi Sep 15, 2023
b5ce50c
format
KateUnger Sep 15, 2023
322d945
minor fix
wmdi Sep 17, 2023
4fadbcf
Implement substitutions (#1011)
wmdi Sep 19, 2023
76a86ed
update the softmax
lambda7xx Sep 24, 2023
b17dae9
Serialize jobs in per-lib-checks workflow (#1149)
goliaro Sep 24, 2023
d684286
update the topk
lambda7xx Sep 24, 2023
c114bc4
update the transpose
lambda7xx Sep 24, 2023
cecc7f8
add update and leave task to implement
lambda7xx Sep 24, 2023
06c96ad
update the reduce
lambda7xx Sep 24, 2023
675a22b
use exceptions
lambda7xx Sep 27, 2023
e4c90cf
use exceptions
lambda7xx Sep 27, 2023
986af8f
use exceptions in partition.cc
lambda7xx Sep 27, 2023
dc73c86
fix signatures and bind_arg
KateUnger Sep 28, 2023
cdd5b36
Merge branch 'repo-refactor' into no_op
reyna-abhyankar Sep 30, 2023
990276c
fix
KateUnger Sep 30, 2023
98539d4
fix type error
lambda7xx Sep 30, 2023
4359386
refine the softmax by binding new thing in init
lambda7xx Oct 3, 2023
9e17b38
leave the index
lambda7xx Oct 3, 2023
53a6b2a
format the code
lambda7xx Oct 3, 2023
5d3cee7
format repo-refactor (#1168)
KateUnger Oct 4, 2023
069e94c
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Oct 4, 2023
0f083a3
combine
KateUnger Oct 4, 2023
b20afc8
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Oct 4, 2023
0b5704c
finish concat
KateUnger Oct 4, 2023
a09b61f
Merge branch 'concat_new' of github.com:KateUnger/FlexFlow into conca…
KateUnger Oct 4, 2023
94a1bd9
concat
KateUnger Oct 4, 2023
0dfbed8
add delete
KateUnger Oct 4, 2023
2b8894b
conv2d typo fix
KateUnger Oct 4, 2023
dd9c2c3
Merge branch 'repo-refactor' of https://github.com/flexflow/FlexFlow …
KateUnger Oct 4, 2023
a79b37c
finish element_binary
KateUnger Oct 4, 2023
63627fa
element_binary
KateUnger Oct 4, 2023
b71b73e
conv2d
KateUnger Oct 4, 2023
931d8e4
Merge branch 'repo-refactor' into no_op
reyna-abhyankar Oct 5, 2023
49f811d
concat
KateUnger Oct 5, 2023
e388534
concat
KateUnger Oct 5, 2023
cdb29f2
finish concat
KateUnger Oct 5, 2023
630da82
Fix signature return
reyna-abhyankar Oct 5, 2023
16ca54c
Comment CHECK_FMTABLE
reyna-abhyankar Oct 5, 2023
2e39499
Add signature for other ops
reyna-abhyankar Oct 5, 2023
8145d0f
Merge pull request #1139 from KateUnger/no_op
reyna-abhyankar Oct 5, 2023
57edb78
Merge branch 'repo-refactor' into concat_new
reyna-abhyankar Oct 5, 2023
cf33e1e
Repo refactor ci (#1083)
lambda7xx Oct 6, 2023
b670a2c
Merge branch 'repo-refactor' into concat_new
reyna-abhyankar Oct 6, 2023
4d1bd0a
Merge branch 'repo-refactor' into element_binary
lockshaw Oct 6, 2023
76d0ee7
Merge branch 'repo-refactor' into repo-refactor-lambda-Repartition-OP
lockshaw Oct 6, 2023
bc70bdb
Merge branch 'repo-refactor' into repo-refactor-lambda-transpose-OP
lockshaw Oct 6, 2023
9a227d4
Merge branch 'repo-refactor' into repo-refactor-lambda-topK-OP
lockshaw Oct 6, 2023
6ec88c5
Merge branch 'repo-refactor' into combine_new
lockshaw Oct 6, 2023
6eb3a16
Merge pull request #1114 from KateUnger/concat_new
reyna-abhyankar Oct 6, 2023
b1806b3
Merge branch 'repo-refactor' into repo-refactor-lambda-softmax-OP
lockshaw Oct 6, 2023
5a38012
Merge branch 'repo-refactor' into conv2d
reyna-abhyankar Oct 6, 2023
7d52ca8
Merge branch 'repo-refactor' into repo-refactor-lambda-Repartition-OP
lockshaw Oct 6, 2023
aadaff7
Merge branch 'repo-refactor' into conv2d
reyna-abhyankar Oct 6, 2023
16c9c7c
Merge pull request #1169 from KateUnger/conv2d
reyna-abhyankar Oct 6, 2023
0376cae
Call fwd sig
reyna-abhyankar Oct 6, 2023
ac38e58
Remove namespace std
reyna-abhyankar Oct 6, 2023
bc35da0
Merge branch 'repo-refactor' into combine_new
reyna-abhyankar Oct 6, 2023
3dc4093
Format
reyna-abhyankar Oct 7, 2023
d926516
Merge pull request #1113 from KateUnger/combine_new
reyna-abhyankar Oct 7, 2023
95dd194
Fix signature
reyna-abhyankar Oct 7, 2023
e34245c
Merge branch 'repo-refactor' into element_binary
reyna-abhyankar Oct 7, 2023
329c0e5
Merge pull request #1136 from KateUnger/element_binary
reyna-abhyankar Oct 7, 2023
277e107
Dropout Op (#1134)
KateUnger Oct 7, 2023
dd39cc8
Flat Operator (#1137)
KateUnger Oct 7, 2023
e1b1be2
Batch Norm Op (#1110)
KateUnger Oct 7, 2023
8c424a9
fix the init_task
lambda7xx Oct 10, 2023
13a0dc5
add topk
lambda7xx Oct 10, 2023
bcd6bf4
Merge branch 'repo-refactor-lambda-topK-OP' of https://github.com/lam…
lambda7xx Oct 10, 2023
9a8c3fb
add allocator to allocate memory for index_ptr
lambda7xx Oct 10, 2023
e7dc8c0
remove the old comment
lambda7xx Oct 10, 2023
58ecd89
Merge branch 'repo-refactor-lambda-transpose-OP' of https://github.co…
lambda7xx Oct 10, 2023
89dd105
fix the input_tensor
lambda7xx Oct 10, 2023
31a55ca
start to update the kernel
lambda7xx Oct 10, 2023
1039372
fix the transpose_kernels cu
lambda7xx Oct 11, 2023
56f931a
fix the transpose op
lambda7xx Oct 11, 2023
e750c93
subsitutions build
wmdi Oct 11, 2023
4daf4d9
fmt
wmdi Oct 11, 2023
77e85ad
Merge remote-tracking branch 'upstream/repo-refactor' into test-subst…
wmdi Oct 12, 2023
a44b0cb
fmt
wmdi Oct 12, 2023
561a2f7
fmt
wmdi Oct 12, 2023
8a1fb69
fix issues caused by merge
wmdi Oct 12, 2023
d096ba1
Purge MOE operators (#1177)
reyna-abhyankar Oct 12, 2023
9ef1113
Merge branch 'repo-refactor' into test-substitution
lockshaw Oct 12, 2023
7eb34b7
Cast Op (#1111)
KateUnger Oct 12, 2023
d269b40
compiler build
wmdi Oct 16, 2023
716e5b9
Merge branch 'repo-refactor' into repo-refactor-lambda-Repartition-OP
reyna-abhyankar Oct 17, 2023
1d0ef22
Update lib/runtime/src/ops/softmax.cc
reyna-abhyankar Oct 17, 2023
6c2751f
Update lib/runtime/src/ops/softmax.cc
reyna-abhyankar Oct 17, 2023
a55c54b
Update lib/runtime/src/ops/softmax.cc
reyna-abhyankar Oct 17, 2023
1771a01
Merge branch 'repo-refactor' into repo-refactor-lambda-softmax-OP
reyna-abhyankar Oct 17, 2023
e31a41d
fix
wmdi Oct 17, 2023
933981e
fix the error
lambda7xx Oct 17, 2023
24c55ef
Merge branch 'repo-refactor-lambda-Repartition-OP' of https://github.…
lambda7xx Oct 17, 2023
d1b38c8
fix the kernel
lambda7xx Oct 17, 2023
c352c52
fix the topk error and add indeices
lambda7xx Oct 17, 2023
b6f4112
fix the error
lambda7xx Oct 17, 2023
66b9591
fix the conflict
lambda7xx Oct 17, 2023
f9549b6
format the code
lambda7xx Oct 17, 2023
6e29568
format the code
lambda7xx Oct 17, 2023
f46fd11
Merge branch 'test-substitution' into test-compiler
wmdi Oct 18, 2023
5b97efa
Split OP (#1107)
lambda7xx Oct 18, 2023
1610902
Merge branch 'repo-refactor' into repo-refactor-lambda-softmax-OP
reyna-abhyankar Oct 18, 2023
4d799b7
fix the typo
lambda7xx Oct 18, 2023
fa6a929
fix the typo
lambda7xx Oct 18, 2023
5ddd801
fix the typo
lambda7xx Oct 18, 2023
c62b26c
Merge branch 'repo-refactor' into repo-refactor-lambda-topK-OP
lambda7xx Oct 19, 2023
53ed023
fix the semi
lambda7xx Oct 19, 2023
4ba60f4
fix the format
lambda7xx Oct 19, 2023
f494210
Merge pull request #1106 from lambda7xx/repo-refactor-lambda-softmax-OP
lambda7xx Oct 19, 2023
d8a3276
Merge branch 'repo-refactor' into repo-refactor-lambda-transpose-OP
reyna-abhyankar Oct 19, 2023
4a5b2dc
Replicate OP (#1101)
lambda7xx Oct 19, 2023
5e227c2
Merge branch 'repo-refactor' into repo-refactor-lambda-transpose-OP
reyna-abhyankar Oct 19, 2023
c8adea0
implement some missing functions
wmdi Oct 20, 2023
ecbb20f
format
wmdi Oct 20, 2023
8993f43
Merge branch 'repo-refactor' into repo-refactor-lambda-topK-OP
reyna-abhyankar Oct 23, 2023
04a348f
Merge pull request #1117 from lambda7xx/repo-refactor-lambda-transpos…
lambda7xx Oct 25, 2023
29cf3cd
Merge branch 'repo-refactor' into repo-refactor-lambda-topK-OP
lambda7xx Oct 25, 2023
d75a8f5
Merge pull request #1116 from lambda7xx/repo-refactor-lambda-topK-OP
lambda7xx Oct 25, 2023
28f9aea
Reshape OP (#1100)
lambda7xx Nov 1, 2023
e2ba62e
Reverse OP (#1105)
lambda7xx Nov 1, 2023
f2e1ec7
Merge branch 'repo-refactor' into repo-refactor-lambda-Repartition-OP
reyna-abhyankar Nov 1, 2023
41e15ef
Merge pull request #1119 from lambda7xx/repo-refactor-lambda-Repartit…
lambda7xx Nov 1, 2023
f29461b
Pool2D OP (#1182)
lambda7xx Nov 1, 2023
ed8b4b4
Reduce OP (#1118)
lambda7xx Nov 2, 2023
f01e13b
Reduction OP (#1120)
lambda7xx Nov 2, 2023
9372819
Update submodule (#1212)
reyna-abhyankar Nov 2, 2023
0f3c7de
substitutions tests pass
wmdi Nov 5, 2023
ec0dead
fmt
wmdi Nov 5, 2023
af67e9e
Merge branch 'test-substitution' into test-compiler
wmdi Nov 8, 2023
c015efb
unity dp works
wmdi Nov 15, 2023
6211b84
format
wmdi Nov 15, 2023
2940646
Batch Matmul Op (#1023)
KateUnger Dec 31, 2023
6e6779e
improve at for OutputLabelledOpenMultiDiGraph
wmdi Jan 6, 2024
0f75405
graph get_ptr fix
wmdi Jan 6, 2024
68a1bf1
fmt
wmdi Jan 6, 2024
ebfde7a
Embedding (#1256)
reyna-abhyankar Jan 7, 2024
53a9daa
update fmt
wmdi Jan 8, 2024
75d8482
Hip kernel fix (#1178)
reyna-abhyankar Jan 17, 2024
d4d9354
remove unnecessary virtual
wmdi Jan 19, 2024
b474d8d
format
wmdi Jan 19, 2024
7a87d11
Merge pull request #1189 from wmdi/test-substitution
wmdi Jan 19, 2024
d9f1302
Merge remote-tracking branch 'upstream/repo-refactor' into test-compiler
wmdi Jan 24, 2024
fb58a99
fmt
wmdi Jan 24, 2024
02937e1
fix
wmdi Jan 24, 2024
6402ed0
add substitutions, compiler, and their unit tests to CI
wmdi Jan 25, 2024
0c45f61
disable runtime unit test
wmdi Jan 25, 2024
bf41a4b
linear operator (#1180)
lambda7xx Feb 7, 2024
49d2ff7
LayerNorm OP draft (#1186)
lambda7xx Feb 8, 2024
95fa427
minor fix
wmdi Feb 15, 2024
1f7e2b6
(not compilable) visitable issue for OptimalCostState
wmdi Feb 18, 2024
ffa7f79
first try on docs
Bob-Chen222 Feb 23, 2024
a9a6402
fix machine mapping hash & refactor dp algorithm
wmdi Feb 27, 2024
d8bbcb8
minor fix
wmdi Feb 27, 2024
09d3152
fix variant issue
wmdi Feb 28, 2024
a150d3a
fmt
wmdi Feb 28, 2024
3237169
Element Unary Op (#1257)
reyna-abhyankar Mar 7, 2024
2eb3fdf
fix
wmdi Mar 11, 2024
7598a92
fmt
wmdi Mar 11, 2024
05c8336
fix
wmdi Mar 14, 2024
71aeddb
Merge remote-tracking branch 'upstream/repo-refactor' into test-compiler
wmdi Mar 14, 2024
502b41f
Merge branch 'repo-refactor' into repo-refactor
Bob-Chen222 Mar 14, 2024
6962bc8
additional doc
Bob-Chen222 Mar 14, 2024
73d72d2
Merge branch 'repo-refactor' of https://github.com/Bob-Chen222/FlexFl…
Bob-Chen222 Mar 14, 2024
74c90bf
Remove unnecessary dependencies and allow using external installs (#1…
lockshaw Mar 16, 2024
9345400
add more unit tests
wmdi Mar 18, 2024
c0015df
fmt
wmdi Mar 18, 2024
6d28697
Merge remote-tracking branch 'origin/repo-refactor' into compiler
lockshaw Mar 22, 2024
102f5fb
Fix post-merge
lockshaw Mar 22, 2024
d6e10bb
Add shell hook for sapling development
lockshaw Mar 23, 2024
95fb4cc
changed from nullopt to std::nullopt
Mar 23, 2024
c091479
fix cast issue
wmdi Mar 23, 2024
57bd35f
Merge branch 'test-compiler' of github.com:wmdi/FlexFlow into test-co…
wmdi Mar 23, 2024
54c604a
Fix spdlog cmake issue
lockshaw Mar 24, 2024
a09e528
Merge remote-tracking branch 'refs/remotes/wmdi/test-compiler' into c…
lockshaw Mar 24, 2024
8b914cf
Re-remove submodules
lockshaw Mar 24, 2024
189f323
minor fix & fmt
wmdi Mar 24, 2024
d2eb505
upd tests name to match ci
wmdi Mar 24, 2024
371324a
Add TEST_SUITE declaration to make tests findable by ctest
lockshaw Mar 26, 2024
da74817
Remove unnecessary nix files, add utils test to ci
lockshaw Mar 26, 2024
0db60db
Fix utils tests name, format
lockshaw Mar 26, 2024
6e520bb
Merge pull request #1229 from wmdi/test-compiler
wmdi Mar 26, 2024
bf9c0c0
resolved merge conflict
Bob-Chen222 Mar 28, 2024
1b0c32e
Re-merge #1229 (#1346)
lockshaw Mar 30, 2024
c21d66e
add tutorial
Bob-Chen222 Apr 1, 2024
817065b
Merge branch 'repo-refactor' into repo-refactor
Bob-Chen222 Apr 4, 2024
041f338
Add external tl-expected via nix, add proj via flake instead of submo…
lockshaw Apr 4, 2024
318152a
Commit proj toml and update proj version (#1350)
lockshaw Apr 5, 2024
fe164cc
Merge branch 'flexflow:repo-refactor' into repo-refactor
Bob-Chen222 Apr 6, 2024
7a9aeaf
align linear and layer_norm
Bob-Chen222 Apr 6, 2024
1369b35
fixed style
Apr 12, 2024
274d497
fixed layer_norm and linear
Apr 12, 2024
6e979e6
Revert "add tutorial"
Apr 12, 2024
080ed6a
Revert "first try on docs"
Bob-Chen222 Apr 12, 2024
03451e6
Revert "additional doc"
Bob-Chen222 Apr 12, 2024
de551ce
Update operator_pattern.h
Bob-Chen222 Apr 12, 2024
2fe8015
"revert substitution"
Bob-Chen222 Apr 12, 2024
a1fcfe5
"added hip for e-kernel"
Bob-Chen222 Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 6 additions & 16 deletions lib/kernels/include/kernels/concat_kernels.h
Original file line number Diff line number Diff line change
@@ -5,30 +5,20 @@
#include "kernels/device.h"

namespace FlexFlow {

class ConcatPerDeviceState : public PerDeviceOpState {
public:
ConcatPerDeviceState(FFHandler handle) : PerDeviceOpState(handle){};
int legion_axis;
char op_name[MAX_OPNAME];
};

namespace Kernels {
namespace Concat {

void init_meta(ConcatPerDeviceState *meta, int legion_axis);

void forward_kernel(ffStream_t stream,
ConcatPerDeviceState const *m,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const *inputs,
int num_inputs);
std::vector<GenericTensorAccessorR> const &inputs,
int num_inputs,
ff_dim_t legion_axis);

void backward_kernel(ffStream_t stream,
ConcatPerDeviceState const *m,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const *input_grads,
int num_inputs);
std::vector<GenericTensorAccessorW> const &input_grads,
int num_inputs,
ff_dim_t legion_axis);

} // namespace Concat
} // namespace Kernels
74 changes: 25 additions & 49 deletions lib/kernels/src/cuda/concat_kernels.cu
Original file line number Diff line number Diff line change
@@ -13,46 +13,42 @@
* limitations under the License.
*/

#include "device.h"
#include "kernels/concat_kernels.h"
#include "kernels/cuda_helper.h"
#include "kernels/device.h"
#include <cassert>

namespace FlexFlow {

namespace Kernels {
namespace Concat {

void init_meta(ConcatPerDeviceState *m, int legion_axis) {
m->legion_axis = legion_axis;
}

void calc_blk_size(size_t &num_blocks,
size_t &blk_size,
ArrayShape const &shape,
int axis) {
req<ff_dim_t> legion_axis) {
num_blocks = 1;
blk_size = 1;
for (int d = 0; d < shape.num_dims(); d++) {
if (d <= axis) {
blk_size *= shape[d];
if (d <= legion_axis) {
blk_size *= shape[legion_dim_t(d)];
} else {
num_blocks *= shape[d];
num_blocks *= shape[legion_dim_t(d)];
}
}
}

void forward_kernel(cudaStream_t stream,
ConcatPerDeviceState const *m,
void forward_kernel(ffStream_t stream,
GenericTensorAccessorW const &output,
GenericTensorAccessorR const *inputs,
int num_inputs) {
std::vector<GenericTensorAccessorR> const &inputs,
int num_inputs,
ff_dim_t legion_axis) {
size_t num_blocks = 1, output_blk_size = 1, input_blk_sizes[MAX_NUM_INPUTS];
assert(num_inputs <= MAX_NUM_INPUTS);
calc_blk_size(num_blocks, output_blk_size, output.shape, m->legion_axis);
for (int i = 0; i < num_input; i++) {
calc_blk_size(num_blocks, output_blk_size, output.shape, legion_axis);
for (int i = 0; i < num_inputs; i++) {
size_t input_num_blocks = 1;
calc_blk_size(
input_num_blocks, input_blk_sizes[i], inputs[i].shape, m->legion_axis);
input_num_blocks, input_blk_sizes[i], inputs[i].shape, legion_axis);
assert(input_num_blocks == num_blocks);
}

@@ -66,39 +62,25 @@ void forward_kernel(cudaStream_t stream,
num_blocks,
output_blk_size,
input_blk_sizes[i]);
// printf("output = %x num_blocks=%d output_blk_size=%d
// input_blk_size[%d]=%d\n",
// output, num_blocks, output_blk_size, i, input_blk_sizes[i]);
offset += input_blk_sizes[i];
}
}

void backward_kernel(cudaStream_t stream,
ConcatPerDeviceState const *m,
void backward_kernel(ffStream_t stream,
GenericTensorAccessorR const &output_grad,
GenericTensorAccessorW const *input_grads,
int num_inputs) {
std::vector<GenericTensorAccessorW> const &input_grads,
int num_inputs,
ff_dim_t legion_axis) {
size_t num_blocks = 1, output_blk_size = 1, input_blk_sizes[MAX_NUM_INPUTS];
assert(num_inputs <= MAX_NUM_INPUTS);
switch (output_grad.domain.get_dim()) {
#define DIMFUNC(DIM) \
case DIM: { \
Rect<DIM> rect = output_grad.domain; \
calc_blk_size<DIM>(num_blocks, output_blk_size, rect, m->legion_axis); \
for (int i = 0; i < num_inputs; i++) { \
rect = input_grads[i].domain; \
coord_t input_num_blocks = 1; \
calc_blk_size<DIM>( \
input_num_blocks, input_blk_sizes[i], rect, m->legion_axis); \
assert(input_num_blocks == num_blocks); \
} \
break; \
}
LEGION_FOREACH_N(DIMFUNC)
#undef DIMFUNC
default:
fprintf(stderr, "Unsupported concat dimension number");
assert(false);

ArrayShape shape = output_grad.shape;
calc_blk_size(num_blocks, output_blk_size, shape, legion_axis);
for (int i = 0; i < num_inputs; i++) {
shape = input_grads[i].shape;
size_t input_num_blocks = 1;
calc_blk_size(input_num_blocks, input_blk_sizes[i], shape, legion_axis);
assert(input_num_blocks == num_blocks);
}

off_t offset = 0;
@@ -113,12 +95,6 @@ void backward_kernel(cudaStream_t stream,
output_blk_size);
offset += input_blk_sizes[i];
}

// Rect<2> output_rect(Point<2>(0, 0), Point<2>(output_blk_size-1, batch_size
// - 1)); Rect<2> input_rect(Point<2>(0, 0), Point<2>(input_blk_sizes[0]-1,
// batch_size - 1)); print_tensor<2, float>(output_grad - output_blk_size,
// output_rect, "[Concat:backward:output]"); print_tensor<2,
// float>(input_grads[0], input_rect, "[Concat:backward:input0]");
}

} // namespace Concat
3 changes: 2 additions & 1 deletion lib/op-attrs/include/op-attrs/ops/concat.h
Original file line number Diff line number Diff line change
@@ -10,8 +10,9 @@ namespace FlexFlow {

struct ConcatAttrs {
ff_dim_t axis;
req<int> num_inputs;
};
FF_VISITABLE_STRUCT(ConcatAttrs, axis);
FF_VISITABLE_STRUCT(ConcatAttrs, axis, num_inputs);
CHECK_VALID_OP_ATTR(ConcatAttrs);

} // namespace FlexFlow
7 changes: 4 additions & 3 deletions lib/runtime/include/runtime/config.h
Original file line number Diff line number Diff line change
@@ -104,13 +104,14 @@ struct FFConfig : public use_visitable_cmp<FFConfig> {
int python_data_loader_type = 2;
};

class FFIterationConfig {
public:
FFIterationConfig();
struct FFIterationConfig {
FFIterationConfig() = delete;
void reset();
int seq_length;
};

FF_VISITABLE_STRUCT_NONSTANDARD_CONSTRUCTION(FFIterationConfig, seq_length);

enum FieldIDs {
FID_DATA,
};
583 changes: 100 additions & 483 deletions lib/runtime/src/ops/concat.cc

Large diffs are not rendered by default.

51 changes: 1 addition & 50 deletions lib/runtime/src/ops/concat.h
Original file line number Diff line number Diff line change
@@ -2,8 +2,8 @@
#define _FLEXFLOW_CONCAT_H

#include "op-attrs/ops/concat.h"
#include "op_task_invocation.h"
#include "sim_environment.h"
#include "task_spec/op_task_invocation.h"

namespace FlexFlow {

@@ -24,55 +24,6 @@ CostMetrics
std::vector<ParallelTensorShape> const &input_shapes,
ProfilingSettings const &settings,
MachineView const &machine_view);

/* class Concat : public Op { */
/* public: */
/* using Attrs = ConcatAttrs; */

/* Concat(FFModel &model, */
/* int n, */
/* ParallelTensor const *inputs, */
/* int axis, */
/* char const *name); */
/* Concat(FFModel &model, */
/* Attrs const &attrs, */
/* std::vector<ParallelTensor> const &inputs, */
/* char const *name = nullptr); */
/* void init(FFModel const &) override; */
/* void forward(FFModel const &) override; */
/* void backward(FFModel const &) override; */
/* static Op * */
/* create_operator_from_layer(FFModel &model, */
/* Layer const *layer, */
/* std::vector<ParallelTensor> const &inputs);
*/
/* static PerDeviceOpState *init_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* static void forward_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* static void backward_task(Legion::Task const *task, */
/* std::vector<Legion::PhysicalRegion> const
* &regions, */
/* Legion::Context ctx, */
/* Legion::Runtime *runtime); */
/* bool measure_operator_cost(Simulator *sim, */
/* MachineView const &pc, */
/* CostMetrics &cost_metrics) const override; */

/* OpTaskBinding get_init_task_binding() const override; */
/* OpTaskBinding get_fwd_task_binding() const override; */
/* OpTaskBinding get_bwd_task_binding() const override; */

/* public: */
/* int legion_axis; */
/* }; */

} // namespace FlexFlow

#endif
10 changes: 9 additions & 1 deletion lib/runtime/src/serialization.h
Original file line number Diff line number Diff line change
@@ -7,9 +7,10 @@
#include "legion/legion_utilities.h"
#include "op-attrs/dim_ordered.h"
#include "utils/optional.h"
#include "utils/required.h"
#include "utils/type_traits.h"
#include "utils/variant.h"
#include "utils/visitable.h"
#include <type_traits>

namespace FlexFlow {

@@ -77,6 +78,13 @@ struct is_trivially_serializable<
typename std::enable_if<std::is_integral<T>::value>::type>
: std::true_type {};

template <typename T>
struct is_trivially_serializable<T, void_t<underlying_type_t<T>>>
: is_trivially_serializable<underlying_type_t<T>> {};

template <typename T>
struct is_trivially_serializable<req<T>> : is_trivially_serializable<T> {};

template <>
struct is_trivially_serializable<half> : std::true_type {};
template <>
17 changes: 7 additions & 10 deletions lib/runtime/src/task_spec/op_task_invocation.h
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
#include "legion.h"
#include "op_arg_ref.h"
#include "op_task_signature.h"
#include "op_tensor_spec.h"
#include "runtime/config.h"
#include "runtime/profiling.h"
#include "serialization.h"
@@ -14,6 +15,7 @@
#include "utils/bidict.h"
#include "utils/optional.h"
#include "utils/stack_map.h"
#include "variadic_tensor_ref.h"
#include <typeindex>
#include <unordered_map>
#include <unordered_set>
@@ -22,16 +24,6 @@ namespace FlexFlow {

enum class IsTrainable { YES, NO };

struct OpTensorSpec {
TensorRole role;
req<int> idx;
};
FF_VISITABLE_STRUCT(OpTensorSpec, role, idx);

OpTensorSpec input_tensor(int);
OpTensorSpec output_tensor(int);
OpTensorSpec weight_tensor(int);

using OpArgSpec = variant<ConcreteArgSpec,
IndexArgSpec,
OpArgRefSpec,
@@ -48,6 +40,11 @@ struct OpTaskBinding {
void bind(slot_id, OpTensorSpec const &);
void bind_grad(slot_id, OpTensorSpec const &);

template <typename T>
void bind(slot_id name, VariadicTensorRef<T> const &t) {
NOT_IMPLEMENTED();
}

template <typename T>
void bind_device_specific_arg(slot_id name, T const &t) {
NOT_IMPLEMENTED();
20 changes: 20 additions & 0 deletions lib/runtime/src/task_spec/op_tensor_spec.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H
#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_OP_TENSOR_SPEC_REF_H

#include "op_task_signature.h"

namespace FlexFlow {

struct OpTensorSpec {
TensorRole role;
req<int> idx;
};
FF_VISITABLE_STRUCT(OpTensorSpec, role, idx);

OpTensorSpec input_tensor(int);
OpTensorSpec output_tensor(int);
OpTensorSpec weight_tensor(int);

} // namespace FlexFlow

#endif
2 changes: 2 additions & 0 deletions lib/runtime/src/task_spec/runtime_arg_ref.h
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

#include "arg_ref.h"
#include "device_specific.h"
#include "runtime/config.h"

namespace FlexFlow {

@@ -15,6 +16,7 @@ using RuntimeArgRefSpec = ArgRefSpec<RuntimeArgRefType>;

RuntimeArgRef<ProfilingSettings> profiling_settings();
RuntimeArgRef<DeviceSpecific<PerDeviceFFHandle>> ff_handle();
RuntimeArgRef<FFIterationConfig> iteration_config();

} // namespace FlexFlow

20 changes: 20 additions & 0 deletions lib/runtime/src/task_spec/variadic_tensor_ref.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H
#define _FLEXFLOW_RUNTIME_SRC_TASK_SPEC_VARIADIC_TENSOR_ARG_REF_H

#include "arg_ref.h"
#include "op_tensor_spec.h"

namespace FlexFlow {

enum class VariadicTensorRefType { INPUT_TENSORS };

template <typename T>
using VariadicTensorRef = ArgRef<VariadicTensorRefType, T>;

VariadicTensorRef<OpTensorSpec> get_input_tensors() {
return {VariadicTensorRefType::INPUT_TENSORS};
}

} // namespace FlexFlow

#endif