Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-datatonic committed Nov 15, 2023
1 parent 7c932b9 commit 07fb10b
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 30 deletions.
13 changes: 6 additions & 7 deletions components/tests/test_lookup_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@


@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model(mock_model, tmpdir):
def test_lookup_model(mock_model, tmp_path):
"""
Assert lookup_model produces expected resource name, and that list method is
called with the correct arguemnts
"""

# Mock attribute and method
mock_path = tmpdir
mock_path = str(tmp_path / "model")
mock_model.resource_name = "my-model-resource-name"
mock_model.uri = mock_path
mock_model.list.return_value = [mock_model]
Expand All @@ -54,7 +54,7 @@ def test_lookup_model(mock_model, tmpdir):


@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models(mock_model, tmpdir):
def test_lookup_model_when_no_models(mock_model, tmp_path):
"""
Checks that when there are no models and fail_on_model_found = False,
lookup_model returns an empty string.
Expand All @@ -65,15 +65,14 @@ def test_lookup_model_when_no_models(mock_model, tmpdir):
location="europe-west4",
project="my-project-id",
fail_on_model_not_found=False,
model=Model(uri=str(tmpdir)),
model=Model(uri=str(tmp_path / "model")),
)

print(exported_model_resource_name)
assert exported_model_resource_name == ""


@mock.patch("google.cloud.aiplatform.Model")
def test_lookup_model_when_no_models_fail(mock_model, tmpdir):
def test_lookup_model_when_no_models_fail(mock_model, tmp_path):
"""
Checks that when there are no models and fail_on_model_found = True,
lookup_model raises a RuntimeError.
Expand All @@ -87,5 +86,5 @@ def test_lookup_model_when_no_models_fail(mock_model, tmpdir):
location="europe-west4",
project="my-project-id",
fail_on_model_not_found=True,
model=Model(uri=str(tmpdir)),
model=Model(uri=str(tmp_path / "model")),
)
43 changes: 23 additions & 20 deletions components/tests/test_model_batch_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
def test_model_batch_predict(
create_job,
get_job,
tmpdir,
tmp_path,
source_format,
destination_format,
source_uri,
Expand All @@ -68,24 +68,27 @@ def test_model_batch_predict(
"""
Asserts model_batch_predict successfully creates requests given different arguments.
"""
mock_model = Model(uri=tmpdir, metadata={"resourceName": ""})
gcp_resources_path = tmpdir / "gcp_resources.json"
mock_model = Model(uri=str(tmp_path / "model"), metadata={"resourceName": ""})
gcp_resources_path = tmp_path / "gcp_resources.json"

model_batch_predict(
model=mock_model,
job_display_name="",
location="",
project="",
source_uri=source_uri,
destination_uri=destination_format,
source_format=source_format,
destination_format=destination_format,
monitoring_training_dataset=monitoring_training_dataset,
monitoring_alert_email_addresses=monitoring_alert_email_addresses,
monitoring_skew_config=monitoring_skew_config,
gcp_resources=str(gcp_resources_path),
)
try:
model_batch_predict(
model=mock_model,
job_display_name="",
location="",
project="",
source_uri=source_uri,
destination_uri=destination_format,
source_format=source_format,
destination_format=destination_format,
monitoring_training_dataset=monitoring_training_dataset,
monitoring_alert_email_addresses=monitoring_alert_email_addresses,
monitoring_skew_config=monitoring_skew_config,
gcp_resources=str(gcp_resources_path),
)

create_job.assert_called_once()
get_job.assert_called_once()
assert gcp_resources_path.exists()
create_job.assert_called_once()
get_job.assert_called_once()
assert gcp_resources_path.exists()
finally:
gcp_resources_path.unlink(missing_ok=True)
5 changes: 4 additions & 1 deletion docs/PRODUCTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ cloud_schedulers_config = {
template_path = "https://<GCP region>-kfp.pkg.dev/<Project ID of test project>/vertex-pipelines/xgboost-train-pipeline/v1.2"
enable_caching = null
pipeline_parameters = {
// Add pipeline parameters which are expected by your pipeline here e.g.
// project = "my-project-id"
},
},
Expand All @@ -97,7 +99,8 @@ cloud_schedulers_config = {
template_path = "https://<GCP region>-kfp.pkg.dev/<Project ID of test project>/vertex-pipelines/xgboost-prediction-pipeline/v1.2"
enable_caching = null
pipeline_parameters = {
// TODO: add all pipeline parameters which are expected by your pipeline
// Add pipeline parameters which are expected by your pipeline here e.g.
// project = "my-project-id"
},
},
Expand Down
4 changes: 2 additions & 2 deletions pipelines/src/pipelines/utils/trigger_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def trigger_pipeline(
enable_caching = enable_caching.lower() in true_

# For below options, we want an empty string to become None, so we add "or None"
encryption_spec_key_name = os.environ.get("VERTEX_CMEK_IDENTIFIER")
network = os.environ.get("VERTEX_NETWORK")
encryption_spec_key_name = os.environ.get("VERTEX_CMEK_IDENTIFIER") or None
network = os.environ.get("VERTEX_NETWORK") or None

# Instantiate PipelineJob object
pl = aiplatform.PipelineJob(
Expand Down

0 comments on commit 07fb10b

Please sign in to comment.