Skip to content

Commit

Permalink
Merge pull request #644 from bhearsum/wildcard
Browse files Browse the repository at this point in the history
feat: add support for deferring upstream artifact selection to runtime
  • Loading branch information
bhearsum authored Jun 24, 2024
2 parents fea42c0 + a3997f2 commit 5d5de10
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 11 deletions.
30 changes: 28 additions & 2 deletions src/scriptworker/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import asyncio
import fnmatch
import gzip
import logging
import mimetypes
Expand All @@ -15,6 +16,7 @@
import aiohttp
import arrow
import async_timeout
from taskcluster.exceptions import TaskclusterFailure

from scriptworker.client import validate_artifact_url
from scriptworker.exceptions import DownloadError, ScriptWorkerRetryException, ScriptWorkerTaskException
Expand Down Expand Up @@ -221,6 +223,16 @@ def get_artifact_url(context, task_id, path):
return url


# list_latest_artifacts {{{1
async def list_latest_artifacts(queue, task_id, exception=TaskclusterFailure):
return await queue.listLatestArtifacts(task_id)


async def retry_list_latest_artifacts(queue, task_id, exception=TaskclusterFailure, **kwargs):
kwargs.setdefault("retry_exceptions", tuple(set([TaskclusterFailure, exception])))
return await retry_async(list_latest_artifacts, args=(queue, task_id), kwargs={"exception": exception}, **kwargs)


# get_expiration_arrow {{{1
def get_expiration_arrow(context):
"""Return an arrow matching `context.task['expires']`.
Expand Down Expand Up @@ -321,8 +333,12 @@ def get_upstream_artifacts_full_paths_per_task_id(context):
for task_id, paths in task_ids_and_relative_paths:
for path in paths:
try:
path_to_add = get_and_check_single_upstream_artifact_full_path(context, task_id, path)
add_enumerable_item_to_dict(dict_=upstream_artifacts_full_paths_per_task_id, key=task_id, item=path_to_add)
if "*" in path:
for path_to_add in get_artifacts_matching_glob(context, task_id, path):
add_enumerable_item_to_dict(dict_=upstream_artifacts_full_paths_per_task_id, key=task_id, item=path_to_add)
else:
path_to_add = get_and_check_single_upstream_artifact_full_path(context, task_id, path)
add_enumerable_item_to_dict(dict_=upstream_artifacts_full_paths_per_task_id, key=task_id, item=path_to_add)
except ScriptWorkerTaskException:
if path in optional_artifacts_per_task_id.get(task_id, []):
log.warning('Optional artifact "{}" of task "{}" not found'.format(path, task_id))
Expand Down Expand Up @@ -417,3 +433,13 @@ def assert_is_parent(path, parent_dir):
p2 = Path(os.path.realpath(parent_dir))
if p1 != p2 and p2 not in p1.parents:
raise ScriptWorkerTaskException("{} is not under {}!".format(p1, p2))


def get_artifacts_matching_glob(context, task_id, pattern):
parent_dir = os.path.abspath(os.path.join(context.config["work_dir"], "cot", task_id))
matching = []
for root, _, files in os.walk(parent_dir):
for f in files:
if fnmatch.fnmatch(f, pattern):
matching.append(os.path.join(root, f))
return matching
4 changes: 3 additions & 1 deletion src/scriptworker/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def verify_task(self) -> None:
task_id: str = upstream_artifact["taskId"]
for path in upstream_artifact["paths"]:
if os.path.isabs(path) or ".." in path:
raise CoTError("upstreamArtifacts taskId {} has illegal path {}!".format(task_id, path))
raise CoTError(f"upstreamArtifacts taskId {task_id} has illegal path {path}!")
if "*" in path and not upstream_artifact.get("optional", False):
raise CoTError(f"upstreamArtifacts taskId {task_id} has globbed path {path} as a non-optional artifact!")

@property
def credentials(self) -> Optional[Dict[str, Any]]:
Expand Down
29 changes: 25 additions & 4 deletions src/scriptworker/cot/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import argparse
import asyncio
import fnmatch
import logging
import os
import pprint
Expand All @@ -23,7 +24,13 @@
from immutabledict import immutabledict
from taskcluster.aio import Queue

from scriptworker.artifacts import download_artifacts, get_artifact_url, get_optional_artifacts_per_task_id, get_single_upstream_artifact_full_path
from scriptworker.artifacts import (
download_artifacts,
get_artifact_url,
get_optional_artifacts_per_task_id,
get_single_upstream_artifact_full_path,
retry_list_latest_artifacts,
)
from scriptworker.config import apply_product_config, read_worker_creds
from scriptworker.constants import DEFAULT_CONFIG
from scriptworker.context import Context
Expand Down Expand Up @@ -762,14 +769,28 @@ async def download_cot_artifacts(chain):

mandatory_artifact_tasks = []
optional_artifact_tasks = []
latest_artifacts = {}
for task_id, paths in all_artifacts_per_task_id.items():
for path in paths:
coroutine = asyncio.ensure_future(download_cot_artifact(chain, task_id, path))
if "*" in path:
# Paths with wildcards in them indicate that the concrete
# artifact names aren't known when the task definition is
# created. For these cases, we need to fetch the list of
# artifacts from the completed tasks and then determine
# which are needed based on the pattern given.
if not latest_artifacts.get(task_id):
latest_artifacts[task_id] = (await retry_list_latest_artifacts(chain.context.queue, task_id))["artifacts"]
coroutines = []
for artifact in latest_artifacts[task_id]:
if fnmatch.fnmatch(artifact["name"], path):
coroutines.append(asyncio.ensure_future(download_cot_artifact(chain, task_id, artifact["name"])))
else:
coroutines = [asyncio.ensure_future(download_cot_artifact(chain, task_id, path))]

if is_artifact_optional(chain, task_id, path):
optional_artifact_tasks.append(coroutine)
optional_artifact_tasks.extend(coroutines)
else:
mandatory_artifact_tasks.append(coroutine)
mandatory_artifact_tasks.extend(coroutines)

mandatory_artifacts_paths = await raise_future_exceptions(mandatory_artifact_tasks)
succeeded_optional_artifacts_paths, failed_optional_artifacts = await get_results_and_future_exceptions(optional_artifact_tasks)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,38 @@ def test_get_upstream_artifacts_full_paths_per_task_id(context):
}


@pytest.mark.parametrize(
"artifacts_to_create,artifact_filenames,pattern",
(
(("file_a1", "file_b1", "file_c1"), ("file_a1", "file_b1", "file_c1"), "*"),
(("file_a1", "file_b1", "file_c1", "foo.log", "bar.log"), ("foo.log", "bar.log"), "*.log"),
),
)
def test_get_upstream_artifacts_full_paths_per_task_id_with_globs(context, artifacts_to_create, artifact_filenames, pattern):
context.task["payload"] = {
"upstreamArtifacts": [
{"paths": [pattern], "taskId": "dependency1", "taskType": "build"},
]
}

for artifact in artifacts_to_create:
folder = os.path.join(context.config["work_dir"], "cot", "dependency1", "public")

try:
os.makedirs(os.path.join(folder))
except FileExistsError:
pass
touch(os.path.join(folder, artifact))

succeeded_artifacts, failed_artifacts = get_upstream_artifacts_full_paths_per_task_id(context)

# ensure deterministic sorting here...
assert "dependency1" in succeeded_artifacts
expected = set([os.path.join(context.config["work_dir"], "cot", "dependency1", "public", f) for f in artifact_filenames])
assert set(succeeded_artifacts["dependency1"]) == expected
assert failed_artifacts == {}


def test_fail_get_upstream_artifacts_full_paths_per_task_id(context):
context.task["payload"] = {"upstreamArtifacts": [{"paths": ["public/failed_mandatory_file"], "taskId": "failedDependency", "taskType": "signing"}]}
with pytest.raises(ScriptWorkerTaskException):
Expand Down
15 changes: 11 additions & 4 deletions tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,24 @@ def test_set_event_loop(mocker):
assert rw_context.event_loop is fake_loop


def test_verify_task(claim_task):
def test_verify_task():
rw_context = swcontext.Context()
rw_context.task = {"payload": {"upstreamArtifacts": [{"taskId": "foo", "paths": ["bar"]}]}}
# should not throw
rw_context.verify_task()


@pytest.mark.parametrize("bad_path", ("/abspath/foo", "public/../../../blah"))
def test_bad_verify_task(claim_task, bad_path):
@pytest.mark.parametrize(
"upstream_artifacts",
(
[{"taskId": "bar", "paths": ["baz", "public/../../../blah"]}],
[{"taskId": "bar", "paths": ["baz", "/abspath/foo"]}],
[{"taskId": "bar", "paths": ["*"], "optional": False}],
),
)
def test_bad_verify_task(upstream_artifacts):
context = swcontext.Context()
context.task = {"payload": {"upstreamArtifacts": [{"taskId": "bar", "paths": ["baz", bad_path]}]}}
context.task = {"payload": {"upstreamArtifacts": upstream_artifacts}}
with pytest.raises(CoTError):
context.verify_task()

Expand Down
56 changes: 56 additions & 0 deletions tests/test_cot_verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,62 @@ async def fake_download(x, y, path):
assert sorted(result) == ["path1", "path2", "path3"]


# download_cot_artifacts {{{1
@pytest.mark.parametrize(
"upstreamArtifacts,expected",
(
([{"taskId": "task_id", "paths": ["*"]}], ["foo", "bar", "baz", "live.log", "test.log"]),
([{"taskId": "task_id", "paths": ["*.log"]}], ["live.log", "test.log"]),
),
)
@pytest.mark.asyncio
async def test_download_cot_artifacts_wildcard(chain, mocker, upstreamArtifacts, expected):
async def fake_download(x, y, path):
return path

async def fake_artifacts(*args, **kwargs):
return {
"artifacts": [
{
"storageType": "s3",
"name": "foo",
"expires": "2025-03-01T16:04:04.463Z",
"contentType": "text/plain",
},
{
"storageType": "s3",
"name": "bar",
"expires": "2025-03-01T16:04:04.463Z",
"contentType": "text/plain",
},
{
"storageType": "s3",
"name": "baz",
"expires": "2025-03-01T16:04:04.463Z",
"contentType": "text/plain",
},
{
"storageType": "s3",
"name": "live.log",
"expires": "2025-03-01T16:04:04.463Z",
"contentType": "text/plain",
},
{
"storageType": "s3",
"name": "test.log",
"expires": "2025-03-01T16:04:04.463Z",
"contentType": "text/plain",
},
]
}

chain.task["payload"]["upstreamArtifacts"] = upstreamArtifacts
mocker.patch.object(cotverify, "download_cot_artifact", new=fake_download)
mocker.patch.object(cotverify, "retry_list_latest_artifacts", new=fake_artifacts)
result = await cotverify.download_cot_artifacts(chain)
assert sorted(result) == sorted(expected)


# is_artifact_optional {{{1
@pytest.mark.parametrize(
"upstream_artifacts, task_id, path, expected",
Expand Down

0 comments on commit 5d5de10

Please sign in to comment.