Skip to content

Commit

Permalink
Allow specifying custom auth in resources (#2082)
Browse files Browse the repository at this point in the history
  • Loading branch information
joscha authored Nov 27, 2024
1 parent d9cdc6c commit aa80667
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 2 deletions.
13 changes: 11 additions & 2 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def create_resources(
client = RESTClient(
base_url=client_config["base_url"],
headers=client_config.get("headers"),
auth=create_auth(client_config.get("auth")),
auth=create_auth(endpoint_config.get("auth", client_config.get("auth"))),
paginator=create_paginator(client_config.get("paginator")),
session=client_config.get("session"),
)
Expand Down Expand Up @@ -409,7 +409,16 @@ def _validate_config(config: RESTAPIConfig) -> None:
if client_config:
auth = client_config.get("auth")
if auth:
auth = _mask_secrets(auth)
_mask_secrets(auth)
resources = c.get("resources", [])
for resource in resources:
if isinstance(resource, (str, DltResource)):
continue
if endpoint := resource.get("endpoint"):
if not isinstance(endpoint, str):
auth = endpoint.get("auth")
if auth:
_mask_secrets(auth)

validate_dict(RESTAPIConfig, c, path=".")

Expand Down
1 change: 1 addition & 0 deletions dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class Endpoint(TypedDict, total=False):
data_selector: Optional[jsonpath.TJsonPath]
response_actions: Optional[List[ResponseAction]]
incremental: Optional[IncrementalConfig]
auth: Optional[AuthConfig]


class ProcessingSteps(TypedDict):
Expand Down
26 changes: 26 additions & 0 deletions docs/website/docs/dlt-ecosystem/verified-sources/rest_api/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,32 @@ A resource configuration is used to define a [dlt resource](../../../general-usa
- `include_from_parent`: A list of fields from the parent resource to be included in the resource output. See the [resource relationships](#include-fields-from-the-parent-resource) section for more details.
- `processing_steps`: A list of [processing steps](#processing-steps-filter-and-transform-data) to filter and transform the data.
- `selected`: A flag to indicate if the resource is selected for loading. This could be useful when you want to load data only from child resources and not from the parent resource.
- `auth`: An optional `AuthConfig` instance. If passed, is used over the one defined in the [client](#client) definition. Example:
```py
from dlt.sources.helpers.rest_client.auth import HttpBasicAuth
config = {
"client": {
"auth": {
"type": "bearer",
"token": dlt.secrets["your_api_token"],
}
},
"resources": [
"resource-using-bearer-auth",
{
"name": "my-resource-with-special-auth",
"endpoint": {
# ...
"auth": HttpBasicAuth("user", dlt.secrets["your_basic_auth_password"])
},
# ...
}
]
# ...
}
```
This would use `Bearer` auth as defined in the `client` for `resource-using-bearer-auth` and `Http Basic` auth for `my-resource-with-special-auth`.
You can also pass additional resource parameters that will be used to configure the dlt resource. See [dlt resource API reference](../../../api_reference/extract/decorators#resource) for more details.
Expand Down
12 changes: 12 additions & 0 deletions tests/sources/rest_api/configurations/source_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,18 @@ def repositories():
repositories(),
],
},
{
"client": {"base_url": "https://github.com/api/v2"},
"resources": [
{
"name": "issues",
"endpoint": {
"path": "user/repos",
"auth": HttpBasicAuth("", "BASIC_AUTH_TOKEN"),
},
}
],
},
]


Expand Down
41 changes: 41 additions & 0 deletions tests/sources/rest_api/integration/test_response_actions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import pytest
from dlt.common import json
from dlt.sources.helpers.requests import Response
Expand Down Expand Up @@ -316,3 +317,43 @@ def add_field(response: Response, *args, **kwargs) -> Response:
mock_response_hook_2.assert_called_once()

assert all(record["custom_field"] == "foobar" for record in data)


def test_auth_overwrites_for_specific_endpoints(mock_api_server, mocker):
def custom_hook(response: Response, *args, **kwargs) -> Response:
assert (
response.request.headers["Authorization"]
== f"Basic {base64.b64encode(b'U:P').decode('ascii')}"
)
return response

mock_response_hook = mocker.Mock(side_effect=custom_hook)
mock_source = rest_api_source(
{
"client": {
"base_url": "https://api.example.com",
"auth": {
"type": "bearer",
"token": "T",
},
},
"resources": [
{
"name": "posts",
"endpoint": {
"auth": {
"type": "http_basic",
"username": "U",
"password": "P",
},
"response_actions": [
mock_response_hook,
],
},
},
],
}
)

list(mock_source.with_resources("posts").add_limit(1))
mock_response_hook.assert_called_once()

0 comments on commit aa80667

Please sign in to comment.