Skip to content

Commit

Permalink
Merge pull request #965 from CitrineInformatics/use-dw-branch-root-an…
Browse files Browse the repository at this point in the history
…d-version

Address DW with branch root ID and version.
  • Loading branch information
anoto-moniz authored Sep 20, 2024
2 parents 0e454fe + ede810b commit e5355f0
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 127 deletions.
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.6.0"
__version__ = "3.7.0"
35 changes: 6 additions & 29 deletions src/citrine/informatics/workflows/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
predictor_version = properties.Optional(
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id')
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
""":Optional[int]: Version number of the branch that contains this workflow."""

status_description = properties.String('status_description', serializable=False)
""":str: more detailed description of the workflow's status"""
typ = properties.String('type', default='DesignWorkflow', deserializable=False)

_branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id',
serializable=False, deserializable=False)
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
_branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version',
serializable=False, deserializable=False)
""":Optional[int]: Version number of the branch that contains this workflow."""
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id',
serializable=False)

def __init__(self,
name: str,
Expand All @@ -68,25 +67,3 @@ def design_executions(self) -> DesignExecutionCollection:
raise AttributeError('Cannot initialize execution without project reference!')
return DesignExecutionCollection(
project_id=self.project_id, session=self._session, workflow_id=self.uid)

@property
def branch_root_id(self):
"""Retrieve the root ID of the branch this workflow is on."""
return self._branch_root_id

@branch_root_id.setter
def branch_root_id(self, value):
"""Set the root ID of the branch this workflow is on."""
self._branch_root_id = value
self._branch_id = None

@property
def branch_version(self):
"""Retrieve the version of the branch this workflow is on."""
return self._branch_version

@branch_version.setter
def branch_version(self, value):
"""Set the version of the branch this workflow is on."""
self._branch_version = value
self._branch_id = None
47 changes: 10 additions & 37 deletions src/citrine/resources/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from citrine._rest.collection import Collection
from citrine._session import Session
from citrine.exceptions import NotFound
from citrine.informatics.workflows import DesignWorkflow
from citrine.resources.response import Response
from functools import partial
Expand All @@ -31,25 +30,6 @@ def __init__(self,
self.branch_root_id = branch_root_id
self.branch_version = branch_version

def _resolve_branch_root_and_version(self, workflow):
from citrine.resources.branch import BranchCollection

workflow_copy = deepcopy(workflow)
bc = BranchCollection(self.project_id, self.session)
branch = bc.get_by_version_id(version_id=workflow_copy._branch_id)
workflow_copy._branch_root_id = branch.root_id
workflow_copy._branch_version = branch.version
return workflow_copy

def _resolve_branch_id(self, root_id, version):
from citrine.resources.branch import BranchCollection

if root_id and version:
bc = BranchCollection(self.project_id, self.session)
branch = bc.get(root_id=root_id, version=version)
return branch.uid
return None

def register(self, model: DesignWorkflow) -> DesignWorkflow:
"""
Upload a new design workflow.
Expand Down Expand Up @@ -77,15 +57,15 @@ def register(self, model: DesignWorkflow) -> DesignWorkflow:
'project.design_workflows.register().')
raise RuntimeError(msg)
else:
# branch_id is in the body of design workflow endpoints, so it must be serialized.
# This means the collection branch_id might not match the workflow branch_id. The
# collection should win out, since the user is explicitly referencing the branch
# represented by this collection.
# To avoid modifying the parameter, and to ensure the only change is the branch_id, we
# deepcopy, modify, then register it.
# branch_root_id and branch_version are in the body of design workflow endpoints, so
# they must be serialized. This means the collection fields might not match the
# workflow fields. The collection should win out, since the user is explicitly
# referencing the branch represented by this collection.
# To avoid modifying the parameter, and to ensure the only changes are the
# branch_root_id and branch_version, we deepcopy, modify, then register it.
model_copy = deepcopy(model)
model_copy._branch_id = self._resolve_branch_id(self.branch_root_id,
self.branch_version)
model_copy.branch_root_id = self.branch_root_id
model_copy.branch_version = self.branch_version
return super().register(model_copy)

def build(self, data: dict) -> DesignWorkflow:
Expand All @@ -104,7 +84,6 @@ def build(self, data: dict) -> DesignWorkflow:
"""
workflow = DesignWorkflow.build(data)
workflow = self._resolve_branch_root_and_version(workflow)
workflow._session = self.session
workflow.project_id = self.project_id
return workflow
Expand Down Expand Up @@ -137,13 +116,6 @@ def update(self, model: DesignWorkflow) -> DesignWorkflow:
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
'branch_version are set.')

try:
model._branch_id = self._resolve_branch_id(model.branch_root_id,
model.branch_version)
except NotFound:
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
'branch_version exists.')

# If executions have already been done, warn about future behavior change
executions = model.design_executions.list()
if next(executions, None) is not None:
Expand Down Expand Up @@ -197,7 +169,8 @@ def _fetch_page(self,
additional_params: Optional[dict] = None,
) -> Tuple[Iterable[dict], str]:
params = additional_params or {}
params["branch"] = self._resolve_branch_id(self.branch_root_id, self.branch_version)
params["branch_root_id"] = self.branch_root_id
params["branch_version"] = self.branch_version
return super()._fetch_page(path=path,
fetch_func=fetch_func,
page=page,
Expand Down
4 changes: 2 additions & 2 deletions src/citrine/seeding/find_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def create_or_update(*,
# Locally created design workflows likely won't have a branch ID but
# need one to be updated.
if isinstance(old_resource, DesignWorkflow):
new_resource._branch_root_id = old_resource.branch_root_id
new_resource._branch_version = old_resource.branch_version
new_resource.branch_root_id = old_resource.branch_root_id
new_resource.branch_version = old_resource.branch_version
return collection.update(new_resource)
else:
logger.info("Registering new module: {}".format(resource.name))
Expand Down
23 changes: 23 additions & 0 deletions tests/resources/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def test_branch_get(session, collection, branch_path):
assert session.last_call == FakeCall(method='GET', path=branch_path, params={'page': 1, 'per_page': 1, 'root': root_id, 'version': version})


def test_branch_get_not_found(session, collection, branch_path):
# Given
session.set_response({"response": []})

# When
with pytest.raises(NotFound):
collection.get(root_id=uuid.uuid4(), version=1)


def test_branch_get_by_version_id(session, collection, branch_path):
# Given
branch_data = BranchDataFactory()
version_id = branch_data['id']
session.set_response(branch_data)

# When
branch = collection.get_by_version_id(version_id=version_id)

# Then
assert session.num_calls == 1
assert session.last_call == FakeCall(method='GET', path=f"{branch_path}/{version_id}")


def test_branch_list(session, collection, branch_path):
# Given
branch_count = 5
Expand Down
64 changes: 18 additions & 46 deletions tests/resources/test_design_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def collection(branch_data, collection_without_branch) -> DesignWorkflowCollecti

@pytest.fixture
def workflow(collection, branch_data, design_workflow_dict) -> DesignWorkflow:
design_workflow_dict["branch_id"] = branch_data["id"]
design_workflow_dict["branch_root_id"] = branch_data["metadata"]["root_id"]
design_workflow_dict["branch_version"] = branch_data["metadata"]["version"]

collection.session.set_response(branch_data)
workflow = collection.build(design_workflow_dict)
Expand Down Expand Up @@ -71,12 +72,6 @@ def workflow_path(collection, workflow=None):
path = f'{path}/{workflow.uid}'
return path

def branches_path(collection, branch_id=None):
path = f'/projects/{collection.project_id}/branches'
if branch_id:
path = f'{path}/{branch_id}'
return path

def assert_workflow(actual, expected, *, include_branch=False):
assert actual.name == expected.name
assert actual.description == expected.description
Expand All @@ -86,7 +81,7 @@ def assert_workflow(actual, expected, *, include_branch=False):
assert actual.predictor_version == expected.predictor_version
assert actual.project_id == expected.project_id
if include_branch:
assert actual.branch_id == expected.branch_id
assert actual._branch_id == expected._branch_id
assert actual.branch_root_id == expected.branch_root_id
assert actual.branch_version == expected.branch_version

Expand All @@ -99,29 +94,22 @@ def test_basic_methods(workflow, collection, design_workflow_dict):
@pytest.mark.parametrize("optional_args", all_combination_lengths(OPTIONAL_ARGS))
def test_register(session, branch_data, workflow_minimal, collection, optional_args):
workflow = workflow_minimal
branch_id = branch_data['id']
branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
}
branch_root_id = branch_data['metadata']['root_id']
branch_version = branch_data['metadata']['version']

# Set a random value for all optional args selected for this run.
for name, factory in optional_args:
setattr(workflow, name, factory())

# Given
post_dict = {**workflow.dump(), "branch_id": str(branch_id)}
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
post_dict = {**workflow.dump(), "branch_root_id": str(branch_root_id), "branch_version": branch_version}
session.set_responses({**post_dict, 'status_description': 'status'})

# When
new_workflow = collection.register(workflow)

# Then
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
FakeCall(method='GET', path=branches_path(collection, branch_id)),
]
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]

assert new_workflow.branch_root_id == collection.branch_root_id
assert new_workflow.branch_version == collection.branch_version
Expand All @@ -133,23 +121,18 @@ def test_register_conflicting_branches(session, branch_data, workflow, collectio
old_branch_root_id = uuid.uuid4()
workflow.branch_root_id = old_branch_root_id
assert workflow.branch_root_id != collection.branch_root_id

new_branch_root_id = str(branch_data["metadata"]["root_id"])
new_branch_version = branch_data["metadata"]["version"]

branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
}
post_dict = {**workflow.dump(), "branch_id": str(branch_data["id"])}
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
post_dict = {**workflow.dump(), "branch_root_id": new_branch_root_id, "branch_version": new_branch_version}
session.set_responses({**post_dict, 'status_description': 'status'})

# When
new_workflow = collection.register(workflow)

# Then
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
FakeCall(method='GET', path=branches_path(collection, branch_data["id"])),
]
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]

assert workflow.branch_root_id == old_branch_root_id
assert new_workflow.branch_root_id == collection.branch_root_id
Expand Down Expand Up @@ -180,10 +163,10 @@ def test_delete(collection):


def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollection):
branch_data_get_resp = {"response": [branch_data]}
branch_id = uuid.UUID(branch_data['id'])
branch_root_id = uuid.UUID(branch_data['metadata']['root_id'])
branch_version = branch_data['metadata']['version']

collection.session.set_responses(branch_data_get_resp, {"response": []})
collection.session.set_responses({"response": []})

lst = list(collection.list_archived(per_page=10))
assert len(lst) == 0
Expand All @@ -192,7 +175,7 @@ def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollecti
assert collection.session.last_call == FakeCall(
method='GET',
path=expected_path,
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch': branch_id},
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch_root_id': branch_root_id, 'branch_version': branch_version},
json=None
)

Expand All @@ -213,17 +196,10 @@ def test_missing_project(design_workflow_dict):

def test_update(session, branch_data, workflow, collection_without_branch):
# Given
branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(workflow.branch_root_id), 'version': workflow.branch_version
}

post_dict = workflow.dump()
session.set_responses(
branch_data_get_resp,
{"per_page": 1, "next": "", "response": []},
{**post_dict, 'status_description': 'status'},
branch_data
)

# When
Expand All @@ -232,20 +208,16 @@ def test_update(session, branch_data, workflow, collection_without_branch):
# Then
executions_path = f'/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions'
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection_without_branch), params=branch_data_get_params),
FakeCall(method='GET', path=executions_path, params={'page': 1, 'per_page': 100}),
FakeCall(method='PUT', path=workflow_path(collection_without_branch, workflow), json=post_dict),
FakeCall(method='GET', path=branches_path(collection_without_branch, branch_data["id"])),
]
assert_workflow(new_workflow, workflow)


def test_update_failure_with_existing_execution(session, branch_data, workflow, collection_without_branch, design_execution_dict):
branch_data_get_resp = {"response": [branch_data]}
workflow.branch_root_id = uuid.uuid4()
post_dict = workflow.dump()
session.set_responses(
branch_data_get_resp,
{"per_page": 1, "next": "", "response": [design_execution_dict]},
{**post_dict, 'status_description': 'status'})

Expand Down
8 changes: 2 additions & 6 deletions tests/resources/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,15 @@ def test_build_design_workflow(session, basic_design_workflow_data):

def test_list_workflows(session, basic_design_workflow_data):
#Given
branch_data = BranchDataFactory()
branch_data_get_resp = {"response": [branch_data]}
session.set_response(branch_data)

workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session)
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}, branch_data)
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20})

# When
workflows = list(workflow_collection.list(per_page=20))

# Then
expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id),
params={'per_page': 20, 'module_type': 'DESIGN_WORKFLOW'})
assert 2 == session.num_calls
assert 1 == session.num_calls
assert len(workflows) == 1
assert isinstance(workflows[0], DesignWorkflow)
6 changes: 0 additions & 6 deletions tests/seeding/test_find_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,11 @@ def test_create_or_update_unique_found_design_workflow(session):
dw2_dict = DesignWorkflowDataFactory(branch_root_id=root_id, branch_version=version)
dw3_dict = DesignWorkflowDataFactory()
session.set_responses(
# Build (setup)
branch_data, # Find the model's branch root ID and version
# List
{"response": [branch_data]}, # Find the collection's branch version ID
{"response": [dw1_dict, dw2_dict, dw3_dict]}, # Return the design workflows
branch_data, branch_data, branch_data, # Lookup the branch root ID and version of each design workflow.
# Update
{"response": [branch_data]}, # Lookup the module's branch version ID
{"response": []}, # Check if there are any executions
dw2_dict, # Return the updated design workflow
branch_data # Lookup the updated design workflow branch root ID and version
)

collection = LocalDesignWorkflowCollection(project_id=uuid4(), session=session, branch_root_id=root_id, branch_version=version)
Expand Down

0 comments on commit e5355f0

Please sign in to comment.