Skip to content

Commit

Permalink
Split Simulator memory_requirement of operators into more detailed pa…
Browse files Browse the repository at this point in the history
…rts for input, output, and weight tensors (flexflow#297)

* [Memory] Split memory_requirement into more detailed parts

* [Memory] Change the way to measure input and output tensor memory usage

* [Memory] Update based on review comments

* [Memory] Use a single help function to get operator memory usage

* Better to make the function const, per review suggestion

* Update details per review suggestion
  • Loading branch information
eric-zheng authored Aug 23, 2022
1 parent 11df9db commit 2bc49d5
Show file tree
Hide file tree
Showing 29 changed files with 201 additions and 43 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
/.tools/
/python/flexflow_python
/python/flexflow/core/legion_cffi.py
python/flexflow/core/flexflow_cffi_header.py
python/flexflow/core/legion_cffi_header.py
*.pb.cc
*.pb.h
*.o
Expand Down
30 changes: 26 additions & 4 deletions include/flexflow/simulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,35 @@ class TransposeMeta;
class Op;
class FFModel;

/**
* @brief Costs of an operator.
*/
struct CostMetrics {
CostMetrics()
: forward_time(0.0f), backward_time(0.0f), sync_time(0.0f),
memory_requirement(0) {}
/**
* @brief Return the sum of the memory usage recorded in this CostMetrics.
*/
size_t total_memory() const;

/**
* @brief Get the incremental difference between the total memory in
* CostMetrics and sim->offset.
* @details This is to easily compute the difference between sim->offset and
* sum of all memory usage recorded in this CostMetrics.
*
* @param sim_offset Simulator->offset
* @return size_t The incremental memory usage difference
*/
size_t total_mem_diff_from(off_t sim_offset) const;

public:
float forward_time, backward_time;
float sync_time;
size_t memory_requirement;
///< Bytes of memory usage of different parts
// Assume:
// 1. all memory allocations use Simulator::allocate
// 2. we call Simulator::free_all before measuring an operator
// Therefore, the current memory usage of an operator is (size_t)sim->offset
size_t inputs_memory, outputs_memory, weights_memory;
};

class Device {
Expand Down
4 changes: 3 additions & 1 deletion src/ops/aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ bool Aggregate::measure_operator_cost(Simulator *sim,
// TODO: implement
cost_metrics.forward_time = 0.0f;
cost_metrics.backward_time = 0.0f;
cost_metrics.memory_requirement = 0;
cost_metrics.inputs_memory = 0;
cost_metrics.outputs_memory = 0;
cost_metrics.weights_memory = 0;
return false;
}

Expand Down
4 changes: 3 additions & 1 deletion src/ops/aggregate_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,9 @@ bool AggregateSpec::measure_operator_cost(Simulator *sim,
// TODO: implement
cost_metrics.forward_time = 0.0f;
cost_metrics.backward_time = 0.0f;
cost_metrics.memory_requirement = 0;
cost_metrics.inputs_memory = 0;
cost_metrics.outputs_memory = 0;
cost_metrics.weights_memory = 0;
return false;
}

Expand Down
14 changes: 13 additions & 1 deletion src/ops/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,14 @@ bool MultiHeadAttention::measure_operator_cost(
(float const *)sim->allocate(sub_key.get_volume(), DT_FLOAT);
float const *value_ptr =
(float const *)sim->allocate(sub_value.get_volume(), DT_FLOAT);
float const *weight_ptr = (float const *)sim->allocate(num_weights, DT_FLOAT);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float const *weight_ptr = (float const *)sim->allocate(num_weights, DT_FLOAT);
cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

Expand All @@ -780,10 +785,17 @@ bool MultiHeadAttention::measure_operator_cost(
(float *)sim->allocate(sub_key.get_volume(), DT_FLOAT);
float *value_grad_ptr =
(float *)sim->allocate(sub_value.get_volume(), DT_FLOAT);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *weight_grad_ptr = (float *)sim->allocate(num_weights, DT_FLOAT);
cost_metrics.weights_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m,
Expand Down
7 changes: 7 additions & 0 deletions src/ops/batch_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,11 @@ bool BatchMatmul::measure_operator_cost(Simulator *sim,
float *b_ptr = (float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT);
assert(b_ptr != NULL);
float *c_ptr = NULL;
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *out_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(out_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

int m = input1_c;
int n = input0_r;
Expand All @@ -548,9 +551,13 @@ bool BatchMatmul::measure_operator_cost(Simulator *sim,
float *b_grad_ptr =
(float *)sim->allocate(sub_input1.get_volume(), DT_FLOAT);
float *c_grad_ptr = NULL;
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *out_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(out_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(meta,
Expand Down
12 changes: 12 additions & 0 deletions src/ops/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,17 @@ bool BatchNorm::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT);
assert(bias_ptr != NULL);
float *scale_ptr = (float *)sim->allocate(output_c, DT_FLOAT);
assert(scale_ptr != NULL);
cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset);

std::function<void()> forward, backward;
forward = [&] {
Expand All @@ -263,13 +268,20 @@ bool BatchNorm::measure_operator_cost(Simulator *sim,
float *input_grad_ptr =
(float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_grad_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

float *scale_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT);
assert(scale_grad_ptr != NULL);
float *bias_grad_ptr = (float *)sim->allocate(output_c, DT_FLOAT);
assert(bias_grad_ptr != NULL);
cost_metrics.weights_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel(m,
Expand Down
4 changes: 3 additions & 1 deletion src/ops/cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ bool Cache::measure_operator_cost(Simulator *sim,
// TODO: implement
cost_metrics.forward_time = 0.0f;
cost_metrics.backward_time = 0.0f;
cost_metrics.memory_requirement = 0;
cost_metrics.inputs_memory = 0;
cost_metrics.outputs_memory = 0;
cost_metrics.weights_memory = 0;
return false;
}

Expand Down
3 changes: 3 additions & 0 deletions src/ops/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ bool Cast::measure_operator_cost(Simulator *sim,
// Assume cast has no cost
cost_metrics.forward_time = 0.0f;
cost_metrics.backward_time = 0.0f;
cost_metrics.inputs_memory = 0;
cost_metrics.outputs_memory = 0;
cost_metrics.weights_memory = 0;
return true;
}

Expand Down
8 changes: 8 additions & 0 deletions src/ops/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,11 @@ bool Concat::measure_operator_cost(Simulator *sim,
(float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT);
out_of_memory = out_of_memory || (input_ptrs[i] == NULL);
}
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

out_of_memory = out_of_memory || (output_ptr == NULL);
if (out_of_memory) {
cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME;
Expand Down Expand Up @@ -406,8 +410,12 @@ bool Concat::measure_operator_cost(Simulator *sim,
(float *)sim->allocate(sub_inputs[i].get_volume(), DT_FLOAT);
out_of_memory = out_of_memory || (input_grad_ptrs[i] == NULL);
}
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);
float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

out_of_memory = out_of_memory || (output_grad_ptr == NULL);
if (out_of_memory) {
cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME;
Expand Down
7 changes: 0 additions & 7 deletions src/ops/conv_2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,13 +552,6 @@ bool Conv2D::measure_operator_cost(Simulator *sim,
float* bias_ptr = (float*)sim->allocate(output_c, DT_FLOAT);
assert(bias_ptr != NULL);

// compute memory usage
// Assume:
// 1. all memory allocations use Simulator::allocate
// 2. we call Simulator::free_all before measure an operator
// Therefore, the memory usage of an operator is sim->offset
cost_metrics.memory_requirement = (size_t)sim->offset;

// select forward algorithm
{
const int reqAlgCnt = 8;
Expand Down
12 changes: 5 additions & 7 deletions src/ops/conv_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -557,20 +557,18 @@ bool Conv2D::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *weight_ptr = (float *)sim->allocate(
(size_t)output_c * input_c * kernel_h * kernel_w / groups, DT_FLOAT);
assert(weight_ptr != NULL);
float *bias_ptr = (float *)sim->allocate(output_c, DT_FLOAT);
assert(bias_ptr != NULL);

// compute memory usage
// Assume:
// 1. all memory allocations use Simulator::allocate
// 2. we call Simulator::free_all before measure an operator
// Therefore, the memory usage of an operator is sim->offset
cost_metrics.memory_requirement = (size_t)sim->offset;
cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset);

// select forward algorithm
{
Expand Down
8 changes: 8 additions & 0 deletions src/ops/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,11 @@ bool Dropout::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

Expand All @@ -375,9 +378,14 @@ bool Dropout::measure_operator_cost(Simulator *sim,
float *input_grad_ptr =
(float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_grad_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
assert(output_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m, output_grad_ptr, input_grad_ptr);
};
Expand Down
8 changes: 8 additions & 0 deletions src/ops/element_binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -647,13 +647,16 @@ bool ElementBinary::measure_operator_cost(Simulator *sim,
assert(input1_ptr != NULL);
float *input2_ptr = (float *)sim->allocate(sub_input2.get_volume(), DT_FLOAT);
assert(input2_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = NULL;
if (inplace_a) {
output_ptr = input1_ptr;
} else {
output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
}
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

Expand All @@ -668,6 +671,8 @@ bool ElementBinary::measure_operator_cost(Simulator *sim,
float *input2_grad_ptr =
(float *)sim->allocate(sub_input2.get_volume(), DT_FLOAT);
assert(input2_grad_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr = NULL;
if (inplace_a) {
output_grad_ptr = input1_grad_ptr;
Expand All @@ -676,6 +681,9 @@ bool ElementBinary::measure_operator_cost(Simulator *sim,
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
}
assert(output_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m,
output_grad_ptr,
Expand Down
8 changes: 8 additions & 0 deletions src/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,13 +539,16 @@ bool ElementUnary::measure_operator_cost(Simulator *sim,
sim->free_all();
float *input_ptr = (float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_ptr = NULL;
if (inplace) {
output_ptr = input_ptr;
} else {
output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
}
assert(output_ptr != NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

assert(m->profiling == false);

Expand All @@ -557,6 +560,8 @@ bool ElementUnary::measure_operator_cost(Simulator *sim,
float *input_grad_ptr =
(float *)sim->allocate(sub_input.get_volume(), DT_FLOAT);
assert(input_grad_ptr != NULL);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *output_grad_ptr = NULL;
if (inplace) {
output_grad_ptr = input_grad_ptr;
Expand All @@ -565,6 +570,9 @@ bool ElementUnary::measure_operator_cost(Simulator *sim,
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
}
assert(output_grad_ptr != NULL);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);

backward = [&] {
backward_kernel_wrapper(m,
input_ptr,
Expand Down
13 changes: 13 additions & 0 deletions src/ops/embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,16 @@ bool Embedding::measure_operator_cost(Simulator *sim,
bool out_of_memory = false;
int64_t *input_ptr =
(int64_t *)sim->allocate(sub_input.get_volume(), DT_INT64);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

out_of_memory = out_of_memory || (input_ptr == NULL);
float *output_ptr = (float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
out_of_memory = out_of_memory || (output_ptr == NULL);
cost_metrics.outputs_memory += cost_metrics.total_mem_diff_from(sim->offset);

float *weight_ptr =
(float *)sim->allocate(num_entries * out_channels, DT_FLOAT);
cost_metrics.weights_memory += cost_metrics.total_mem_diff_from(sim->offset);
out_of_memory = out_of_memory || (weight_ptr == NULL);
if (out_of_memory) {
cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME;
Expand Down Expand Up @@ -657,13 +662,21 @@ bool Embedding::measure_operator_cost(Simulator *sim,
if (sim->computationMode == COMP_MODE_TRAINING) {
float *weight_grad_ptr =
(float *)sim->allocate(num_entries * out_channels, DT_FLOAT);
cost_metrics.weights_memory +=
cost_metrics.total_mem_diff_from(sim->offset);
out_of_memory = out_of_memory || (weight_grad_ptr == NULL);

float *output_grad_ptr =
(float *)sim->allocate(sub_output.get_volume(), DT_FLOAT);
cost_metrics.outputs_memory +=
cost_metrics.total_mem_diff_from(sim->offset);
out_of_memory = out_of_memory || (output_grad_ptr == NULL);

int64_t *input_grad_ptr =
(int64_t *)sim->allocate(sub_input.get_volume(), DT_INT64);
cost_metrics.inputs_memory += cost_metrics.total_mem_diff_from(sim->offset);
out_of_memory = out_of_memory || (input_grad_ptr == NULL);

if (out_of_memory) {
cost_metrics.forward_time = Simulator::MAXIMUM_TASK_RUN_TIME;
cost_metrics.backward_time = Simulator::MAXIMUM_TASK_RUN_TIME;
Expand Down
Loading

0 comments on commit 2bc49d5

Please sign in to comment.