Skip to content

Commit

Permalink
support async function
Browse files Browse the repository at this point in the history
  • Loading branch information
Nanguage committed Jan 1, 2025
1 parent 669d71f commit a58906f
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 5 deletions.
2 changes: 1 addition & 1 deletion executor/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .core import Engine, EngineSetting
from .job import LocalJob, ThreadJob, ProcessJob

__version__ = '0.2.7'
__version__ = '0.2.8'

__all__ = [
'Engine', 'EngineSetting',
Expand Down
5 changes: 4 additions & 1 deletion executor/engine/job/dask.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import functools
from inspect import iscoroutinefunction

from dask.distributed import Client, LocalCluster

from .base import Job
from .utils import create_generator_wrapper
from .utils import create_generator_wrapper, run_async_func
from ..utils import PortManager


Expand Down Expand Up @@ -58,6 +59,8 @@ async def run_function(self):
"""Run job with Dask."""
client = self.engine.dask_client
func = functools.partial(self.func, *self.args, **self.kwargs)
if iscoroutinefunction(func):
func = functools.partial(run_async_func, func)
fut = client.submit(func)
self._executor = fut
result = await fut
Expand Down
7 changes: 6 additions & 1 deletion executor/engine/job/local.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from inspect import iscoroutinefunction

from .base import Job
from .utils import create_generator_wrapper


class LocalJob(Job):
async def run_function(self):
"""Run job in local thread."""
res = self.func(*self.args, **self.kwargs)
if iscoroutinefunction(self.func):
res = await self.func(*self.args, **self.kwargs)
else:
res = self.func(*self.args, **self.kwargs)
return res

async def run_generator(self):
Expand Down
7 changes: 6 additions & 1 deletion executor/engine/job/process.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import functools
from inspect import iscoroutinefunction

from loky.process_executor import ProcessPoolExecutor

from .base import Job
from .utils import _gen_initializer, create_generator_wrapper
from .utils import (
_gen_initializer, create_generator_wrapper, run_async_func
)


class ProcessJob(Job):
Expand Down Expand Up @@ -45,6 +48,8 @@ def release_resource(self) -> bool:
async def run_function(self):
"""Run job in process pool."""
func = functools.partial(self.func, *self.args, **self.kwargs)
if iscoroutinefunction(func):
func = functools.partial(run_async_func, func)
self._executor = ProcessPoolExecutor(1)
loop = asyncio.get_running_loop()
fut = loop.run_in_executor(self._executor, func)
Expand Down
7 changes: 6 additions & 1 deletion executor/engine/job/thread.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import functools
from inspect import iscoroutinefunction
from concurrent.futures import ThreadPoolExecutor

from .base import Job
from .utils import _gen_initializer, create_generator_wrapper
from .utils import (
_gen_initializer, create_generator_wrapper, run_async_func
)


class ThreadJob(Job):
Expand Down Expand Up @@ -44,6 +47,8 @@ def release_resource(self) -> bool:
async def run_function(self):
"""Run job in thread pool."""
func = functools.partial(self.func, *self.args, **self.kwargs)
if iscoroutinefunction(func):
func = functools.partial(run_async_func, func)
self._executor = ThreadPoolExecutor(1)
loop = asyncio.get_running_loop()
fut = loop.run_in_executor(self._executor, func)
Expand Down
4 changes: 4 additions & 0 deletions executor/engine/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,7 @@ def create_generator_wrapper(
return AsyncGeneratorWrapper(job, fut)
else:
return SyncGeneratorWrapper(job, fut)


def run_async_func(func, *args, **kwargs):
return asyncio.run(func(*args, **kwargs))
21 changes: 21 additions & 0 deletions tests/test_dask_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,24 @@ async def gen():
assert x == i
i += 1
assert job.status == "done"


@pytest.mark.asyncio
async def test_dask_async_func():
port = PortManager.find_free_port()
cluster = LocalCluster(
dashboard_address=f":{port}",
asynchronous=True,
processes=False,
)
client = Client(cluster)
engine = Engine()
engine.dask_client = client

async def async_func(x):
return x + 1

job = DaskJob(async_func, (1,))
await engine.submit_async(job)
await job.wait_until_status("done")
assert job.result() == 2
13 changes: 13 additions & 0 deletions tests/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,16 @@ async def gen_async():
assert await g.asend(2) == 3
with pytest.raises(StopAsyncIteration):
await g.asend(3)


@pytest.mark.asyncio
async def test_async_func_job():
with Engine() as engine:
async def async_func(x):
return x + 1

for job_cls in [LocalJob, ThreadJob, ProcessJob]:
job = job_cls(async_func, (1,))
await engine.submit_async(job)
await job.wait_until_status("done")
assert job.result() == 2

0 comments on commit a58906f

Please sign in to comment.