Skip to content

Commit

Permalink
feat: add factory support to resource sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
tdstein committed Dec 17, 2024
1 parent f4c5445 commit 39e1eb0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
7 changes: 5 additions & 2 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from requests import Response, Session
from typing_extensions import TYPE_CHECKING, overload

from posit.connect.environments import Environment
from posit.connect.packages import Package

from . import hooks, me
from .auth import Auth
from .config import Config
Expand Down Expand Up @@ -298,7 +301,7 @@ def oauth(self) -> OAuth:
@property
@requires(version="2024.11.0")
def packages(self) -> Packages:
return _PaginatedResourceSequence(self._ctx, "v1/packages", uid="name")
return _PaginatedResourceSequence[Package](self._ctx, "v1/packages", uid="name")

@property
def vanities(self) -> Vanities:
Expand All @@ -311,7 +314,7 @@ def system(self) -> System:
@property
@requires(version="2023.05.0")
def environments(self) -> Environments:
return _ResourceSequence(self._ctx, "v1/environments")
return _ResourceSequence[Environment](self._ctx, "v1/environments")

def __del__(self):
"""Close the session when the Client instance is deleted."""
Expand Down
52 changes: 28 additions & 24 deletions src/posit/connect/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import posixpath
import warnings
from abc import ABC
from typing import ItemsView, cast
from typing import ItemsView, Type, cast

from typing_extensions import (
TYPE_CHECKING,
Expand Down Expand Up @@ -92,49 +92,50 @@ def update(self, **attributes): # type: ignore[reportIncompatibleMethodOverride
super().update(**result)


T = TypeVar("T", bound=Resource)
_T = TypeVar("_T", bound=Resource)
_T_co = TypeVar("_T_co", bound=Resource, covariant=True)


class ResourceFactory(Protocol):
def __call__(self, ctx: Context, path: str, **attributes) -> Resource: ...
class ResourceFactory(Protocol[_T_co]):
def __call__(self, ctx: Context, path: str, **attributes: Any) -> _T_co: ...


class ResourceSequence(Protocol[T]):
class ResourceSequence(Protocol[_T]):
@overload
def __getitem__(self, index: SupportsIndex, /) -> T: ...
def __getitem__(self, index: SupportsIndex, /) -> _T: ...

@overload
def __getitem__(self, index: slice, /) -> List[T]: ...
def __getitem__(self, index: slice, /) -> List[_T]: ...

def __len__(self) -> int: ...

def __iter__(self) -> Iterator[T]: ...
def __iter__(self) -> Iterator[_T]: ...

def __str__(self) -> str: ...

def __repr__(self) -> str: ...


class _ResourceSequence(Sequence[T], ResourceSequence[T]):
class _ResourceSequence(Sequence[_T], ResourceSequence[_T]):
def __init__(
self,
ctx: Context,
path: str,
factory: ResourceFactory = _Resource,
factory: ResourceFactory[_T] | None = None,
uid: str = "guid",
):
self._ctx = ctx
self._path = path
self._uid = uid
self._factory = factory
self._factory = factory or cast(ResourceFactory[_T], _Resource)

def __getitem__(self, index):
return list(self.fetch())[index]

def __len__(self) -> int:
return len(list(self.fetch()))

def __iter__(self) -> Iterator[T]:
def __iter__(self) -> Iterator[_T]:
return iter(self.fetch())

def __str__(self) -> str:
Expand All @@ -143,32 +144,34 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return repr(self.fetch())

def create(self, **attributes: Any) -> T:
def create(self, **attributes: Any) -> _T:
response = self._ctx.client.post(self._path, json=attributes)
result = response.json()
uid = result[self._uid]
path = posixpath.join(self._path, uid)
return cast(T, self._factory(self._ctx, path, **result))
resource = self._factory(self._ctx, path, **result)
return resource

def fetch(self, **conditions) -> Iterable[T]:
def fetch(self, **conditions: Any) -> Iterable[_T]:
response = self._ctx.client.get(self._path, params=conditions)
results = response.json()
resources: List[T] = []
resources: List[_T] = []
for result in results:
uid = result[self._uid]
path = posixpath.join(self._path, uid)
resource = cast(T, self._factory(self._ctx, path, **result))
resource = self._factory(self._ctx, path, **result)
resources.append(resource)

return resources

def find(self, *args: str) -> T:
def find(self, *args: str) -> _T:
path = posixpath.join(self._path, *args)
response = self._ctx.client.get(path)
result = response.json()
return cast(T, self._factory(self._ctx, path, **result))
resource = self._factory(self._ctx, path, **result)
return resource

def find_by(self, **conditions) -> T | None:
def find_by(self, **conditions: Any) -> _T | None:
"""
Find the first record matching the specified conditions.
Expand All @@ -183,19 +186,20 @@ def find_by(self, **conditions) -> T | None:
Optional[T]
The first record matching the conditions, or `None` if no match is found.
"""
collection: Iterable[T] = self.fetch(**conditions)
collection = self.fetch(**conditions)
return next((v for v in collection if v.items() >= conditions.items()), None)


class _PaginatedResourceSequence(_ResourceSequence[T]):
def fetch(self, **conditions) -> Iterator[T]:
class _PaginatedResourceSequence(_ResourceSequence[_T]):
def fetch(self, **conditions: Any) -> Iterable[_T]:
paginator = Paginator(self._ctx, self._path, dict(**conditions))
for page in paginator.fetch_pages():
resources = []
results = page.results
for result in results:
uid = result[self._uid]
path = posixpath.join(self._path, uid)
resource = cast(T, self._factory(self._ctx, path, **result))
resource = self._factory(self._ctx, path, **result)

resources.append(resource)
yield from resources

0 comments on commit 39e1eb0

Please sign in to comment.