Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async API #598

Merged
merged 20 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codegen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ In some cases we may want to deviate from the WebGPU API, because well ... Pytho
Other changes include:

* Where in JS the input args are provided via a dict, we use kwargs directly. Nevertheless, some input args have subdicts (and sub-sub-dicts)
* For methods that are async in IDL, we also provide sync methods. The Async method names have an "_async" suffix.
* For methods that are async in JavaScript (i.e return a `Promise`), we provide both an asynchronous and synchronous variant, indicated by an `_async` and `_sync` suffix.

### Codegen summary

Expand Down
168 changes: 134 additions & 34 deletions codegen/apipatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def patch_properties(self, classname, i1, i2):
elif "@apidiff.hide" in pre_lines:
pass # continue as normal
old_line = self.lines[j1]
new_line = f" def {propname}(self):"
new_line = self.get_property_def(classname, propname)
if old_line != new_line:
fixme_line = " # FIXME: was " + old_line.split("def ", 1)[-1]
lines = [fixme_line, new_line]
Expand Down Expand Up @@ -241,7 +241,7 @@ def get_missing_properties(self, classname, seen_props):
if propname not in seen_props:
lines.append(" # FIXME: new prop to implement")
lines.append(" @property")
lines.append(f" def {propname}(self):")
lines.append(self.get_property_def(classname, propname))
lines.append(" raise NotImplementedError()")
lines.append("")
return lines
Expand All @@ -265,16 +265,105 @@ class IdlPatcherMixin:
def __init__(self):
super().__init__()
self.idl = get_idl_parser()
self.detect_async_props_and_methods()

def detect_async_props_and_methods(self):

self.async_idl_names = async_idl_names = {} # (sync-name, async-name)

for classname, interface in self.idl.classes.items():
for namedict in [interface.attributes, interface.functions]:
for name_idl, idl_line in namedict.items():
idl_result = idl_line.split(name_idl)[0]
if "Promise" in idl_result:
# We found an async property or method.
name_idl_base = name_idl
if name_idl.endswith("Async"):
name_idl_base = name_idl[:-5]
key = classname, name_idl_base
# Now we determine the kind
if name_idl_base != name_idl and name_idl_base in namedict:
# Has both
async_idl_names[key] = name_idl_base, name_idl
else:
# Only has async
async_idl_names[key] = None, name_idl

def get_idl_name_variants(self, classname, base_name):
"""Returns the names of an idl prop/method for its sync and async variant.
Either can be None.
"""
# Must be a base name, without the suffix
assert not base_name.lower().endswith(("sync", "async"))

key = classname, base_name
default = base_name, None
return self.async_idl_names.get(key, default)

def name2idl(self, classname, name_py):
"""Map a python propname/methodname to the idl variant.
Take async into account.
"""
if name_py == "__init__":
return "constructor"

# Get idl base name
if name_py.endswith(("_sync", "_async")):
name_idl_base = to_camel_case(name_py.rsplit("_", 1)[0])
else:
name_idl_base = to_camel_case(name_py)

def name2idl(self, name):
m = {"__init__": "constructor"}
name = m.get(name, name)
return to_camel_case(name)
# Get idl variant names
idl_sync, idl_async = self.get_idl_name_variants(classname, name_idl_base)

def name2py(self, name):
m = {"constructor": "__init__"}
name = m.get(name, name)
return to_snake_case(name)
# Triage
if idl_sync and idl_async:
if name_py.endswith("_async"):
return idl_async
elif name_py.endswith("_sync"):
return name_idl_base + "InvalidVariant"
else:
return idl_sync
elif idl_async:
if name_py.endswith("_async"):
return idl_async
elif name_py.endswith("_sync"):
return idl_async
else:
return name_idl_base + "InvalidVariant"
else: # idl_sync only
if name_py.endswith("_async"):
return name_idl_base + "InvalidVariant"
elif name_py.endswith("_sync"):
return name_idl_base + "InvalidVariant"
else:
return idl_sync

def name2py_names(self, classname, name_idl):
"""Map a idl propname/methodname to the python variants.
Take async into account. Returns a list with one or two names;
for async props/methods Python has the sync and the async variant.
"""

if name_idl == "constructor":
return ["__init__"]

# Get idl base name
name_idl_base = name_idl
if name_idl.endswith("Async"):
name_idl_base = name_idl[:-5]
name_py_base = to_snake_case(name_idl_base)

# Get idl variant names
idl_sync, idl_async = self.get_idl_name_variants(classname, name_idl_base)

if idl_sync and idl_async:
return [to_snake_case(idl_sync), name_py_base + "_async"]
elif idl_async:
return [name_py_base + "_sync", name_py_base + "_async"]
else:
assert idl_sync == name_idl_base
return [name_py_base]

def class_is_known(self, classname):
return classname in self.idl.classes
Expand All @@ -295,22 +384,28 @@ def get_class_def(self, classname):
bases = "" if not bases else f"({', '.join(bases)})"
return f"class {classname}{bases}:"

def get_property_def(self, classname, propname):
attributes = self.idl.classes[classname].attributes
name_idl = self.name2idl(classname, propname)
assert name_idl in attributes

line = "def " + to_snake_case(propname) + "(self):"
if propname.endswith("_async"):
line = "async " + line
return " " + line

def get_method_def(self, classname, methodname):
# Get the corresponding IDL line
functions = self.idl.classes[classname].functions
name_idl = self.name2idl(methodname)
if methodname.endswith("_async") and name_idl not in functions:
name_idl = self.name2idl(methodname.replace("_async", ""))
elif name_idl not in functions and name_idl + "Async" in functions:
name_idl += "Async"
idl_line = functions[name_idl]
name_idl = self.name2idl(classname, methodname)
assert name_idl in functions

# Construct preamble
preamble = "def " + to_snake_case(methodname) + "("
if "async" in methodname:
if methodname.endswith("_async"):
preamble = "async " + preamble

# Get arg names and types
idl_line = functions[name_idl]
args = idl_line.split("(", 1)[1].split(")", 1)[0].split(",")
args = [arg.strip() for arg in args if arg.strip()]
raw_defaults = [arg.partition("=")[2].strip() for arg in args]
Expand Down Expand Up @@ -361,28 +456,31 @@ def _arg_from_struct_field(self, field):
return result

def prop_is_known(self, classname, propname):
propname_idl = self.name2idl(propname)
return propname_idl in self.idl.classes[classname].attributes
attributes = self.idl.classes[classname].attributes
propname_idl = self.name2idl(classname, propname)
return propname_idl if propname_idl in attributes else None

def method_is_known(self, classname, methodname):
functions = self.idl.classes[classname].functions
name_idl = self.name2idl(methodname)
if "_async" in methodname and name_idl not in functions:
name_idl = self.name2idl(methodname.replace("_async", ""))
elif name_idl not in functions and name_idl + "Async" in functions:
name_idl += "Async"
return name_idl if name_idl in functions else None
methodname_idl = self.name2idl(classname, methodname)
return methodname_idl if methodname_idl in functions else None

def get_class_names(self):
return list(self.idl.classes.keys())

def get_required_prop_names(self, classname):
propnames_idl = self.idl.classes[classname].attributes.keys()
return [self.name2py(x) for x in propnames_idl]
attributes = self.idl.classes[classname].attributes
names = []
for name_idl in attributes.keys():
names.extend(self.name2py_names(classname, name_idl))
return names

def get_required_method_names(self, classname):
methodnames_idl = self.idl.classes[classname].functions.keys()
return [self.name2py(x) for x in methodnames_idl]
functions = self.idl.classes[classname].functions
names = []
for name_idl in functions.keys():
names.extend(self.name2py_names(classname, name_idl))
return names


class BaseApiPatcher(IdlPatcherMixin, AbstractApiPatcher):
Expand All @@ -398,14 +496,16 @@ def get_class_comment(self, classname):
return None

def get_prop_comment(self, classname, propname):
if self.prop_is_known(classname, propname):
propname_idl = self.name2idl(propname)
return " # IDL: " + self.idl.classes[classname].attributes[propname_idl]
attributes = self.idl.classes[classname].attributes
name_idl = self.prop_is_known(classname, propname)
if name_idl:
return " # IDL: " + attributes[name_idl]

def get_method_comment(self, classname, methodname):
functions = self.idl.classes[classname].functions
name_idl = self.method_is_known(classname, methodname)
if name_idl:
return " # IDL: " + self.idl.classes[classname].functions[name_idl]
return " # IDL: " + functions[name_idl]


class BackendApiPatcher(AbstractApiPatcher):
Expand Down
3 changes: 1 addition & 2 deletions codegen/idlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class IdlParser:
* enums: a dict mapping the (Pythonic) enum name to a dict of field-value pairs.
* structs: a dict mapping the (Pythonic) struct name to a dict of StructField
objects.
* functions: a dict mapping the (normalized) func name to the line defining the
function.
* classes: a dict mapping the (normalized) class name an Interface object.

"""

Expand Down
56 changes: 55 additions & 1 deletion codegen/tests/test_codegen_apipatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""

from codegen.utils import blacken
from codegen.apipatcher import CommentRemover, AbstractCommentInjector
from codegen.apipatcher import CommentRemover, AbstractCommentInjector, IdlPatcherMixin


def dedent(code):
Expand Down Expand Up @@ -110,6 +110,60 @@ def eggs(self):
assert code2 == code3


def test_async_api_logic():

class Object(object):
pass

class OtherIdlPatcherMixin(IdlPatcherMixin):
def __init__(self):
cls = Object()
cls.attributes = {
"prop1": "x prop1 bla",
"prop2": "Promise<x> prop2 bla",
}
cls.functions = {
"method1": "x method1 bla",
"method2": "Promise<x> method2 bla",
"method3Async": "Promise<x> method3 bla",
"method3": "x method3 bla",
}

self.idl = Object()
self.idl.classes = {"Foo": cls}

patcher = OtherIdlPatcherMixin()
patcher.detect_async_props_and_methods()

# Normal prop
assert patcher.name2idl("Foo", "prop1") == "prop1"
assert patcher.name2idl("Foo", "prop1_sync") == "prop1InvalidVariant"
assert patcher.name2idl("Foo", "prop1_async") == "prop1InvalidVariant"

# Unknow prop, name still works
assert patcher.name2idl("Foo", "prop_unknown") == "propUnknown"

# Async prop
assert patcher.name2idl("Foo", "prop2_async") == "prop2"
assert patcher.name2idl("Foo", "prop2_sync") == "prop2"
assert patcher.name2idl("Foo", "prop2") == "prop2InvalidVariant"

# Normal method
assert patcher.name2idl("Foo", "method1") == "method1"
assert patcher.name2idl("Foo", "method1_sync") == "method1InvalidVariant"
assert patcher.name2idl("Foo", "method1_async") == "method1InvalidVariant"

# Async method
assert patcher.name2idl("Foo", "method2_async") == "method2"
assert patcher.name2idl("Foo", "method2_sync") == "method2"
assert patcher.name2idl("Foo", "method2") == "method2InvalidVariant"

# Async method that also has sync variant in JS
assert patcher.name2idl("Foo", "method3_async") == "method3Async"
assert patcher.name2idl("Foo", "method3") == "method3"
assert patcher.name2idl("Foo", "method3_sync") == "method3InvalidVariant"


if __name__ == "__main__":
for func in list(globals().values()):
if callable(func) and func.__name__.startswith("test_"):
Expand Down
18 changes: 18 additions & 0 deletions codegen/tests/test_codegen_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
""" Test some aspects of the generated code.
"""

from codegen.files import read_file


def test_async_methods_and_props():
# Test that only and all aync methods are suffixed with '_async'

for fname in ["_classes.py", "backends/wgpu_native/_api.py"]:
code = read_file(fname)
for line in code.splitlines():
line = line.strip()
if line.startswith("def "):
assert not line.endswith("_async"), line
elif line.startswith("async def "):
name = line.split("def", 1)[1].split("(")[0].strip()
assert name.endswith("_async"), line
4 changes: 2 additions & 2 deletions docs/backends.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ It also works out of the box, because the wgpu-native DLL is shipped with wgpu-p

The wgpu_native backend provides a few extra functionalities:

.. py:function:: wgpu.backends.wgpu_native.request_device(adapter, trace_path, *, label="", required_features, required_limits, default_queue)
.. py:function:: wgpu.backends.wgpu_native.request_device_sync(adapter, trace_path, *, label="", required_features, required_limits, default_queue)

An alternative to :func:`wgpu.GPUAdapter.request_adapter`, that streams a trace
of all low level calls to disk, so the visualization can be replayed (also on other systems),
Expand Down Expand Up @@ -88,7 +88,7 @@ You must tell the adapter to create a device that supports push constants,
and you must tell it the number of bytes of push constants that you are using.
Overestimating is okay::

device = adapter.request_device(
device = adapter.request_device_sync(
required_features=["push-constants"],
required_limits={"max-push-constant-size": 256},
)
Expand Down
6 changes: 3 additions & 3 deletions docs/guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ you can obtain a device.

.. code-block:: py

adapter = wgpu.gpu.request_adapter(power_preference="high-performance")
device = adapter.request_device()
adapter = wgpu.gpu.request_adapter_sync(power_preference="high-performance")
device = adapter.request_device_sync()

The ``wgpu.gpu`` object is the API entrypoint (:class:`wgpu.GPU`). It contains just a handful of functions,
including ``request_adapter()``. The device is used to create most other GPU objects.
Expand Down Expand Up @@ -232,7 +232,7 @@ You can run your application via RenderDoc, which is able to capture a
frame, including all API calls, objects and the complete pipeline state,
and display all of that information within a nice UI.

You can use ``adapter.request_device()`` to provide a directory path
You can use ``adapter.request_device_sync()`` to provide a directory path
where a trace of all API calls will be written. This trace can then be used
to re-play your use-case elsewhere (it's cross-platform).

Expand Down
2 changes: 1 addition & 1 deletion docs/start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ You can verify whether the `"DiscreteGPU"` adapters are found:
import wgpu
import pprint

for a in wgpu.gpu.enumerate_adapters():
for a in wgpu.gpu.enumerate_adapters_sync():
pprint.pprint(a.info)

If you are using a remote frame buffer via `jupyter-rfb <https://github.com/vispy/jupyter_rfb>`_ we also recommend installing the following for optimal performance:
Expand Down
Loading