diff --git a/scripts/generate_deps.py b/ci/release/download_deps.py similarity index 70% rename from scripts/generate_deps.py rename to ci/release/download_deps.py index 28e162e405..6744af3642 100755 --- a/scripts/generate_deps.py +++ b/ci/release/download_deps.py @@ -36,7 +36,7 @@ import yaml SCRIPT_DIR = os.path.relpath(os.path.dirname(__file__)) -PROJ_ROOT = os.path.dirname(SCRIPT_DIR) +PROJ_ROOT = os.path.dirname(os.path.dirname(SCRIPT_DIR)) PIP_FLAGS_RE = re.compile(r"^--.*") STRIP_VER_RE = re.compile(r"^([\w|-]+).*") @@ -112,9 +112,8 @@ } # Please keep sorted -KNOWN_FIRST_PARTY = { - 'cuda-cudart', 'cuda-nvrtc', 'cuda-nvtx', 'cuda-version', 'cudf', 'mrc', 'rapids-dask-dependency', 'tritonclient' -} +KNOWN_FIRST_PARTY = frozenset( + {'cuda-cudart', 'cuda-nvrtc', 'cuda-nvtx', 'cuda-version', 'cudf', 'mrc', 'rapids-dask-dependency', 'tritonclient'}) # Some of these packages are installed via CPM (pybind11), others are transitive deps who's version is determined by # other packages but we use directly (glog), while others exist in the build environment and are statically linked @@ -168,13 +167,15 @@ logger = logging.getLogger(__file__) -def _get_repo_info(url_map: dict, tag_url_path: str, pkg_name: str, repo_name: str, pkg_version: str) -> dict | None: - try: - repo_url = url_map[repo_name] - except KeyError: - return None - - tag_formatter = GIT_TAG_FORMAT.get(repo_name, TAG_V_PREFIX) +def _get_repo_info(*, + git_tag_format, + repo_url_map: dict, + tag_url_path: str, + pkg_name: str, + repo_name: str, + pkg_version: str) -> dict[str, typing.Any]: + repo_url = repo_url_map[repo_name] + tag_formatter = git_tag_format.get(repo_name, TAG_V_PREFIX) if isinstance(tag_formatter, str): tag = tag_formatter.format(name=repo_name, version=pkg_version) else: @@ -185,23 +186,21 @@ def _get_repo_info(url_map: dict, tag_url_path: str, pkg_name: str, repo_name: s return {'packages': [pkg_name], 'tag': tag, 'tar_url': tar_url} -def _get_github_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None: - return _get_repo_info(KNOWN_GITHUB_URLS, GIT_HUB_TAG_URL_PATH, pkg_name, repo_name, pkg_version) - - -def _get_gitlab_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None: - return _get_repo_info(KNOWN_GITLAB_URLS, GIT_LAB_TAG_URL_PATH, pkg_name, repo_name, pkg_version) - - -def mk_repo_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]: +def mk_repo_urls(*, + known_first_party: frozenset[str], + known_github_urls: dict[str, str], + known_gitlab_urls: dict[str, str], + git_tag_format: dict[str, str], + package_aliases: dict[str, str], + packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]: matched = {} unmatched: list[str] = [] for (pkg_name, pkg_version) in packages: - if pkg_name in KNOWN_FIRST_PARTY: + if pkg_name in known_first_party: logger.debug("Skipping first party package: %s", pkg_name) continue - repo_name = PACKAGE_ALIASES.get(pkg_name, pkg_name) + repo_name = package_aliases.get(pkg_name, pkg_name) if repo_name != pkg_name: logger.debug("Package %s is knwon as %s", pkg_name, repo_name) @@ -212,9 +211,16 @@ def mk_repo_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any] i = 0 repo_info = None - repo_getters = (_get_github_info, _get_gitlab_info) - while repo_info is None and i < len(repo_getters): - repo_info = repo_getters[i](pkg_name, repo_name, pkg_version) + repos = ((known_github_urls, GIT_HUB_TAG_URL_PATH), (known_gitlab_urls, GIT_LAB_TAG_URL_PATH)) + while repo_info is None and i < len(repos): + (repo_url_map, tag_url_path) = repos[i] + if repo_name in repo_url_map: + repo_info = _get_repo_info(repo_url_map=repo_url_map, + pkg_name=pkg_name, + repo_name=repo_name, + pkg_version=pkg_version, + tag_url_path=tag_url_path, + git_tag_format=git_tag_format) i += 1 if repo_info is not None: @@ -409,6 +415,107 @@ def print_summary(dep_urls: dict[str, typing.Any], unmatched_packages: list[str] return missing_packages +def download_source_deps(*, + conda_yaml: str, + conda_json: str, + package_aliases: dict[str, str], + known_github_urls: dict[str, str], + known_gitlab_urls: dict[str, str], + known_first_party: frozenset[str], + git_tag_format: dict[str, str], + known_non_conda_deps: list[tuple[str, str]], + dry_run: bool = False, + verify_urls: bool = False, + download: bool = False, + download_dir: str | None = None, + extract: bool = False, + extract_dir: str | None = None) -> int: + """ + Main entry point for downloading source dependencies. + + Parameters + ---------- + conda_yaml : str + Path to the Conda environment file to read dependencies from. + conda_json : str + Path to the JSON formatted output of the resolved Conda environment. Generated by running: + `./docker/run_container_release.sh conda list --json > .tmp/container_pkgs.json` + package_aliases : dict[str, str] + Mapping of Conda package names to their upstream repo name. This is needed primarily to handle the case where + multiple Conda packages are derived from a single upstream repo. + known_github_urls : dict[str, str] + Mapping of package names to their GitHub repo URL. + known_gitlab_urls : dict[str, str] + Mapping of package names to their GitLab repo URL. This is kept separate from `known_github_urls` since they + have different tag URL formats. + known_first_party : frozenset[str] + Set of first party packages that are not downloaded. + git_tag_format : dict[str, str] + Mapping of package names to the format of tag names, by default the "v{version}" format is used. + dry_run : bool, optional + If True, do not download or extract any files, just merge dependencies, by default False + verify_urls : bool, optional + If True, verify that the URLs are valid, when `download` and `extract` are `False` this is effectively a more + verbose dry_run, by default False + download : bool, optional + If True, download the tar archives, by default False + download_dir : str, optional + Required when `download` or `extract` is True + extract : bool, optional + If True, extract the tar archives to `extract_dir`, by default False + extract_dir : str, optional + Required when `extract` is True + known_non_conda_deps : list[tuple[str, str]] + List of dependencies that are not specified in the Conda environment file, but are known to be required. + + Returns + ------- + int + Number of missing packages + """ + declared_deps = parse_env_file(conda_yaml) + resolved_conda_deps = parse_json_deps(conda_json) + + merged_deps = merge_deps(declared_deps, known_non_conda_deps, resolved_conda_deps) + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Declared Yaml deps:\n%s", pprint.pformat(sorted(declared_deps))) + logger.debug("Resolved Conda deps:\n%s", pprint.pformat(resolved_conda_deps)) + logger.debug("Merged deps:\n%s", pprint.pformat(merged_deps)) + + (dep_urls, unmatched_packages) = mk_repo_urls(known_first_party=known_first_party, + known_github_urls=known_github_urls, + known_gitlab_urls=known_gitlab_urls, + git_tag_format=git_tag_format, + package_aliases=package_aliases, + packages=merged_deps) + if len(unmatched_packages) > 0: + logger.error( + "\n------------\nPackages without github info which will need to be fetched manually:\n%s\n------------\n", + pprint.pformat(unmatched_packages)) + + if dry_run: + sys.exit(0) + + with requests.Session() as session: + if verify_urls: + verify_tar_urls(session, dep_urls) + + if download: + assert download_dir is not None, "download_dir must be set when download is True" + download_tars(session, dep_urls, download_dir) + + if extract: + assert extract_dir is not None, "extract_dir must be set when extract is True" + extract_tar_files(dep_urls, extract_dir) + + missing_packages = print_summary(dep_urls, unmatched_packages, download, extract) + if extract: + print(f"Exraction location: {extract_dir}") + + return len(missing_packages) + + def parse_args(): argparser = argparse.ArgumentParser( "Download source code for third party dependencies specified in a Conda environment yaml file, by default " @@ -426,7 +533,8 @@ def parse_args(): "This is used to determine the exact version number actually used by a package which " "specifies a version range in the Conda environment file.")) - argparser.add_argument('--skip_url_verify', default=False, action='store_true') + argparser.add_argument('--dry_run', default=False, action='store_true') + argparser.add_argument('--verify_urls', default=False, action='store_true') argparser.add_argument('--download', default=False, action='store_true') argparser.add_argument('--download_dir', @@ -469,54 +577,42 @@ def main(): logging.getLogger('requests').setLevel(args.http_log_level) logging.getLogger("urllib3").setLevel(args.http_log_level) - declared_deps = parse_env_file(args.conda_yaml) - resolved_conda_deps = parse_json_deps(args.conda_json) - - merged_deps = merge_deps(declared_deps, KNOWN_NON_CONDA_DEPS, resolved_conda_deps) - - if logger.isEnabledFor(logging.DEBUG): - logger.debug("Declared Yaml deps:\n%s", pprint.pformat(sorted(declared_deps))) - logger.debug("Resolved Conda deps:\n%s", pprint.pformat(resolved_conda_deps)) - logger.debug("Merged deps:\n%s", pprint.pformat(merged_deps)) - - (dep_urls, unmatched_packages) = mk_repo_urls(merged_deps) - if len(unmatched_packages) > 0: - logger.error( - "\n------------\nPackages without github info which will need to be fetched manually:\n%s\n------------\n", - pprint.pformat(unmatched_packages)) - - if not args.download and args.skip_url_verify: - sys.exit(0) - - with requests.Session() as session: - if not args.skip_url_verify: - verify_tar_urls(session, dep_urls) - - download_dir: str | None = args.download_dir + needs_cleanup = False + download_dir: str | None = args.download_dir + if download_dir is None: if args.download: - if download_dir is None: - download_dir = tempfile.mkdtemp(prefix="morpheus_deps_download_") - logger.info("Created temporary download directory: %s", download_dir) - - download_tars(session, dep_urls, download_dir) + download_dir = tempfile.mkdtemp(prefix="morpheus_deps_download_") + logger.info("Created temporary download directory: %s", download_dir) + needs_cleanup = True + elif args.extract: + logger.error("--extract requires either --download or --download_dir to be set.") + sys.exit(1) extract_dir: str | None = args.extract_dir - if args.extract: - if extract_dir is None: - extract_dir = tempfile.mkdtemp(prefix="morpheus_deps_extract_") - logger.info("Created temporary extract directory: %s", extract_dir) - - extract_tar_files(dep_urls, extract_dir) - - if args.download_dir is None and download_dir is not None and not args.no_clean: + if extract_dir is None and args.extract: + extract_dir = tempfile.mkdtemp(prefix="morpheus_deps_extract_") + logger.info("Created temporary extract directory: %s", extract_dir) + + num_missing_packages = download_source_deps(conda_yaml=args.conda_yaml, + conda_json=args.conda_json, + package_aliases=PACKAGE_ALIASES, + known_github_urls=KNOWN_GITHUB_URLS, + known_gitlab_urls=KNOWN_GITLAB_URLS, + known_first_party=KNOWN_FIRST_PARTY, + git_tag_format=GIT_TAG_FORMAT, + known_non_conda_deps=KNOWN_NON_CONDA_DEPS, + dry_run=args.dry_run, + verify_urls=args.verify_urls, + download=args.download, + download_dir=download_dir, + extract=args.extract, + extract_dir=extract_dir) + + if needs_cleanup and not args.no_clean and download_dir is not None: logger.info("Removing temporary download directory: %s", download_dir) shutil.rmtree(download_dir) - missing_packages = print_summary(dep_urls, unmatched_packages, args.download, args.extract) - if args.extract: - print(f"Exraction location: {extract_dir}") - - if len(missing_packages) > 0: + if num_missing_packages > 0: sys.exit(1)