Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322388863
  • Loading branch information
ml-metadata-team authored and tf-metadata-team committed Jul 21, 2020
1 parent 1a0e05e commit bb5d116
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 3 deletions.
34 changes: 34 additions & 0 deletions ml_metadata/metadata_store/metadata_access_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,40 @@ class MetadataAccessObject {
std::vector<Artifact>* artifacts,
std::string* next_page_token) = 0;

// Queries executions stored in the metadata source using `options`.
// `options` is the ListOperationOptions proto message defined
// in metadata_store.
// If successfull:
// 1. `executions` is updated with result set of size determined by
// max_result_size set in `options`.
// 2. `next_page_token` is populated with information necessary to fetch next
// page of results.
// RETURNS INVALID_ARGUMENT if the `options` is invalid with one of
// the cases:
// 1. order_by_field is not set or has an unspecified field.
// 2. Direction of ordering is not specified for the order_by_field.
// 3. next_page_token cannot be decoded.
virtual tensorflow::Status ListExecutions(const ListOperationOptions& options,
std::vector<Execution>* executions,
std::string* next_page_token) = 0;

// Queries contexts stored in the metadata source using `options`.
// `options` is the ListOperationOptions proto message defined
// in metadata_store.
// If successfull:
// 1. `contexts` is updated with result set of size determined by
// max_result_size set in `options`.
// 2. `next_page_token` is populated with information necessary to fetch next
// page of results.
// RETURNS INVALID_ARGUMENT if the `options` is invalid with one of
// the cases:
// 1. order_by_field is not set or has an unspecified field.
// 2. Direction of ordering is not specified for the order_by_field.
// 3. next_page_token cannot be decoded.
virtual tensorflow::Status ListContexts(const ListOperationOptions& options,
std::vector<Context>* contexts,
std::string* next_page_token) = 0;

// Queries an artifact by its type_id and name.
// Returns NOT_FOUND error, if no artifact can be found.
// Returns detailed INTERNAL error, if query execution fails.
Expand Down
274 changes: 274 additions & 0 deletions ml_metadata/metadata_store/metadata_access_object_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,280 @@ TEST_P(MetadataAccessObjectTest, ListArtifactsWithInvalidNextPageToken) {
tensorflow::error::INVALID_ARGUMENT);
}

TEST_P(MetadataAccessObjectTest, ListExecutionsWithNonIdFieldOptions) {
TF_ASSERT_OK(Init());
ExecutionType type = ParseTextProtoOrDie<ExecutionType>(R"(
name: 'test_type'
properties { key: 'property_1' value: INT }
properties { key: 'property_2' value: DOUBLE }
properties { key: 'property_3' value: STRING }
)");
int64 type_id;
TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id));

Execution sample_execution = ParseTextProtoOrDie<Execution>(R"(
properties {
key: 'property_1'
value: { int_value: 3 }
}
properties {
key: 'property_2'
value: { double_value: 3.0 }
}
properties {
key: 'property_3'
value: { string_value: '3' }
}
custom_properties {
key: 'custom_property_1'
value: { string_value: '5' }
}
)");
sample_execution.set_type_id(type_id);
const int total_stored_executions = 6;
int64 last_stored_execution_id;

for (int i = 0; i < total_stored_executions; i++) {
TF_ASSERT_OK(metadata_access_object_->CreateExecution(
sample_execution, &last_stored_execution_id));
}

const int page_size = 2;
ListOperationOptions list_options =
ParseTextProtoOrDie<ListOperationOptions>(R"(
max_result_size: 2,
order_by_field: { field: CREATE_TIME is_asc: false }
)");

int64 expected_execution_id = last_stored_execution_id;
std::string next_page_token;

do {
std::vector<Execution> got_executions;
TF_ASSERT_OK(metadata_access_object_->ListExecutions(
list_options, &got_executions, &next_page_token));
EXPECT_TRUE(got_executions.size() <= page_size);
for (const Execution& execution : got_executions) {
sample_execution.set_id(expected_execution_id--);

EXPECT_THAT(execution, EqualsProto(sample_execution, /*ignore_fields=*/{
"create_time_since_epoch",
"last_update_time_since_epoch"}));
}
list_options.set_next_page_token(next_page_token);
} while (!next_page_token.empty());

EXPECT_EQ(expected_execution_id, 0);
}

TEST_P(MetadataAccessObjectTest, ListExecutionsWithIdFieldOptions) {
TF_ASSERT_OK(Init());
ExecutionType type = ParseTextProtoOrDie<ExecutionType>(R"(
name: 'test_type'
properties { key: 'property_1' value: INT }
)");
int64 type_id;
TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id));

Execution sample_execution = ParseTextProtoOrDie<Execution>(R"(
properties {
key: 'property_1'
value: { int_value: 3 }
}
custom_properties {
key: 'custom_property_1'
value: { string_value: '5' }
}
)");

sample_execution.set_type_id(type_id);
int stored_executions_count = 0;
int64 first_execution_id;
TF_ASSERT_OK(metadata_access_object_->CreateExecution(sample_execution,
&first_execution_id));
stored_executions_count++;

for (int i = 0; i < 6; i++) {
int64 unused_execution_id;
TF_ASSERT_OK(metadata_access_object_->CreateExecution(
sample_execution, &unused_execution_id));
}
stored_executions_count += 6;

const int page_size = 2;
ListOperationOptions list_options =
ParseTextProtoOrDie<ListOperationOptions>(R"(
max_result_size: 2,
order_by_field: { field: ID is_asc: true }
)");

std::string next_page_token;
int64 expected_execution_id = first_execution_id;
int seen_executions_count = 0;
do {
std::vector<Execution> got_executions;
TF_ASSERT_OK(metadata_access_object_->ListExecutions(
list_options, &got_executions, &next_page_token));
EXPECT_TRUE(got_executions.size() <= page_size);
for (const Execution& execution : got_executions) {
sample_execution.set_id(expected_execution_id++);

EXPECT_THAT(execution, EqualsProto(sample_execution, /*ignore_fields=*/{
"create_time_since_epoch",
"last_update_time_since_epoch"}));
seen_executions_count++;
}
list_options.set_next_page_token(next_page_token);
} while (!next_page_token.empty());

EXPECT_EQ(stored_executions_count, seen_executions_count);
}

TEST_P(MetadataAccessObjectTest, ListContextsWithNonIdFieldOptions) {
TF_ASSERT_OK(Init());
ContextType type = ParseTextProtoOrDie<ContextType>(R"(
name: 'test_type'
properties { key: 'property_1' value: INT }
properties { key: 'property_2' value: DOUBLE }
properties { key: 'property_3' value: STRING }
)");
int64 type_id;
TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id));

Context sample_context = ParseTextProtoOrDie<Context>(R"(
properties {
key: 'property_1'
value: { int_value: 3 }
}
properties {
key: 'property_2'
value: { double_value: 3.0 }
}
properties {
key: 'property_3'
value: { string_value: '3' }
}
custom_properties {
key: 'custom_property_1'
value: { string_value: '5' }
}
)");
sample_context.set_type_id(type_id);
int64 last_stored_context_id;
int context_name_suffix = 0;
sample_context.set_name("list_contexts_test-1");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&last_stored_context_id));

context_name_suffix++;
sample_context.set_name("list_contexts_test-2");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&last_stored_context_id));
context_name_suffix++;
sample_context.set_name("list_contexts_test-3");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&last_stored_context_id));
context_name_suffix++;

const int page_size = 2;
ListOperationOptions list_options =
ParseTextProtoOrDie<ListOperationOptions>(R"(
max_result_size: 2,
order_by_field: { field: CREATE_TIME is_asc: false }
)");

int64 expected_context_id = last_stored_context_id;
std::string next_page_token;

do {
std::vector<Context> got_contexts;
TF_ASSERT_OK(metadata_access_object_->ListContexts(
list_options, &got_contexts, &next_page_token));
EXPECT_TRUE(got_contexts.size() <= page_size);
for (const Context& context : got_contexts) {
sample_context.set_name(
absl::StrCat("list_contexts_test-", context_name_suffix--));
sample_context.set_id(expected_context_id--);
EXPECT_THAT(context, EqualsProto(sample_context, /*ignore_fields=*/{
"create_time_since_epoch",
"last_update_time_since_epoch"}));
}
list_options.set_next_page_token(next_page_token);
} while (!next_page_token.empty());

EXPECT_EQ(expected_context_id, 0);
}

TEST_P(MetadataAccessObjectTest, ListContextsWithIdFieldOptions) {
TF_ASSERT_OK(Init());
ContextType type = ParseTextProtoOrDie<ContextType>(R"(
name: 'test_type'
properties { key: 'property_1' value: INT }
)");
int64 type_id;
TF_ASSERT_OK(metadata_access_object_->CreateType(type, &type_id));

Context sample_context = ParseTextProtoOrDie<Context>(R"(
properties {
key: 'property_1'
value: { int_value: 3 }
}
custom_properties {
key: 'custom_property_1'
value: { string_value: '5' }
}
)");

sample_context.set_type_id(type_id);
int stored_contexts_count = 0;
int64 first_context_id;
sample_context.set_name("list_contexts_test-1");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&first_context_id));

int64 unused_context_id;
stored_contexts_count++;
sample_context.set_name("list_contexts_test-2");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&unused_context_id));
stored_contexts_count++;
sample_context.set_name("list_contexts_test-3");
TF_ASSERT_OK(metadata_access_object_->CreateContext(sample_context,
&unused_context_id));
stored_contexts_count++;

const int page_size = 2;
ListOperationOptions list_options =
ParseTextProtoOrDie<ListOperationOptions>(R"(
max_result_size: 2,
order_by_field: { field: ID is_asc: true }
)");

std::string next_page_token;
int64 expected_context_id = first_context_id;
int expected_context_name_suffix = 1;
int seen_contexts_count = 0;
do {
std::vector<Context> got_contexts;
TF_ASSERT_OK(metadata_access_object_->ListContexts(
list_options, &got_contexts, &next_page_token));
EXPECT_TRUE(got_contexts.size() <= page_size);
for (const Context& context : got_contexts) {
sample_context.set_name(
absl::StrCat("list_contexts_test-", expected_context_name_suffix++));
sample_context.set_id(expected_context_id++);

EXPECT_THAT(context, EqualsProto(sample_context, /*ignore_fields=*/{
"create_time_since_epoch",
"last_update_time_since_epoch"}));
seen_contexts_count++;
}
list_options.set_next_page_token(next_page_token);
} while (!next_page_token.empty());

EXPECT_EQ(stored_contexts_count, seen_contexts_count);
}

TEST_P(MetadataAccessObjectTest, DefaultArtifactState) {
TF_ASSERT_OK(Init());
ArtifactType type = ParseTextProtoOrDie<ArtifactType>("name: 'test_type'");
Expand Down
10 changes: 10 additions & 0 deletions ml_metadata/metadata_store/query_config_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,4 +495,14 @@ tensorflow::Status QueryConfigExecutor::ListArtifactIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) {
return ListNodeIDsUsingOptions<Artifact>(options, record_set);
}

tensorflow::Status QueryConfigExecutor::ListExecutionIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) {
return ListNodeIDsUsingOptions<Execution>(options, record_set);
}

tensorflow::Status QueryConfigExecutor::ListContextIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) {
return ListNodeIDsUsingOptions<Context>(options, record_set);
}
} // namespace ml_metadata
6 changes: 6 additions & 0 deletions ml_metadata/metadata_store/query_config_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ class QueryConfigExecutor : public QueryExecutor {
tensorflow::Status ListArtifactIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) final;

tensorflow::Status ListExecutionIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) final;

tensorflow::Status ListContextIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) final;

private:
// Utility method to bind an nullable value.
template <typename T>
Expand Down
14 changes: 14 additions & 0 deletions ml_metadata/metadata_store/query_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,20 @@ class QueryExecutor {
// On success `record_set` is updated with artifact IDs based on `options`
virtual tensorflow::Status ListArtifactIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) = 0;

// List Execution IDs using `options`.
// On success `set` is updated with execution IDs based on `options` and
// `next_page_token` is updated with information for the caller to use for
// next page of results.
virtual tensorflow::Status ListExecutionIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) = 0;

// List Context IDs using `options`.
// On success `set` is updated with context IDs based on `options` and
// `next_page_token` is updated with information for the caller to use for
// next page of results.
virtual tensorflow::Status ListContextIDsUsingOptions(
const ListOperationOptions& options, RecordSet* record_set) = 0;
};

} // namespace ml_metadata
Expand Down
Loading

0 comments on commit bb5d116

Please sign in to comment.