Skip to content

Commit

Permalink
fix context manager with sequence management
Browse files Browse the repository at this point in the history
  • Loading branch information
jrobinAV committed Feb 12, 2024
1 parent 5dbc7a2 commit a71e7b4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
11 changes: 6 additions & 5 deletions taipy/core/scenario/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def add_sequence(
"""
if name in self.sequences:
raise SequenceAlreadyExists(name, self.id)
self._set_sequence(name, tasks, properties, subscribers)
Notifier.publish(_make_event(self.sequences[name], EventOperation.CREATION))
seq = self._set_sequence(name, tasks, properties, subscribers)
Notifier.publish(_make_event(seq, EventOperation.CREATION))

def update_sequence(
self,
Expand All @@ -222,16 +222,16 @@ def update_sequence(
"""
if name not in self.sequences:
raise NonExistingSequence(name, self.id)
self._set_sequence(name, tasks, properties, subscribers)
Notifier.publish(_make_event(self.sequences[name], EventOperation.UPDATE))
seq = self._set_sequence(name, tasks, properties, subscribers)
Notifier.publish(_make_event(seq, EventOperation.UPDATE))

def _set_sequence(
self,
name: str,
tasks: Union[List[Task], List[TaskId]],
properties: Optional[Dict] = None,
subscribers: Optional[List[_Subscriber]] = None,
):
) -> Sequence:
_scenario = _Reloader()._reload(self._MANAGER_NAME, self)
_scenario_task_ids = set(task.id if isinstance(task, Task) else task for task in _scenario._tasks)
_sequence_task_ids: Set[TaskId] = set(task.id if isinstance(task, Task) else task for task in tasks)
Expand All @@ -253,6 +253,7 @@ def _set_sequence(
}
)
self.sequences = _sequences # type: ignore
return seq

def add_sequences(self, sequences: Dict[str, Union[List[Task], List[TaskId]]]):
"""Add multiple sequences to the scenario.
Expand Down
23 changes: 23 additions & 0 deletions tests/core/scenario/test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,29 @@ def test_update_sequence(data_node):
assert scenario.sequences["seq_1"].properties["new_key"] == "new_value"


def test_add_rename_and_remove_sequences_within_context(data_node):
task_1 = Task("task_1", {}, print, output=[data_node])
task_2 = Task("task_2", {}, print, input=[data_node])
_TaskManagerFactory._build_manager()._set(task_1)
scenario = Scenario(config_id="scenario", tasks={task_1, task_2}, properties={})
_ScenarioManagerFactory._build_manager()._set(scenario)

with scenario as sc:
sc.add_sequence("seq_name", [task_1])
assert len(scenario.sequences) == 1
assert scenario.sequences["seq_name"].tasks == {"task_1": task_1}

with scenario as sc:
sc.update_sequence("seq_name", [task_2])
assert len(scenario.sequences) == 1
assert scenario.sequences["seq_name"].tasks == {"task_2": task_2}

with scenario as sc:
sc.remove_sequence("seq_name")
assert len(scenario.sequences) == 0



def test_add_property_to_scenario():
scenario = Scenario("foo", set(), {"key": "value"})
assert scenario.properties == {"key": "value"}
Expand Down

0 comments on commit a71e7b4

Please sign in to comment.