Skip to content

Commit

Permalink
Add multi-objective global memory search algorithm (flexflow#493)
Browse files Browse the repository at this point in the history
* Initial change of search procedure with memory consideration (flexflow#278)

* [Memory] Add necessary types to support memory search. WIP.

* [Memory] Implement modified DP search algorithm with memory cost. Missing base solution. WIP.

* [Memory] Complete all changes to the search procedure to support multi-objective search with global memory.

A search procedure refactor is in the future plan.

* Add line to export clang compilation database, but not enable that.

* [Memory] Save some work

* [Memory] Allow different run time cost factor

* Update format

* Update format again

* Resolve compile error due to merge conflict

* Sync the changes again (flexflow#296)

* Save some more expressive logging

* Update format

* [Memory] Correct memory cost calculation

* Fix the build with CUDA_TOOLKIT_ROOT_DIR

* [Memory] Update calculation of memory cost

* Add logs folder to gitignore

* Improve dot graph representation

* [Dot] Update dot graph representation

* Move changes

* Quick fix to avoid bert segfault

* Grid search of lambda

* [WIP] Update

* [Interface] Add --memory-search argument

* [Memory] Update memory search

* [Interface] Save -ll:fsize info

* [WIP] Save per-device memory change

* Finalize per-device max memory threshold

* Update format

* Update comments to prepare for merging

* [WIP] Experiments to clear the caches

* Fixed a memory calculation bug

* Update minor issues

* Update based on review comments

* Remove unnecessary include

* Update based on review

* Update based on review

* Factor out lambda helper functions

* Fix a bug due to moving lambda function out

* Fix memory leak of the cached_simulator

---------

Co-authored-by: Gabriele Oliaro <[email protected]>
Co-authored-by: Colin Unger <[email protected]>
Co-authored-by: Zhihao Jia <[email protected]>
  • Loading branch information
4 people authored Feb 21, 2023
1 parent c122eb2 commit 2c4d257
Show file tree
Hide file tree
Showing 18 changed files with 1,439 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,6 @@ train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz
train-images-idx3-ubyte
train-labels-idx1-ubyte

# Logs
logs/
2 changes: 2 additions & 0 deletions include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class FFConfig {
int epochs, batchSize, printFreq;
// int inputHeight, inputWidth;
int numNodes, cpusPerNode, workersPerNode;
float device_mem; // The device (GPU) memory threshold; given by -ll:fsize
float learningRate, weightDecay;
size_t workSpaceSize;
Legion::Context lg_ctx;
Expand Down Expand Up @@ -155,6 +156,7 @@ class FFConfig {
int base_optimize_threshold;
bool enable_control_replication;
int python_data_loader_type;
bool perform_memory_search{false};
};

class FFIterationConfig {
Expand Down
63 changes: 63 additions & 0 deletions include/flexflow/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#define _FLEXFLOW_GRAPH_H_
#include "flexflow/basic_graph.h"
#include "flexflow/graph_structures.h"
#include "flexflow/memory_optimization.h"
#include "flexflow/model.h"
#include "flexflow/utils/dot/dot_file.h"
#include "flexflow/utils/recursive_logger.h"
Expand Down Expand Up @@ -114,6 +115,31 @@ struct GraphCostResult {
friend std::ostream &operator<<(std::ostream &, GraphCostResult const &);
};

/**
* @brief Holds the cost information of a PCG.
*/
struct GraphCostResultWithMemory {
float cost; ///< Run time cost
MemoryUsage mem_cost; ///< Memory usage
///< Corresponding machine views (device placement views)
std::unordered_map<Node, MachineView> views;

/**
* @brief Get the multi-objective cost that combines the run time and memory
* cost.
*
* @return float Numerical value to represent the overall cost
*/
float get_multi_obj_cost() const;

static GraphCostResultWithMemory invalid();

bool operator<(GraphCostResultWithMemory const &other) const;

friend std::ostream &operator<<(std::ostream &,
GraphCostResultWithMemory const &);
};

template <typename T>
T sequence_cost(T const &first, T const &second);

Expand Down Expand Up @@ -157,6 +183,12 @@ class SearchHelper {
NodeAssignment const &source,
NodeAssignment const &sink,
MachineResource const &resources) const;
/**
* @brief Starting point to get parallel split time cost.
*
* @tparam T float or GraphCostResult (or GraphCostResultWithMemory in memory
* optimization)
*/
template <typename T>
T find_optimal_nonsequence_graph_time(Graph const *g,
NodeAssignment const &source,
Expand Down Expand Up @@ -200,6 +232,20 @@ class SearchHelper {
template <typename T>
void add_operator_cost(NodeAssignment const &, float, T *) const;

template <typename T>
void add_sink_node_costs(NodeAssignment const &sink,
CostMetrics metrics,
T *result) const;

/**
* @brief Add run time cost and memory cost of the operator to the graph cost.
* This is a temp workaround and should be refactored eventually.
*/
void add_operator_cost_with_memory(NodeAssignment const &node,
float node_run_time_cost,
MemoryUsage node_mem_cost,
GraphCostResultWithMemory *cost) const;

template <typename T>
float get_cost(T const &) const;

Expand All @@ -209,6 +255,8 @@ class SearchHelper {
public:
mutable std::unique_ptr<RecursiveLogger> logger;

void clear_cache();

private:
template <typename T>
T execute_nonsequence_split(std::unique_ptr<Graph> const &first_graph,
Expand Down Expand Up @@ -260,6 +308,7 @@ class Graph {
Graph subgraph(std::unordered_set<Node> const &nodes) const;
void contract_out_node(Node const &);
float optimal_cost() const;
float optimal_cost_with_memory(float run_time_cost_factor) const;
std::unordered_map<Node, MachineView> optimal_views() const;
void remove_input_nodes();
void duplicate_input_node(Node const &);
Expand Down Expand Up @@ -335,6 +384,20 @@ struct GraphOptimizeResult {
friend std::ostream &operator<<(std::ostream &, GraphOptimizeResult const &);
};

/**
* @brief Hold the optimization results with memory information.
*/
struct GraphOptimizeResultWithMemory {
tl::optional<Graph> graph; ///< Optimized PCG
float cost; ///< Run time cost
MemoryUsage mem_cost; ///< Memory usage
///< Corresponding machine views (device placement views)
std::unordered_map<Node, MachineView> views;

friend std::ostream &operator<<(std::ostream &,
GraphOptimizeResultWithMemory const &);
};

namespace Utils {
template <>
struct GraphStructure<FlexFlow::PCG::Graph> {
Expand Down
107 changes: 107 additions & 0 deletions include/flexflow/memory_optimization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef _FLEXFLOW_MEMORY_OPTIMIZATION_H_
#define _FLEXFLOW_MEMORY_OPTIMIZATION_H_

#include <cassert>
#include <string>

namespace FlexFlow {

enum class MemoryUsageType {
// Use global memory of a PCG as the measure of memory usage. No device
// mapping consideration.
GLOBAL,

// Use the max of peak per-device memory usage among devices as the measure.
// Need associated device mapping views.
PER_DEVICE_MAX,
};

enum class MemorySearchAlgo {
// Multiple objective DP search. Combine memory cost and run time cost into
// one single cost function and add a factor to balance them.
MULTI_OBJECTIVE,
};

/**
* @brief Config class to control memory optimizations. This should be put into
* config.h and be stored in FFConfig. But for easy turnaround, put this here
* for now.
*/
class MemoryOptimConfig {
public:
MemoryUsageType mem_usage_type; ///< How to represent memory cost
MemorySearchAlgo mem_search_algo; ///< How to search for the optimal schedule
float run_time_cost_factor; ///< The weight factor of run time cost in the
///< overall cost function; used in
///< MULTI_OBJECTIVE algorithm
///< Valid between and including 0 and 1

MemoryOptimConfig()
: mem_usage_type{MemoryUsageType::GLOBAL},
mem_search_algo{MemorySearchAlgo::MULTI_OBJECTIVE},
run_time_cost_factor{0.5} {}
MemoryOptimConfig(float factor)
: mem_usage_type{MemoryUsageType::GLOBAL},
mem_search_algo{MemorySearchAlgo::MULTI_OBJECTIVE},
run_time_cost_factor{factor} {}
};

/**
* @brief Hold the result (including memory information) of a graph_optimize on
* a PCG.
*/
class MemorySearchResult {
public:
float run_time_cost{};
float memory_cost{};
float search_time{};
///< The max of per-device memory usage among all devices
float max_per_device_mem_all_deivces = 0.0;
};

namespace PCG {

/**
* @brief Class to hold memory usage information of a (sub-)PCG.
*/
class MemoryUsage {
public:
MemoryUsageType usage_type; ///< What "num" means
float num; ///< The numerical number of memory usage

MemoryUsage() : usage_type{MemoryUsageType::GLOBAL}, num{0.0} {}
MemoryUsage(MemoryUsageType _usage_type, float _num)
: usage_type{_usage_type}, num{_num} {}

std::string to_string() const;

MemoryUsage &operator+=(MemoryUsage const &rhs);

/**
* @brief Combine the memory usage of two PCGs flexibly based on
* MemoryUsageType.
*/
friend MemoryUsage operator+(MemoryUsage lhs, MemoryUsage const &rhs);

friend std::ostream &operator<<(std::ostream &s, MemoryUsage const &usage);
};

} // namespace PCG
} // namespace FlexFlow

#endif // _FLEXFLOW_MEMORY_OPTIMIZATION_H_
13 changes: 13 additions & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "accessor.h"
#include "config.h"
#include "device.h"
#include "flexflow/memory_optimization.h"
#include "flexflow/node.h"
#include "flexflow/operator_params.h"
#include "flexflow/utils/hash_utils.h"
Expand Down Expand Up @@ -784,6 +785,13 @@ class FFModel {
bool only_data_parallel,
std::unique_ptr<PCG::Graph> &best_graph,
std::unordered_map<PCG::Node, MachineView> &optimal_view);
void graph_optimize(size_t budget,
bool only_data_parallel,
std::unique_ptr<PCG::Graph> &best_graph,
std::unordered_map<PCG::Node, MachineView> &optimal_view,
bool perform_memory_search,
MemoryOptimConfig new_config,
MemorySearchResult &search_result);
void mcmc_optimize(std::map<Op const *, ParallelConfig> &best,
size_t budget,
float alpha,
Expand Down Expand Up @@ -821,6 +829,11 @@ class FFModel {
public:
void set_iteration_config_sequence_length(int seq_length);

/**
* @brief Clear the cache of the GraphSearchHelper and SearchHelper.
*/
void clear_graph_search_cache();

public:
size_t op_global_guid, layer_global_guid;
size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid;
Expand Down
7 changes: 4 additions & 3 deletions include/flexflow/parallel_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ struct ParallelDim {
return false;
}

int size = 0;
int degree = UNKNOWN_DEGREE;
int parallel_idx = UNKNOWN_INDEX;
int size = 0; // Actual size of tensor
int degree = UNKNOWN_DEGREE; // Degree of sharding
int parallel_idx = UNKNOWN_INDEX; // Runtime information, unique id of each
// degree of sharding
bool is_replica_dim = false;
};

Expand Down
11 changes: 10 additions & 1 deletion include/flexflow/simulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,17 @@ class FFModel;
*/
struct CostMetrics {
/**
* @brief Return the sum of the memory usage recorded in this CostMetrics.
* @brief Return the sum of inputs_memory, outputs_memory, and weights_memory
* recorded in this CostMetrics.
*/
size_t total_memory() const;

/**
* @brief Return the sum of memory recorded in this CostMetrics, but in MB,
* instead of Bytes.
*/
float total_memory_in_mb() const;

/**
* @brief Get the incremental difference between the total memory in
* CostMetrics and sim->offset.
Expand All @@ -76,6 +83,8 @@ struct CostMetrics {
// 2. we call Simulator::free_all before measuring an operator
// Therefore, the current memory usage of an operator is (size_t)sim->offset
size_t inputs_memory = 0, outputs_memory = 0, weights_memory = 0;
///< Memory usage of Op* considering parallelization over devices
size_t op_total_mem = 0;
};

class Device {
Expand Down
Loading

0 comments on commit 2c4d257

Please sign in to comment.