Skip to content

Commit

Permalink
Merge pull request #655 from Avaiga/feature/#490-change-submit-API
Browse files Browse the repository at this point in the history
feature/#490 submit() functions return a Submission object
  • Loading branch information
toan-quach authored Jan 17, 2024
2 parents 970fb3b + 02be686 commit 7fa780c
Show file tree
Hide file tree
Showing 32 changed files with 437 additions and 269 deletions.
3 changes: 2 additions & 1 deletion taipy/core/_entity/submittable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..common._utils import _Subscriber
from ..data.data_node import DataNode
from ..job.job import Job
from ..submission.submission import Submission
from ..task.task import Task
from ._dag import _DAG

Expand All @@ -42,7 +43,7 @@ def submit(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
):
) -> Submission:
raise NotImplementedError

def get_inputs(self) -> Set[DataNode]:
Expand Down
14 changes: 9 additions & 5 deletions taipy/core/_orchestrator/_abstract_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
# specific language governing permissions and limitations under the License.

from abc import abstractmethod
from typing import Callable, Iterable, List, Optional, Union
from typing import Callable, Iterable, Optional, Union

from .._entity.submittable import Submittable
from ..job.job import Job
from ..submission.submission import Submission
from ..task.task import Task


Expand All @@ -28,12 +30,13 @@ def initialize(cls):
@abstractmethod
def submit(
cls,
sequence,
submittable: Submittable,
callbacks: Optional[Iterable[Callable]],
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> List[Job]:
**properties,
) -> Submission:
raise NotImplementedError

@classmethod
Expand All @@ -45,10 +48,11 @@ def submit_task(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> Job:
**properties,
) -> Submission:
raise NotImplementedError

@classmethod
@abstractmethod
def cancel_job(cls, job):
def cancel_job(cls, job: Job):
raise NotImplementedError
31 changes: 23 additions & 8 deletions taipy/core/_orchestrator/_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from multiprocessing import Lock
from queue import Queue
from time import sleep
from typing import Callable, Iterable, List, Optional, Set, Union
from typing import Callable, Dict, Iterable, List, Optional, Set, Union

from taipy.config.config import Config
from taipy.logger._taipy_logger import _TaipyLogger
Expand All @@ -25,6 +25,7 @@
from ..job.job import Job
from ..job.job_id import JobId
from ..submission._submission_manager_factory import _SubmissionManagerFactory
from ..submission.submission import Submission
from ..task.task import Task
from ._abstract_orchestrator import _AbstractOrchestrator

Expand All @@ -38,6 +39,7 @@ class _Orchestrator(_AbstractOrchestrator):
blocked_jobs: List = []
lock = Lock()
__logger = _TaipyLogger._get_logger()
_submission_entities: Dict[str, Submission] = {}

@classmethod
def initialize(cls):
Expand All @@ -51,7 +53,8 @@ def submit(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> List[Job]:
**properties,
) -> Submission:
"""Submit the given `Scenario^` or `Sequence^` for an execution.
Parameters:
Expand All @@ -63,14 +66,17 @@ def submit(
finished in asynchronous mode.
timeout (Union[float, int]): The optional maximum number of seconds to wait for the jobs to be finished
before returning.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Returns:
The created Jobs.
"""
submission = _SubmissionManagerFactory._build_manager()._create(
submittable.id, # type: ignore
submittable._ID_PREFIX, # type: ignore
getattr(submittable, "config_id", None),
**properties,
)
cls._submission_entities[submission.id] = submission
jobs = []
tasks = submittable._get_sorted_tasks()
with cls.lock:
Expand All @@ -81,7 +87,7 @@ def submit(
task,
submission.id,
submission.entity_id,
callbacks=itertools.chain([submission._update_submission_status], callbacks or []),
callbacks=itertools.chain([cls._update_submission_status], callbacks or []),
force=force, # type: ignore
)
)
Expand All @@ -92,7 +98,7 @@ def submit(
else:
if wait:
cls._wait_until_job_finished(jobs, timeout=timeout)
return jobs
return submission

@classmethod
def submit_task(
Expand All @@ -102,7 +108,8 @@ def submit_task(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> Job:
**properties,
) -> Submission:
"""Submit the given `Task^` for an execution.
Parameters:
Expand All @@ -113,17 +120,21 @@ def submit_task(
in asynchronous mode.
timeout (Union[float, int]): The optional maximum number of seconds to wait for the job
to be finished before returning.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Returns:
The created `Job^`.
"""
submission = _SubmissionManagerFactory._build_manager()._create(task.id, task._ID_PREFIX, task.config_id)
submission = _SubmissionManagerFactory._build_manager()._create(
task.id, task._ID_PREFIX, task.config_id, **properties
)
submit_id = submission.id
cls._submission_entities[submission.id] = submission
with cls.lock:
job = cls._lock_dn_output_and_create_job(
task,
submit_id,
submission.entity_id,
itertools.chain([submission._update_submission_status], callbacks or []),
itertools.chain([cls._update_submission_status], callbacks or []),
force,
)
jobs = [job]
Expand All @@ -134,7 +145,11 @@ def submit_task(
else:
if wait:
cls._wait_until_job_finished(job, timeout=timeout)
return job
return submission

@classmethod
def _update_submission_status(cls, job: Job):
cls._submission_entities[job.submit_id]._update_submission_status(job)

@classmethod
def _lock_dn_output_and_create_job(
Expand Down
2 changes: 1 addition & 1 deletion taipy/core/job/_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def _create(
force=force,
version=version,
)
cls._set(job)
Notifier.publish(_make_event(job, EventOperation.CREATION))
job._on_status_change(*callbacks)
cls._set(job)
return job

@classmethod
Expand Down
10 changes: 6 additions & 4 deletions taipy/core/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import traceback
from datetime import datetime
from typing import Any, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, List, Optional

from taipy.logger._taipy_logger import _TaipyLogger

Expand All @@ -23,10 +23,12 @@
from .._version._version_manager_factory import _VersionManagerFactory
from ..common._utils import _fcts_to_dict
from ..notification.event import Event, EventEntityType, EventOperation, _make_event
from ..task.task import Task
from .job_id import JobId
from .status import Status

if TYPE_CHECKING:
from ..task.task import Task


def _run_callbacks(fn):
def __run_callbacks(job):
Expand Down Expand Up @@ -58,7 +60,7 @@ class Job(_Entity, _Labeled):
_MANAGER_NAME = "job"
_ID_PREFIX = "JOB"

def __init__(self, id: JobId, task: Task, submit_id: str, submit_entity_id: str, force=False, version=None):
def __init__(self, id: JobId, task: "Task", submit_id: str, submit_entity_id: str, force=False, version=None):
self.id = id
self._task = task
self._force = force
Expand Down Expand Up @@ -146,7 +148,7 @@ def stacktrace(self, val):
def version(self):
return self._version

def __contains__(self, task: Task):
def __contains__(self, task: "Task"):
return self.task.id == task.id

def __lt__(self, other):
Expand Down
17 changes: 13 additions & 4 deletions taipy/core/scenario/_scenario_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ..job.job import Job
from ..notification import EventEntityType, EventOperation, Notifier, _make_event
from ..submission._submission_manager_factory import _SubmissionManagerFactory
from ..submission.submission import Submission
from ..task._task_manager_factory import _TaskManagerFactory
from .scenario import Scenario
from .scenario_id import ScenarioId
Expand Down Expand Up @@ -205,7 +206,8 @@ def _submit(
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
check_inputs_are_ready: bool = True,
) -> List[Job]:
**properties,
) -> Submission:
scenario_id = scenario.id if isinstance(scenario, Scenario) else scenario
scenario = cls._get(scenario_id)
if scenario is None:
Expand All @@ -215,13 +217,20 @@ def _submit(
if check_inputs_are_ready:
_warn_if_inputs_not_ready(scenario.get_inputs())

jobs = (
submission = (
_TaskManagerFactory._build_manager()
._orchestrator()
.submit(scenario, callbacks=scenario_subscription_callback, force=force, wait=wait, timeout=timeout)
.submit(
scenario,
callbacks=scenario_subscription_callback,
force=force,
wait=wait,
timeout=timeout,
**properties,
)
)
Notifier.publish(_make_event(scenario, EventOperation.SUBMISSION))
return jobs
return submission

@classmethod
def __get_status_notifier_callbacks(cls, scenario: Scenario) -> List:
Expand Down
10 changes: 6 additions & 4 deletions taipy/core/scenario/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..job.job import Job
from ..notification import Event, EventEntityType, EventOperation, Notifier, _make_event
from ..sequence.sequence import Sequence
from ..submission.submission import Submission
from ..task.task import Task
from ..task.task_id import TaskId
from .scenario_id import ScenarioId
Expand Down Expand Up @@ -492,7 +493,8 @@ def submit(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> List[Job]:
**properties,
) -> Submission:
"""Submit this scenario for execution.
All the `Task^`s of the scenario will be submitted for execution.
Expand All @@ -505,13 +507,13 @@ def submit(
asynchronous mode.
timeout (Union[float, int]): The optional maximum number of seconds to wait for the jobs to be finished
before returning.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Returns:
A list of created `Job^`s.
A `Submission^` containing the information of the submission.
"""
from ._scenario_manager_factory import _ScenarioManagerFactory

return _ScenarioManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout)
return _ScenarioManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout, **properties)

def export(
self,
Expand Down
17 changes: 13 additions & 4 deletions taipy/core/sequence/_sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ..scenario.scenario import Scenario
from ..scenario.scenario_id import ScenarioId
from ..submission._submission_manager_factory import _SubmissionManagerFactory
from ..submission.submission import Submission
from ..task._task_manager_factory import _TaskManagerFactory
from ..task.task import Task, TaskId
from .sequence import Sequence
Expand Down Expand Up @@ -309,7 +310,8 @@ def _submit(
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
check_inputs_are_ready: bool = True,
) -> List[Job]:
**properties,
) -> Submission:
sequence_id = sequence.id if isinstance(sequence, Sequence) else sequence
sequence = cls._get(sequence_id)
if sequence is None:
Expand All @@ -319,13 +321,20 @@ def _submit(
if check_inputs_are_ready:
_warn_if_inputs_not_ready(sequence.get_inputs())

jobs = (
submission = (
_TaskManagerFactory._build_manager()
._orchestrator()
.submit(sequence, callbacks=sequence_subscription_callback, force=force, wait=wait, timeout=timeout)
.submit(
sequence,
callbacks=sequence_subscription_callback,
force=force,
wait=wait,
timeout=timeout,
**properties,
)
)
Notifier.publish(_make_event(sequence, EventOperation.SUBMISSION))
return jobs
return submission

@classmethod
def _exists(cls, entity_id: str) -> bool:
Expand Down
9 changes: 6 additions & 3 deletions taipy/core/sequence/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..exceptions.exceptions import NonExistingTask
from ..job.job import Job
from ..notification.event import Event, EventEntityType, EventOperation, _make_event
from ..submission.submission import Submission
from ..task.task import Task
from ..task.task_id import TaskId
from .sequence_id import SequenceId
Expand Down Expand Up @@ -225,7 +226,8 @@ def submit(
force: bool = False,
wait: bool = False,
timeout: Optional[Union[float, int]] = None,
) -> List[Job]:
**properties,
) -> Submission:
"""Submit the sequence for execution.
All the `Task^`s of the sequence will be submitted for execution.
Expand All @@ -238,12 +240,13 @@ def submit(
in asynchronous mode.
timeout (Union[float, int]): The maximum number of seconds to wait for the jobs to be finished before
returning.
**properties (dict[str, any]): A keyworded variable length list of additional arguments.
Returns:
A list of created `Job^`s.
A `Submission^` containing the information of the submission.
"""
from ._sequence_manager_factory import _SequenceManagerFactory

return _SequenceManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout)
return _SequenceManagerFactory._build_manager()._submit(self, callbacks, force, wait, timeout, **properties)

def get_label(self) -> str:
"""Returns the sequence simple label prefixed by its owner label.
Expand Down
2 changes: 2 additions & 0 deletions taipy/core/submission/_submission_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _entity_to_model(cls, submission: Submission) -> _SubmissionModel:
entity_type=submission.entity_type,
entity_config_id=submission._entity_config_id,
job_ids=[job.id if isinstance(job, Job) else JobId(str(job)) for job in list(submission._jobs)],
properties=submission._properties.data.copy(),
creation_date=submission._creation_date.isoformat(),
submission_status=submission._submission_status,
version=submission._version,
Expand All @@ -40,6 +41,7 @@ def _model_to_entity(cls, model: _SubmissionModel) -> Submission:
entity_config_id=model.entity_config_id,
id=SubmissionId(model.id),
jobs=model.job_ids,
properties=model.properties,
creation_date=datetime.fromisoformat(model.creation_date),
submission_status=model.submission_status,
version=model.version,
Expand Down
Loading

0 comments on commit 7fa780c

Please sign in to comment.