diff --git a/src/posit/connect/client.py b/src/posit/connect/client.py index 6ef22c21..10b02c45 100644 --- a/src/posit/connect/client.py +++ b/src/posit/connect/client.py @@ -5,6 +5,9 @@ from requests import Response, Session from typing_extensions import TYPE_CHECKING, overload +from posit.connect.environments import Environment +from posit.connect.packages import Package + from . import hooks, me from .auth import Auth from .config import Config @@ -298,7 +301,7 @@ def oauth(self) -> OAuth: @property @requires(version="2024.11.0") def packages(self) -> Packages: - return _PaginatedResourceSequence(self._ctx, "v1/packages", uid="name") + return _PaginatedResourceSequence[Package](self._ctx, "v1/packages", uid="name") @property def vanities(self) -> Vanities: @@ -311,7 +314,7 @@ def system(self) -> System: @property @requires(version="2023.05.0") def environments(self) -> Environments: - return _ResourceSequence(self._ctx, "v1/environments") + return _ResourceSequence[Environment](self._ctx, "v1/environments") def __del__(self): """Close the session when the Client instance is deleted.""" diff --git a/src/posit/connect/resources.py b/src/posit/connect/resources.py index 42a59a77..50811bf4 100644 --- a/src/posit/connect/resources.py +++ b/src/posit/connect/resources.py @@ -3,7 +3,7 @@ import posixpath import warnings from abc import ABC -from typing import ItemsView, cast +from typing import ItemsView, Type, cast from typing_extensions import ( TYPE_CHECKING, @@ -92,41 +92,42 @@ def update(self, **attributes): # type: ignore[reportIncompatibleMethodOverride super().update(**result) -T = TypeVar("T", bound=Resource) +_T = TypeVar("_T", bound=Resource) +_T_co = TypeVar("_T_co", bound=Resource, covariant=True) -class ResourceFactory(Protocol): - def __call__(self, ctx: Context, path: str, **attributes) -> Resource: ... +class ResourceFactory(Protocol[_T_co]): + def __call__(self, ctx: Context, path: str, **attributes: Any) -> _T_co: ... -class ResourceSequence(Protocol[T]): +class ResourceSequence(Protocol[_T]): @overload - def __getitem__(self, index: SupportsIndex, /) -> T: ... + def __getitem__(self, index: SupportsIndex, /) -> _T: ... @overload - def __getitem__(self, index: slice, /) -> List[T]: ... + def __getitem__(self, index: slice, /) -> List[_T]: ... def __len__(self) -> int: ... - def __iter__(self) -> Iterator[T]: ... + def __iter__(self) -> Iterator[_T]: ... def __str__(self) -> str: ... def __repr__(self) -> str: ... -class _ResourceSequence(Sequence[T], ResourceSequence[T]): +class _ResourceSequence(Sequence[_T], ResourceSequence[_T]): def __init__( self, ctx: Context, path: str, - factory: ResourceFactory = _Resource, + factory: ResourceFactory[_T] | None = None, uid: str = "guid", ): self._ctx = ctx self._path = path self._uid = uid - self._factory = factory + self._factory = factory or cast(ResourceFactory[_T], _Resource) def __getitem__(self, index): return list(self.fetch())[index] @@ -134,7 +135,7 @@ def __getitem__(self, index): def __len__(self) -> int: return len(list(self.fetch())) - def __iter__(self) -> Iterator[T]: + def __iter__(self) -> Iterator[_T]: return iter(self.fetch()) def __str__(self) -> str: @@ -143,32 +144,34 @@ def __str__(self) -> str: def __repr__(self) -> str: return repr(self.fetch()) - def create(self, **attributes: Any) -> T: + def create(self, **attributes: Any) -> _T: response = self._ctx.client.post(self._path, json=attributes) result = response.json() uid = result[self._uid] path = posixpath.join(self._path, uid) - return cast(T, self._factory(self._ctx, path, **result)) + resource = self._factory(self._ctx, path, **result) + return resource - def fetch(self, **conditions) -> Iterable[T]: + def fetch(self, **conditions: Any) -> Iterable[_T]: response = self._ctx.client.get(self._path, params=conditions) results = response.json() - resources: List[T] = [] + resources: List[_T] = [] for result in results: uid = result[self._uid] path = posixpath.join(self._path, uid) - resource = cast(T, self._factory(self._ctx, path, **result)) + resource = self._factory(self._ctx, path, **result) resources.append(resource) return resources - def find(self, *args: str) -> T: + def find(self, *args: str) -> _T: path = posixpath.join(self._path, *args) response = self._ctx.client.get(path) result = response.json() - return cast(T, self._factory(self._ctx, path, **result)) + resource = self._factory(self._ctx, path, **result) + return resource - def find_by(self, **conditions) -> T | None: + def find_by(self, **conditions: Any) -> _T | None: """ Find the first record matching the specified conditions. @@ -183,12 +186,12 @@ def find_by(self, **conditions) -> T | None: Optional[T] The first record matching the conditions, or `None` if no match is found. """ - collection: Iterable[T] = self.fetch(**conditions) + collection = self.fetch(**conditions) return next((v for v in collection if v.items() >= conditions.items()), None) -class _PaginatedResourceSequence(_ResourceSequence[T]): - def fetch(self, **conditions) -> Iterator[T]: +class _PaginatedResourceSequence(_ResourceSequence[_T]): + def fetch(self, **conditions: Any) -> Iterable[_T]: paginator = Paginator(self._ctx, self._path, dict(**conditions)) for page in paginator.fetch_pages(): resources = [] @@ -196,6 +199,7 @@ def fetch(self, **conditions) -> Iterator[T]: for result in results: uid = result[self._uid] path = posixpath.join(self._path, uid) - resource = cast(T, self._factory(self._ctx, path, **result)) + resource = self._factory(self._ctx, path, **result) + resources.append(resource) yield from resources