From bb5d1163766405d657cf8a85731cb5864a8677b2 Mon Sep 17 00:00:00 2001 From: ml-metadata-team Date: Tue, 21 Jul 2020 10:26:22 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 322388863 --- .../metadata_store/metadata_access_object.h | 34 +++ .../metadata_access_object_test.cc | 274 ++++++++++++++++++ .../metadata_store/query_config_executor.cc | 10 + .../metadata_store/query_config_executor.h | 6 + ml_metadata/metadata_store/query_executor.h | 14 + .../rdbms_metadata_access_object.cc | 19 +- .../rdbms_metadata_access_object.h | 8 + 7 files changed, 362 insertions(+), 3 deletions(-) diff --git a/ml_metadata/metadata_store/metadata_access_object.h b/ml_metadata/metadata_store/metadata_access_object.h index cdf6aac58..e166b1d4b 100644 --- a/ml_metadata/metadata_store/metadata_access_object.h +++ b/ml_metadata/metadata_store/metadata_access_object.h @@ -181,6 +181,40 @@ class MetadataAccessObject { std::vector* artifacts, std::string* next_page_token) = 0; + // Queries executions stored in the metadata source using `options`. + // `options` is the ListOperationOptions proto message defined + // in metadata_store. + // If successfull: + // 1. `executions` is updated with result set of size determined by + // max_result_size set in `options`. + // 2. `next_page_token` is populated with information necessary to fetch next + // page of results. + // RETURNS INVALID_ARGUMENT if the `options` is invalid with one of + // the cases: + // 1. order_by_field is not set or has an unspecified field. + // 2. Direction of ordering is not specified for the order_by_field. + // 3. next_page_token cannot be decoded. + virtual tensorflow::Status ListExecutions(const ListOperationOptions& options, + std::vector* executions, + std::string* next_page_token) = 0; + + // Queries contexts stored in the metadata source using `options`. + // `options` is the ListOperationOptions proto message defined + // in metadata_store. + // If successfull: + // 1. `contexts` is updated with result set of size determined by + // max_result_size set in `options`. + // 2. `next_page_token` is populated with information necessary to fetch next + // page of results. + // RETURNS INVALID_ARGUMENT if the `options` is invalid with one of + // the cases: + // 1. order_by_field is not set or has an unspecified field. + // 2. Direction of ordering is not specified for the order_by_field. + // 3. next_page_token cannot be decoded. + virtual tensorflow::Status ListContexts(const ListOperationOptions& options, + std::vector* contexts, + std::string* next_page_token) = 0; + // Queries an artifact by its type_id and name. // Returns NOT_FOUND error, if no artifact can be found. // Returns detailed INTERNAL error, if query execution fails. diff --git a/ml_metadata/metadata_store/metadata_access_object_test.cc b/ml_metadata/metadata_store/metadata_access_object_test.cc index e7b78141a..d88b9fb3b 100644 --- a/ml_metadata/metadata_store/metadata_access_object_test.cc +++ b/ml_metadata/metadata_store/metadata_access_object_test.cc @@ -1168,6 +1168,280 @@ TEST_P(MetadataAccessObjectTest, ListArtifactsWithInvalidNextPageToken) { tensorflow::error::INVALID_ARGUMENT); } +TEST_P(MetadataAccessObjectTest, ListExecutionsWithNonIdFieldOptions) { + TF_ASSERT_OK(Init()); + ExecutionType type = ParseTextProtoOrDie(R"( + name: 'test_type' + properties { key: 'property_1' value: INT } + properties { key: 'property_2' value: DOUBLE } + properties { key: 'property_3' value: STRING } + )"); + int64 type_id; + TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id)); + + Execution sample_execution = ParseTextProtoOrDie(R"( + properties { + key: 'property_1' + value: { int_value: 3 } + } + properties { + key: 'property_2' + value: { double_value: 3.0 } + } + properties { + key: 'property_3' + value: { string_value: '3' } + } + custom_properties { + key: 'custom_property_1' + value: { string_value: '5' } + } + )"); + sample_execution.set_type_id(type_id); + const int total_stored_executions = 6; + int64 last_stored_execution_id; + + for (int i = 0; i < total_stored_executions; i++) { + TF_ASSERT_OK(metadata_access_object_->CreateExecution( + sample_execution, &last_stored_execution_id)); + } + + const int page_size = 2; + ListOperationOptions list_options = + ParseTextProtoOrDie(R"( + max_result_size: 2, + order_by_field: { field: CREATE_TIME is_asc: false } + )"); + + int64 expected_execution_id = last_stored_execution_id; + std::string next_page_token; + + do { + std::vector got_executions; + TF_ASSERT_OK(metadata_access_object_->ListExecutions( + list_options, &got_executions, &next_page_token)); + EXPECT_TRUE(got_executions.size() <= page_size); + for (const Execution& execution : got_executions) { + sample_execution.set_id(expected_execution_id--); + + EXPECT_THAT(execution, EqualsProto(sample_execution, /*ignore_fields=*/{ + "create_time_since_epoch", + "last_update_time_since_epoch"})); + } + list_options.set_next_page_token(next_page_token); + } while (!next_page_token.empty()); + + EXPECT_EQ(expected_execution_id, 0); +} + +TEST_P(MetadataAccessObjectTest, ListExecutionsWithIdFieldOptions) { + TF_ASSERT_OK(Init()); + ExecutionType type = ParseTextProtoOrDie(R"( + name: 'test_type' + properties { key: 'property_1' value: INT } + )"); + int64 type_id; + TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id)); + + Execution sample_execution = ParseTextProtoOrDie(R"( + properties { + key: 'property_1' + value: { int_value: 3 } + } + custom_properties { + key: 'custom_property_1' + value: { string_value: '5' } + } + )"); + + sample_execution.set_type_id(type_id); + int stored_executions_count = 0; + int64 first_execution_id; + TF_ASSERT_OK(metadata_access_object_->CreateExecution(sample_execution, + &first_execution_id)); + stored_executions_count++; + + for (int i = 0; i < 6; i++) { + int64 unused_execution_id; + TF_ASSERT_OK(metadata_access_object_->CreateExecution( + sample_execution, &unused_execution_id)); + } + stored_executions_count += 6; + + const int page_size = 2; + ListOperationOptions list_options = + ParseTextProtoOrDie(R"( + max_result_size: 2, + order_by_field: { field: ID is_asc: true } + )"); + + std::string next_page_token; + int64 expected_execution_id = first_execution_id; + int seen_executions_count = 0; + do { + std::vector got_executions; + TF_ASSERT_OK(metadata_access_object_->ListExecutions( + list_options, &got_executions, &next_page_token)); + EXPECT_TRUE(got_executions.size() <= page_size); + for (const Execution& execution : got_executions) { + sample_execution.set_id(expected_execution_id++); + + EXPECT_THAT(execution, EqualsProto(sample_execution, /*ignore_fields=*/{ + "create_time_since_epoch", + "last_update_time_since_epoch"})); + seen_executions_count++; + } + list_options.set_next_page_token(next_page_token); + } while (!next_page_token.empty()); + + EXPECT_EQ(stored_executions_count, seen_executions_count); +} + +TEST_P(MetadataAccessObjectTest, ListContextsWithNonIdFieldOptions) { + TF_ASSERT_OK(Init()); + ContextType type = ParseTextProtoOrDie(R"( + name: 'test_type' + properties { key: 'property_1' value: INT } + properties { key: 'property_2' value: DOUBLE } + properties { key: 'property_3' value: STRING } + )"); + int64 type_id; + TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id)); + + Context sample_context = ParseTextProtoOrDie(R"( + properties { + key: 'property_1' + value: { int_value: 3 } + } + properties { + key: 'property_2' + value: { double_value: 3.0 } + } + properties { + key: 'property_3' + value: { string_value: '3' } + } + custom_properties { + key: 'custom_property_1' + value: { string_value: '5' } + } + )"); + sample_context.set_type_id(type_id); + int64 last_stored_context_id; + int context_name_suffix = 0; + sample_context.set_name("list_contexts_test-1"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &last_stored_context_id)); + + context_name_suffix++; + sample_context.set_name("list_contexts_test-2"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &last_stored_context_id)); + context_name_suffix++; + sample_context.set_name("list_contexts_test-3"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &last_stored_context_id)); + context_name_suffix++; + + const int page_size = 2; + ListOperationOptions list_options = + ParseTextProtoOrDie(R"( + max_result_size: 2, + order_by_field: { field: CREATE_TIME is_asc: false } + )"); + + int64 expected_context_id = last_stored_context_id; + std::string next_page_token; + + do { + std::vector got_contexts; + TF_ASSERT_OK(metadata_access_object_->ListContexts( + list_options, &got_contexts, &next_page_token)); + EXPECT_TRUE(got_contexts.size() <= page_size); + for (const Context& context : got_contexts) { + sample_context.set_name( + absl::StrCat("list_contexts_test-", context_name_suffix--)); + sample_context.set_id(expected_context_id--); + EXPECT_THAT(context, EqualsProto(sample_context, /*ignore_fields=*/{ + "create_time_since_epoch", + "last_update_time_since_epoch"})); + } + list_options.set_next_page_token(next_page_token); + } while (!next_page_token.empty()); + + EXPECT_EQ(expected_context_id, 0); +} + +TEST_P(MetadataAccessObjectTest, ListContextsWithIdFieldOptions) { + TF_ASSERT_OK(Init()); + ContextType type = ParseTextProtoOrDie(R"( + name: 'test_type' + properties { key: 'property_1' value: INT } + )"); + int64 type_id; + TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id)); + + Context sample_context = ParseTextProtoOrDie(R"( + properties { + key: 'property_1' + value: { int_value: 3 } + } + custom_properties { + key: 'custom_property_1' + value: { string_value: '5' } + } + )"); + + sample_context.set_type_id(type_id); + int stored_contexts_count = 0; + int64 first_context_id; + sample_context.set_name("list_contexts_test-1"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &first_context_id)); + + int64 unused_context_id; + stored_contexts_count++; + sample_context.set_name("list_contexts_test-2"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &unused_context_id)); + stored_contexts_count++; + sample_context.set_name("list_contexts_test-3"); + TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context, + &unused_context_id)); + stored_contexts_count++; + + const int page_size = 2; + ListOperationOptions list_options = + ParseTextProtoOrDie(R"( + max_result_size: 2, + order_by_field: { field: ID is_asc: true } + )"); + + std::string next_page_token; + int64 expected_context_id = first_context_id; + int expected_context_name_suffix = 1; + int seen_contexts_count = 0; + do { + std::vector got_contexts; + TF_ASSERT_OK(metadata_access_object_->ListContexts( + list_options, &got_contexts, &next_page_token)); + EXPECT_TRUE(got_contexts.size() <= page_size); + for (const Context& context : got_contexts) { + sample_context.set_name( + absl::StrCat("list_contexts_test-", expected_context_name_suffix++)); + sample_context.set_id(expected_context_id++); + + EXPECT_THAT(context, EqualsProto(sample_context, /*ignore_fields=*/{ + "create_time_since_epoch", + "last_update_time_since_epoch"})); + seen_contexts_count++; + } + list_options.set_next_page_token(next_page_token); + } while (!next_page_token.empty()); + + EXPECT_EQ(stored_contexts_count, seen_contexts_count); +} + TEST_P(MetadataAccessObjectTest, DefaultArtifactState) { TF_ASSERT_OK(Init()); ArtifactType type = ParseTextProtoOrDie("name: 'test_type'"); diff --git a/ml_metadata/metadata_store/query_config_executor.cc b/ml_metadata/metadata_store/query_config_executor.cc index 0f16eace1..a14ad7d24 100644 --- a/ml_metadata/metadata_store/query_config_executor.cc +++ b/ml_metadata/metadata_store/query_config_executor.cc @@ -495,4 +495,14 @@ tensorflow::Status QueryConfigExecutor::ListArtifactIDsUsingOptions( const ListOperationOptions& options, RecordSet* record_set) { return ListNodeIDsUsingOptions(options, record_set); } + +tensorflow::Status QueryConfigExecutor::ListExecutionIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) { + return ListNodeIDsUsingOptions(options, record_set); +} + +tensorflow::Status QueryConfigExecutor::ListContextIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) { + return ListNodeIDsUsingOptions(options, record_set); +} } // namespace ml_metadata diff --git a/ml_metadata/metadata_store/query_config_executor.h b/ml_metadata/metadata_store/query_config_executor.h index ac7901a23..0200121a3 100644 --- a/ml_metadata/metadata_store/query_config_executor.h +++ b/ml_metadata/metadata_store/query_config_executor.h @@ -493,6 +493,12 @@ class QueryConfigExecutor : public QueryExecutor { tensorflow::Status ListArtifactIDsUsingOptions( const ListOperationOptions& options, RecordSet* record_set) final; + tensorflow::Status ListExecutionIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) final; + + tensorflow::Status ListContextIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) final; + private: // Utility method to bind an nullable value. template diff --git a/ml_metadata/metadata_store/query_executor.h b/ml_metadata/metadata_store/query_executor.h index 0258f9ea3..8215009fa 100644 --- a/ml_metadata/metadata_store/query_executor.h +++ b/ml_metadata/metadata_store/query_executor.h @@ -418,6 +418,20 @@ class QueryExecutor { // On success `record_set` is updated with artifact IDs based on `options` virtual tensorflow::Status ListArtifactIDsUsingOptions( const ListOperationOptions& options, RecordSet* record_set) = 0; + + // List Execution IDs using `options`. + // On success `set` is updated with execution IDs based on `options` and + // `next_page_token` is updated with information for the caller to use for + // next page of results. + virtual tensorflow::Status ListExecutionIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) = 0; + + // List Context IDs using `options`. + // On success `set` is updated with context IDs based on `options` and + // `next_page_token` is updated with information for the caller to use for + // next page of results. + virtual tensorflow::Status ListContextIDsUsingOptions( + const ListOperationOptions& options, RecordSet* record_set) = 0; }; } // namespace ml_metadata diff --git a/ml_metadata/metadata_store/rdbms_metadata_access_object.cc b/ml_metadata/metadata_store/rdbms_metadata_access_object.cc index 79f34a3d7..8d89f1e9d 100644 --- a/ml_metadata/metadata_store/rdbms_metadata_access_object.cc +++ b/ml_metadata/metadata_store/rdbms_metadata_access_object.cc @@ -1181,10 +1181,11 @@ tensorflow::Status RDBMSMetadataAccessObject::ListNodes( TF_RETURN_IF_ERROR( executor_->ListArtifactIDsUsingOptions(updated_options, &record_set)); } else if (std::is_same::value) { - return tensorflow::errors::Unimplemented( - "ListExecutions not yet supported."); + TF_RETURN_IF_ERROR( + executor_->ListExecutionIDsUsingOptions(updated_options, &record_set)); } else if (std::is_same::value) { - return tensorflow::errors::Unimplemented("ListContexts not yet supported."); + TF_RETURN_IF_ERROR( + executor_->ListContextIDsUsingOptions(updated_options, &record_set)); } else { return tensorflow::errors::InvalidArgument( "Invalid Node passed to ListNodes"); @@ -1210,6 +1211,18 @@ tensorflow::Status RDBMSMetadataAccessObject::ListArtifacts( return ListNodes(options, artifacts, next_page_token); } +tensorflow::Status RDBMSMetadataAccessObject::ListExecutions( + const ListOperationOptions& options, std::vector* executions, + std::string* next_page_token) { + return ListNodes(options, executions, next_page_token); +} + +tensorflow::Status RDBMSMetadataAccessObject::ListContexts( + const ListOperationOptions& options, std::vector* contexts, + std::string* next_page_token) { + return ListNodes(options, contexts, next_page_token); +} + tensorflow::Status RDBMSMetadataAccessObject::FindArtifactByTypeIdAndArtifactName( const int64 type_id, const absl::string_view name, Artifact* artifact) { diff --git a/ml_metadata/metadata_store/rdbms_metadata_access_object.h b/ml_metadata/metadata_store/rdbms_metadata_access_object.h index 7fae1eaf1..939dee8e3 100644 --- a/ml_metadata/metadata_store/rdbms_metadata_access_object.h +++ b/ml_metadata/metadata_store/rdbms_metadata_access_object.h @@ -119,6 +119,14 @@ class RDBMSMetadataAccessObject : public MetadataAccessObject { std::vector* artifacts, std::string* next_page_token) final; + tensorflow::Status ListExecutions(const ListOperationOptions& options, + std::vector* executions, + std::string* next_page_token) final; + + tensorflow::Status ListContexts(const ListOperationOptions& options, + std::vector* contexts, + std::string* next_page_token) final; + tensorflow::Status FindArtifactsByTypeId( int64 artifact_type_id, std::vector* artifacts) final;