diff --git a/ml_metadata/tools/mlmd_bench/BUILD b/ml_metadata/tools/mlmd_bench/BUILD index b23d38806..42a4ee4b0 100644 --- a/ml_metadata/tools/mlmd_bench/BUILD +++ b/ml_metadata/tools/mlmd_bench/BUILD @@ -31,7 +31,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "//ml_metadata/metadata_store", + "//ml_metadata/metadata_store:metadata_store", "//ml_metadata/metadata_store:types", "//ml_metadata/proto:metadata_store_proto", "//ml_metadata/proto:metadata_store_service_proto", @@ -62,10 +62,87 @@ ml_metadata_cc_test( cc_library( name = "stats", + srcs = ["stats.cc"], hdrs = ["stats.h"], deps = [ + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "//ml_metadata/metadata_store:types", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +ml_metadata_cc_test( + name = "stats_test", + size = "small", + srcs = ["stats_test.cc"], + deps = [ + ":stats", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "benchmark", + srcs = ["benchmark.cc"], + hdrs = ["benchmark.h"], + deps = [ + ":fill_types_workload", + ":workload", + "@com_google_absl//absl/memory", + "//ml_metadata/metadata_store:types", + "//ml_metadata/tools/mlmd_bench/proto:mlmd_bench_proto", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +ml_metadata_cc_test( + name = "benchmark_test", + size = "small", + srcs = ["benchmark_test.cc"], + deps = [ + "benchmark", + "@com_google_googletest//:gtest_main", + "//ml_metadata/metadata_store:test_util", + "//ml_metadata/tools/mlmd_bench/proto:mlmd_bench_proto", + ], +) + +cc_library( + name = "thread_runner", + srcs = ["thread_runner.cc"], + hdrs = ["thread_runner.h"], + deps = [ + ":benchmark", + ":stats", + ":workload", + "//ml_metadata/metadata_store:metadata_store", + "//ml_metadata/metadata_store:metadata_store_factory", + "//ml_metadata/metadata_store:types", + "//ml_metadata/proto:metadata_store_proto", + "//ml_metadata/tools/mlmd_bench/proto:mlmd_bench_proto", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +ml_metadata_cc_test( + name = "thread_runner_test", + size = "small", + srcs = ["thread_runner_test.cc"], + deps = [ + ":benchmark", + ":stats", + ":thread_runner", + ":workload", + "@com_google_googletest//:gtest_main", + "//ml_metadata/metadata_store:metadata_store", + "//ml_metadata/metadata_store:metadata_store_factory", + "//ml_metadata/metadata_store:test_util", + "//ml_metadata/proto:metadata_store_service_proto", + "//ml_metadata/tools/mlmd_bench/proto:mlmd_bench_proto", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core:test", ], ) @@ -75,7 +152,7 @@ cc_library( hdrs = ["util.h"], deps = [ "@com_google_absl//absl/types:variant", - "//ml_metadata/metadata_store", + "//ml_metadata/metadata_store:metadata_store", "//ml_metadata/proto:metadata_store_proto", "//ml_metadata/proto:metadata_store_service_proto", "//ml_metadata/tools/mlmd_bench/proto:mlmd_bench_proto", @@ -89,7 +166,7 @@ cc_library( deps = [ ":stats", "@com_google_absl//absl/time", - "//ml_metadata/metadata_store", + "//ml_metadata/metadata_store:metadata_store", "//ml_metadata/metadata_store:types", "@org_tensorflow//tensorflow/core:lib", ], @@ -103,7 +180,7 @@ ml_metadata_cc_test( ":workload", #"@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", - "//ml_metadata/metadata_store", + "//ml_metadata/metadata_store:metadata_store", "//ml_metadata/metadata_store:metadata_store_factory", "//ml_metadata/metadata_store:types", "//ml_metadata/proto:metadata_store_proto", diff --git a/ml_metadata/tools/mlmd_bench/benchmark.cc b/ml_metadata/tools/mlmd_bench/benchmark.cc new file mode 100644 index 000000000..42619271f --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/benchmark.cc @@ -0,0 +1,55 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/benchmark.h" + +#include + +#include "absl/memory/memory.h" +#include "ml_metadata/metadata_store/types.h" +#include "ml_metadata/tools/mlmd_bench/fill_types_workload.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" +#include "ml_metadata/tools/mlmd_bench/workload.h" +#include "tensorflow/core/platform/logging.h" + +namespace ml_metadata { +namespace { + +// Creates the executable workload given `workload_config`. +void CreateWorkload(const WorkloadConfig& workload_config, + std::unique_ptr& workload) { + if (!workload_config.has_fill_types_config()) { + LOG(FATAL) << "Cannot find corresponding workload!"; + } + workload = absl::make_unique(FillTypes( + workload_config.fill_types_config(), workload_config.num_operations())); +} + +} // namespace + +Benchmark::Benchmark(const MLMDBenchConfig& mlmd_bench_config) { + workloads_.resize(mlmd_bench_config.workload_configs_size()); + + // For each `workload_config`, calls CreateWorkload() to create corresponding + // workload. + for (int i = 0; i < mlmd_bench_config.workload_configs_size(); ++i) { + CreateWorkload(mlmd_bench_config.workload_configs(i), workloads_[i]); + } +} + +WorkloadBase* Benchmark::workload(const int64 workload_index) { + return workloads_[workload_index].get(); +} + +} // namespace ml_metadata diff --git a/ml_metadata/tools/mlmd_bench/benchmark.h b/ml_metadata/tools/mlmd_bench/benchmark.h new file mode 100644 index 000000000..56a601aa5 --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/benchmark.h @@ -0,0 +1,46 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef ML_METADATA_TOOLS_MLMD_BENCH_BENCHMARK_H +#define ML_METADATA_TOOLS_MLMD_BENCH_BENCHMARK_H + +#include + +#include "ml_metadata/metadata_store/types.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" +#include "ml_metadata/tools/mlmd_bench/workload.h" + +namespace ml_metadata { + +// Contains a list of workloads to be executed by ThreadRunner. +// The executable workloads are generated according to `mlmd_bench_config`. +class Benchmark { + public: + Benchmark(const MLMDBenchConfig& mlmd_bench_config); + ~Benchmark() = default; + + // Returns a particular executable workload given `workload_index`. + WorkloadBase* workload(int64 workload_index); + + // Returns the number of executable workloads existed inside benchmark. + int64 num_workloads() const { return workloads_.size(); } + + private: + // A list of executable workloads. + std::vector> workloads_; +}; + +} // namespace ml_metadata + +#endif // ML_METADATA_TOOLS_MLMD_BENCH_BENCHMARK_H diff --git a/ml_metadata/tools/mlmd_bench/benchmark_test.cc b/ml_metadata/tools/mlmd_bench/benchmark_test.cc new file mode 100644 index 000000000..3af45af53 --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/benchmark_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/benchmark.h" + +#include +#include "ml_metadata/metadata_store/test_util.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" + +namespace ml_metadata { +namespace { + +// Tests the CreateWorkload() of Benchmark class. +TEST(BenchmarkTest, CreatWorkloadTest) { + MLMDBenchConfig mlmd_bench_config = + testing::ParseTextProtoOrDie( + R"( + workload_configs: { + fill_types_config: { + update: false + specification: ARTIFACT_TYPE + num_properties: { minimum: 1 maximum: 10 } + } + num_operations: 100 + } + workload_configs: { + fill_types_config: { + update: true + specification: EXECUTION_TYPE + num_properties: { minimum: 1 maximum: 10 } + } + num_operations: 500 + } + workload_configs: { + fill_types_config: { + update: false + specification: CONTEXT_TYPE + num_properties: { minimum: 1 maximum: 10 } + } + num_operations: 300 + } + )"); + Benchmark benchmark(mlmd_bench_config); + // Checks that all workload configurations have transformed into executable + // workloads inside benchmark. + EXPECT_EQ(benchmark.num_workloads(), + mlmd_bench_config.workload_configs_size()); +} + +} // namespace +} // namespace ml_metadata diff --git a/ml_metadata/tools/mlmd_bench/stats.cc b/ml_metadata/tools/mlmd_bench/stats.cc new file mode 100644 index 000000000..f3082c163 --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/stats.cc @@ -0,0 +1,86 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/stats.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "ml_metadata/metadata_store/types.h" + +namespace ml_metadata { + +ThreadStats::ThreadStats() + : accumulated_elapsed_time_(absl::Nanoseconds(0)), + done_(0), + bytes_(0), + next_report_(100) {} + +void ThreadStats::Start() { start_ = absl::Now(); } + +void ThreadStats::Update(const OpStats& op_stats, + const int64 approx_total_done) { + bytes_ += op_stats.transferred_bytes; + accumulated_elapsed_time_ += op_stats.elapsed_time; + done_++; + static const int report_thresholds[]{1000, 5000, 10000, 50000, + 100000, 500000, 1000000}; + int threshold_index = 0; + if (approx_total_done < next_report_) { + return; + } + // Reports the current progress with `approx_total_done`. + next_report_ += report_thresholds[threshold_index] / 10; + if (next_report_ > report_thresholds[threshold_index]) { + threshold_index++; + } + std::fprintf(stderr, "... finished %lld ops%30s\r", approx_total_done, ""); + std::fflush(stderr); +} + +void ThreadStats::Stop() { finish_ = absl::Now(); } + +void ThreadStats::Merge(const ThreadStats& other) { + // Accumulates done_, bytes_ and accumulated_elapsed_time_ of each thread + // stats. + done_ += other.done(); + bytes_ += other.bytes(); + accumulated_elapsed_time_ += other.accumulated_elapsed_time(); + // Chooses the earliest start time and latest end time of each merged + // thread stats. + start_ = std::min(start_, other.start()); + finish_ = std::max(finish_, other.finish()); +} + +void ThreadStats::Report(const std::string& specification) { + std::string extra; + if (bytes_ > 0) { + // Rate is computed on actual elapsed time (latest end time minus + // earliest start time of each thread) instead of the sum of per-thread + // elapsed times. + int64 elapsed_seconds = accumulated_elapsed_time_ / absl::Seconds(1); + std::string rate = + absl::StrFormat("%6.1f KB/s", (bytes_ / 1024.0) / elapsed_seconds); + extra = rate; + } + std::fprintf( + stdout, "%-12s : %11.3f micros/op;%s%s\n", specification.c_str(), + (double)(accumulated_elapsed_time_ / absl::Microseconds(1)) / done_, + (extra.empty() ? "" : " "), extra.c_str()); + std::fflush(stdout); +} + +} // namespace ml_metadata diff --git a/ml_metadata/tools/mlmd_bench/stats.h b/ml_metadata/tools/mlmd_bench/stats.h index d85ffbc29..cdea937a5 100644 --- a/ml_metadata/tools/mlmd_bench/stats.h +++ b/ml_metadata/tools/mlmd_bench/stats.h @@ -20,13 +20,73 @@ limitations under the License. namespace ml_metadata { -// OpStats records the statics(elapsed microsecond, transferred bytes) of each +// OpStats records the statics(elapsed time, transferred bytes) of each // operation. It will be used to update the thread stats. struct OpStats { absl::Duration elapsed_time; int64 transferred_bytes; }; +// ThreadStats records the statics(start time, end time, elapsed time, total +// operations done, transferred bytes) of each thread. It will be updated by +// Opstats. Every ThreadStats of a particular workload will be merged together +// after each thread has finished execution to generate a workload stats for +// reporting the performance of current workload. +class ThreadStats { + public: + ThreadStats(); + ~ThreadStats() = default; + + // Starts the current thread stats and initializes the member variables. + void Start(); + + // Updates the current thread stats with op_stats. + void Update(const OpStats& op_stats, int64 approx_total_done); + + // Records the end time for each thread after the current thread has finished + // all the operations. + void Stop(); + + // Merges the thread stats instances into a workload stats that will be used + // for report purpose. + void Merge(const ThreadStats& other); + + // Reports the metrics of interests: microsecond per operation and total bytes + // per seconds for the current workload. + void Report(const std::string& specification); + + // Gets the start time of current thread stats. + absl::Time start() const { return start_; } + + // Gets the finish time of current thread stats. + absl::Time finish() const { return finish_; } + + // Gets the accumulated elapsed time of current thread stats. + absl::Duration accumulated_elapsed_time() const { + return accumulated_elapsed_time_; + } + + // Gets the number of total finished operations of current thread stats. + int64 done() const { return done_; } + + // Gets the total transferred bytes of current thread stats. + int64 bytes() const { return bytes_; } + + private: + // Records the start time of current thread stats. + absl::Time start_; + // Records the finish time of current thread stats. + absl::Time finish_; + // Records the accumulated elapsed time of current thread stats. + absl::Duration accumulated_elapsed_time_; + // Records the number of total finished operations of current thread stats. + int64 done_; + // Records the total transferred bytes of current thread stats. + int64 bytes_; + // Uses in Report() for console reporting. + int64 next_report_; +}; + } // namespace ml_metadata #endif // ML_METADATA_TOOLS_MLMD_BENCH_STATS_H diff --git a/ml_metadata/tools/mlmd_bench/stats_test.cc b/ml_metadata/tools/mlmd_bench/stats_test.cc new file mode 100644 index 000000000..fcd56faf4 --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/stats_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/stats.h" + +#include + +#include +#include "absl/time/clock.h" +#include "absl/time/time.h" + +namespace ml_metadata { +namespace { + +// Tests the Update() of Stats class. +TEST(ThreadStatsTest, UpdateTest) { + srand(time(NULL)); + + ThreadStats stats; + stats.Start(); + + // Prepares the list of `op_stats` to update `stats`. + std::vector op_stats_time; + std::vector op_stats_bytes; + for (int i = 0; i < 10000; ++i) { + op_stats_time.push_back(absl::Microseconds(rand() % 99999)); + op_stats_bytes.push_back(rand() % 99999); + } + + // Updates `stats` with the list of `op_stats`. + for (int64 i = 0; i < 10000; ++i) { + OpStats curr_op_stats{op_stats_time[i], op_stats_bytes[i]}; + stats.Update(curr_op_stats, i); + } + + // Since the `done_`, `bytes_` and `accumulated_elapsed_time_` are accumulated + // by each update, the finial `done_`, `bytes_` and + // `accumulated_elapsed_time_` of `stats` should be the sum of the list of + // `op_stats`. + EXPECT_EQ(stats.done(), 10000); + EXPECT_EQ(stats.accumulated_elapsed_time(), + std::accumulate(op_stats_time.begin(), op_stats_time.end(), + absl::Microseconds(0))); + EXPECT_EQ(stats.bytes(), + std::accumulate(op_stats_bytes.begin(), op_stats_bytes.end(), 0)); +} + +// Tests the Merge() of Stats class. +TEST(ThreadStatsTest, MergeTest) { + srand(time(NULL)); + + ThreadStats stats1; + stats1.Start(); + ThreadStats stats2; + stats2.Start(); + absl::Time ealiest_start_time = std::min(stats1.start(), stats2.start()); + + std::vector op_stats_time; + std::vector op_stats_bytes; + for (int i = 0; i < 10000; ++i) { + op_stats_time.push_back(absl::Microseconds(rand() % 99999)); + op_stats_bytes.push_back(rand() % 99999); + } + + // Updates the stats with the prepared list of `op_stats`. + for (int64 i = 0; i < 10000; ++i) { + OpStats curr_op_stats{op_stats_time[i], op_stats_bytes[i]}; + if (i <= 4999) { + stats1.Update(curr_op_stats, i); + } else { + stats2.Update(curr_op_stats, i); + } + } + + stats1.Stop(); + stats2.Stop(); + absl::Time latest_end_time = std::max(stats1.finish(), stats2.finish()); + + stats1.Merge(stats2); + + // Since the Merge() accumulates the `done_`, `bytes_` and + // `accumulated_elapsed_time_` of each merged stats, the final stats's + // `done_`, `bytes_` and `accumulated_elapsed_time_` should be the sum of the + // stats. + EXPECT_EQ(stats1.done(), 10000); + EXPECT_EQ(stats1.accumulated_elapsed_time(), + std::accumulate(op_stats_time.begin(), op_stats_time.end(), + absl::Microseconds(0))); + EXPECT_EQ(stats1.bytes(), + std::accumulate(op_stats_bytes.begin(), op_stats_bytes.end(), 0)); + + // In Merge(), we takes the earliest start time and latest end time as + // the start and end time of the merged stats. + EXPECT_EQ(stats1.start(), ealiest_start_time); + EXPECT_EQ(stats1.finish(), latest_end_time); +} + +} // namespace +} // namespace ml_metadata diff --git a/ml_metadata/tools/mlmd_bench/thread_runner.cc b/ml_metadata/tools/mlmd_bench/thread_runner.cc new file mode 100644 index 000000000..76b9c991f --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/thread_runner.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/thread_runner.h" + +#include + +#include "ml_metadata/metadata_store/metadata_store.h" +#include "ml_metadata/metadata_store/metadata_store_factory.h" +#include "ml_metadata/metadata_store/types.h" +#include "ml_metadata/proto/metadata_store.pb.h" +#include "ml_metadata/tools/mlmd_bench/benchmark.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" +#include "ml_metadata/tools/mlmd_bench/stats.h" +#include "ml_metadata/tools/mlmd_bench/workload.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace ml_metadata { +namespace { + +// Prepares a list of MLMD client instance(`stores`) for each thread. +tensorflow::Status PrepareStoresForThreads( + const ConnectionConfig mlmd_config, const int64 num_threads, + std::vector>& stores) { + stores.resize(num_threads); + // Each thread uses a different MLMD client instance to talk to + // the same back-end. + for (int64 i = 0; i < num_threads; ++i) { + std::unique_ptr store; + TF_RETURN_IF_ERROR(CreateMetadataStore(mlmd_config, &store)); + stores[i] = std::move(store); + } + return tensorflow::Status::OK(); +} + +// Sets up the current workload. +tensorflow::Status SetUpWorkload(const ConnectionConfig mlmd_config, + WorkloadBase* workload) { + std::unique_ptr set_up_store; + TF_RETURN_IF_ERROR(CreateMetadataStore(mlmd_config, &set_up_store)); + TF_RETURN_IF_ERROR(workload->SetUp(set_up_store.get())); + return tensorflow::Status::OK(); +} + +// Executes the current workload and updates `curr_thread_stats` with `op_stats` +// along the way. +tensorflow::Status ExecuteWorkload(const int64 work_items_start_index, + const int64 op_per_thread, + MetadataStore* curr_store, + WorkloadBase* workload, + int64& approx_total_done, + ThreadStats& curr_thread_stats) { + int64 work_items_index = work_items_start_index; + while (work_items_index < work_items_start_index + op_per_thread) { + // Each operation has a op_stats. + OpStats op_stats; + tensorflow::Status status = + workload->RunOp(work_items_index, curr_store, op_stats); + // If the error is not Abort error, break the current process. + if (!status.ok() && status.code() != tensorflow::error::ABORTED) { + TF_RETURN_IF_ERROR(status); + } + // Handles abort issues for concurrent writing to the db. + if (!status.ok()) { + continue; + } + work_items_index++; + approx_total_done++; + // Updates the current thread stats using the `op_stats`. + curr_thread_stats.Update(op_stats, approx_total_done); + } + return tensorflow::Status::OK(); +} + +// Merges all the thread stats inside `thread_stats_list` into a workload stats +// and reports the workload's performance according to that. +void MergeThreadStatsAndReport(const std::string workload_name, + ThreadStats thread_stats_list[], int64 size) { + for (int64 i = 1; i < size; ++i) { + thread_stats_list[0].Merge(thread_stats_list[i]); + } + // Reports the metrics of interests. + // TODO(briansong) Return the report as a summary proto. + thread_stats_list[0].Report(workload_name); +} + +} // namespace + +ThreadRunner::ThreadRunner(const ConnectionConfig& mlmd_config, + const int64 num_threads) + : mlmd_config_(mlmd_config), num_threads_(num_threads) {} + +// The thread runner will first loops over all the executable workloads in +// benchmark and executes them one by one. Each workload will have a +// `thread_stats_list` to record the stats of each thread when executing the +// current workload. +// During the execution, each operation will has a `op_stats` to record current +// operation statistic. Each `op_stats` will be used to update the +// `thread_stats`. +// After the each thread has finished the execution, the workload stats will be +// generated by merging all the thread stats inside the `thread_stats_list`. The +// performance of the workload will be reported according to the workload stats. +tensorflow::Status ThreadRunner::Run(Benchmark& benchmark) { + for (int i = 0; i < benchmark.num_workloads(); ++i) { + WorkloadBase* workload = benchmark.workload(i); + ThreadStats thread_stats_list[num_threads_]; + tensorflow::Status thread_status_list[num_threads_]; + TF_RETURN_IF_ERROR(SetUpWorkload(mlmd_config_, workload)); + const int64 op_per_thread = workload->num_operations() / num_threads_; + std::vector> stores; + TF_RETURN_IF_ERROR( + PrepareStoresForThreads(mlmd_config_, num_threads_, stores)); + { + // Create a thread pool for multi-thread execution. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "mlmd_bench", num_threads_); + // `approx_total_done` is used for reporting progress along the way. + int64 approx_total_done = 0; + for (int64 t = 0; t < num_threads_; ++t) { + const int64 work_items_start_index = op_per_thread * t; + ThreadStats& curr_thread_stats = thread_stats_list[t]; + MetadataStore* curr_store = stores[t].get(); + tensorflow::Status& curr_status = thread_status_list[t]; + pool.Schedule([this, op_per_thread, workload, work_items_start_index, + curr_store, &curr_thread_stats, &curr_status, + &approx_total_done]() { + curr_thread_stats.Start(); + curr_status.Update( + ExecuteWorkload(work_items_start_index, op_per_thread, curr_store, + workload, approx_total_done, curr_thread_stats)); + curr_thread_stats.Stop(); + }); + TF_RETURN_IF_ERROR(curr_status); + } + } + TF_RETURN_IF_ERROR(workload->TearDown()); + MergeThreadStatsAndReport(workload->GetName(), thread_stats_list, + num_threads_); + } + return tensorflow::Status::OK(); +} + +} // namespace ml_metadata diff --git a/ml_metadata/tools/mlmd_bench/thread_runner.h b/ml_metadata/tools/mlmd_bench/thread_runner.h new file mode 100644 index 000000000..fbe1daa28 --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/thread_runner.h @@ -0,0 +1,45 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef ML_METADATA_TOOLS_MLMD_BENCH_THREAD_RUNNER_H +#define ML_METADATA_TOOLS_MLMD_BENCH_THREAD_RUNNER_H + +#include "ml_metadata/proto/metadata_store.pb.h" +#include "ml_metadata/tools/mlmd_bench/benchmark.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace ml_metadata { + +// The ThreadRunner class is the execution component of the `mlmd_bench` +// It takes the benchmark and runs the workloads. +class ThreadRunner { + public: + ThreadRunner(const ConnectionConfig& mlmd_config, int64 num_threads); + ~ThreadRunner() = default; + + // Execution unit of `mlmd_bench`. Returns detailed error if query executions + // failed. + tensorflow::Status Run(Benchmark& benchmark); + + private: + // Connection configuration that will be used to create the MetadataStore. + const ConnectionConfig mlmd_config_; + // Number of threads for the thread runner. + const int64 num_threads_; +}; + +} // namespace ml_metadata + +#endif // ML_METADATA_TOOLS_MLMD_BENCH_THREAD_RUNNER_H diff --git a/ml_metadata/tools/mlmd_bench/thread_runner_test.cc b/ml_metadata/tools/mlmd_bench/thread_runner_test.cc new file mode 100644 index 000000000..eeef2427c --- /dev/null +++ b/ml_metadata/tools/mlmd_bench/thread_runner_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2020 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "ml_metadata/tools/mlmd_bench/thread_runner.h" + +#include +#include "ml_metadata/metadata_store/metadata_store.h" +#include "ml_metadata/metadata_store/metadata_store_factory.h" +#include "ml_metadata/metadata_store/test_util.h" +#include "ml_metadata/proto/metadata_store_service.pb.h" +#include "ml_metadata/tools/mlmd_bench/benchmark.h" +#include "ml_metadata/tools/mlmd_bench/proto/mlmd_bench.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace ml_metadata { +namespace { + +void TestThreadRunner(const int num_thread) { + MLMDBenchConfig mlmd_bench_config; + mlmd_bench_config.mutable_thread_env_config()->set_num_threads(num_thread); + mlmd_bench_config.add_workload_configs()->CopyFrom( + testing::ParseTextProtoOrDie(R"( + fill_types_config: { + update: false + specification: ARTIFACT_TYPE + num_properties: { minimum: 1 maximum: 10 } + } + num_operations: 100 + )")); + // Uses a fake in-memory SQLite database for testing. + mlmd_bench_config.mutable_mlmd_config()->mutable_sqlite()->set_filename_uri( + absl::StrCat("mlmd-bench-test_", num_thread, ".db")); + Benchmark benchmark(mlmd_bench_config); + ThreadRunner runner(mlmd_bench_config.mlmd_config(), + mlmd_bench_config.thread_env_config().num_threads()); + TF_ASSERT_OK(runner.Run(benchmark)); + + std::unique_ptr store; + TF_ASSERT_OK(CreateMetadataStore(mlmd_bench_config.mlmd_config(), &store)); + + GetArtifactTypesResponse get_response; + TF_ASSERT_OK(store->GetArtifactTypes(/*request=*/{}, &get_response)); + // Checks that the workload indeed be executed by the thread_runner. + EXPECT_EQ(get_response.artifact_types_size(), + mlmd_bench_config.workload_configs()[0].num_operations()); +} + +// Tests the Run() of ThreadRunner class in single-thread mode. +TEST(ThreadRunnerTest, RunInSingleThreadTest) { TestThreadRunner(1); } + +// Tests the Run() of ThreadRunner class in multi-thread mode. +TEST(ThreadRunnerTest, RunInMultiThreadTest) { TestThreadRunner(10); } + +} // namespace +} // namespace ml_metadata