Skip to content

Commit

Permalink
Move, rename and cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Nov 26, 2024
1 parent ed08b0c commit 104e233
Showing 1 changed file with 164 additions and 68 deletions.
232 changes: 164 additions & 68 deletions scripts/generate_deps.py → ci/release/download_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import yaml

SCRIPT_DIR = os.path.relpath(os.path.dirname(__file__))
PROJ_ROOT = os.path.dirname(SCRIPT_DIR)
PROJ_ROOT = os.path.dirname(os.path.dirname(SCRIPT_DIR))

PIP_FLAGS_RE = re.compile(r"^--.*")
STRIP_VER_RE = re.compile(r"^([\w|-]+).*")
Expand Down Expand Up @@ -112,9 +112,8 @@
}

# Please keep sorted
KNOWN_FIRST_PARTY = {
'cuda-cudart', 'cuda-nvrtc', 'cuda-nvtx', 'cuda-version', 'cudf', 'mrc', 'rapids-dask-dependency', 'tritonclient'
}
KNOWN_FIRST_PARTY = frozenset(
{'cuda-cudart', 'cuda-nvrtc', 'cuda-nvtx', 'cuda-version', 'cudf', 'mrc', 'rapids-dask-dependency', 'tritonclient'})

# Some of these packages are installed via CPM (pybind11), others are transitive deps who's version is determined by
# other packages but we use directly (glog), while others exist in the build environment and are statically linked
Expand Down Expand Up @@ -168,13 +167,15 @@
logger = logging.getLogger(__file__)


def _get_repo_info(url_map: dict, tag_url_path: str, pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
try:
repo_url = url_map[repo_name]
except KeyError:
return None

tag_formatter = GIT_TAG_FORMAT.get(repo_name, TAG_V_PREFIX)
def _get_repo_info(*,
git_tag_format,
repo_url_map: dict,
tag_url_path: str,
pkg_name: str,
repo_name: str,
pkg_version: str) -> dict[str, typing.Any]:
repo_url = repo_url_map[repo_name]
tag_formatter = git_tag_format.get(repo_name, TAG_V_PREFIX)
if isinstance(tag_formatter, str):
tag = tag_formatter.format(name=repo_name, version=pkg_version)
else:
Expand All @@ -185,23 +186,21 @@ def _get_repo_info(url_map: dict, tag_url_path: str, pkg_name: str, repo_name: s
return {'packages': [pkg_name], 'tag': tag, 'tar_url': tar_url}


def _get_github_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
return _get_repo_info(KNOWN_GITHUB_URLS, GIT_HUB_TAG_URL_PATH, pkg_name, repo_name, pkg_version)


def _get_gitlab_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
return _get_repo_info(KNOWN_GITLAB_URLS, GIT_LAB_TAG_URL_PATH, pkg_name, repo_name, pkg_version)


def mk_repo_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]:
def mk_repo_urls(*,
known_first_party: frozenset[str],
known_github_urls: dict[str, str],
known_gitlab_urls: dict[str, str],
git_tag_format: dict[str, str],
package_aliases: dict[str, str],
packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]:
matched = {}
unmatched: list[str] = []
for (pkg_name, pkg_version) in packages:
if pkg_name in KNOWN_FIRST_PARTY:
if pkg_name in known_first_party:
logger.debug("Skipping first party package: %s", pkg_name)
continue

repo_name = PACKAGE_ALIASES.get(pkg_name, pkg_name)
repo_name = package_aliases.get(pkg_name, pkg_name)
if repo_name != pkg_name:
logger.debug("Package %s is knwon as %s", pkg_name, repo_name)

Expand All @@ -212,9 +211,16 @@ def mk_repo_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any]

i = 0
repo_info = None
repo_getters = (_get_github_info, _get_gitlab_info)
while repo_info is None and i < len(repo_getters):
repo_info = repo_getters[i](pkg_name, repo_name, pkg_version)
repos = ((known_github_urls, GIT_HUB_TAG_URL_PATH), (known_gitlab_urls, GIT_LAB_TAG_URL_PATH))
while repo_info is None and i < len(repos):
(repo_url_map, tag_url_path) = repos[i]
if repo_name in repo_url_map:
repo_info = _get_repo_info(repo_url_map=repo_url_map,
pkg_name=pkg_name,
repo_name=repo_name,
pkg_version=pkg_version,
tag_url_path=tag_url_path,
git_tag_format=git_tag_format)
i += 1

if repo_info is not None:
Expand Down Expand Up @@ -409,6 +415,107 @@ def print_summary(dep_urls: dict[str, typing.Any], unmatched_packages: list[str]
return missing_packages


def download_source_deps(*,
conda_yaml: str,
conda_json: str,
package_aliases: dict[str, str],
known_github_urls: dict[str, str],
known_gitlab_urls: dict[str, str],
known_first_party: frozenset[str],
git_tag_format: dict[str, str],
known_non_conda_deps: list[tuple[str, str]],
dry_run: bool = False,
verify_urls: bool = False,
download: bool = False,
download_dir: str | None = None,
extract: bool = False,
extract_dir: str | None = None) -> int:
"""
Main entry point for downloading source dependencies.
Parameters
----------
conda_yaml : str
Path to the Conda environment file to read dependencies from.
conda_json : str
Path to the JSON formatted output of the resolved Conda environment. Generated by running:
`./docker/run_container_release.sh conda list --json > .tmp/container_pkgs.json`
package_aliases : dict[str, str]
Mapping of Conda package names to their upstream repo name. This is needed primarily to handle the case where
multiple Conda packages are derived from a single upstream repo.
known_github_urls : dict[str, str]
Mapping of package names to their GitHub repo URL.
known_gitlab_urls : dict[str, str]
Mapping of package names to their GitLab repo URL. This is kept separate from `known_github_urls` since they
have different tag URL formats.
known_first_party : frozenset[str]
Set of first party packages that are not downloaded.
git_tag_format : dict[str, str]
Mapping of package names to the format of tag names, by default the "v{version}" format is used.
dry_run : bool, optional
If True, do not download or extract any files, just merge dependencies, by default False
verify_urls : bool, optional
If True, verify that the URLs are valid, when `download` and `extract` are `False` this is effectively a more
verbose dry_run, by default False
download : bool, optional
If True, download the tar archives, by default False
download_dir : str, optional
Required when `download` or `extract` is True
extract : bool, optional
If True, extract the tar archives to `extract_dir`, by default False
extract_dir : str, optional
Required when `extract` is True
known_non_conda_deps : list[tuple[str, str]]
List of dependencies that are not specified in the Conda environment file, but are known to be required.
Returns
-------
int
Number of missing packages
"""
declared_deps = parse_env_file(conda_yaml)
resolved_conda_deps = parse_json_deps(conda_json)

merged_deps = merge_deps(declared_deps, known_non_conda_deps, resolved_conda_deps)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("Declared Yaml deps:\n%s", pprint.pformat(sorted(declared_deps)))
logger.debug("Resolved Conda deps:\n%s", pprint.pformat(resolved_conda_deps))
logger.debug("Merged deps:\n%s", pprint.pformat(merged_deps))

(dep_urls, unmatched_packages) = mk_repo_urls(known_first_party=known_first_party,
known_github_urls=known_github_urls,
known_gitlab_urls=known_gitlab_urls,
git_tag_format=git_tag_format,
package_aliases=package_aliases,
packages=merged_deps)
if len(unmatched_packages) > 0:
logger.error(
"\n------------\nPackages without github info which will need to be fetched manually:\n%s\n------------\n",
pprint.pformat(unmatched_packages))

if dry_run:
sys.exit(0)

with requests.Session() as session:
if verify_urls:
verify_tar_urls(session, dep_urls)

if download:
assert download_dir is not None, "download_dir must be set when download is True"
download_tars(session, dep_urls, download_dir)

if extract:
assert extract_dir is not None, "extract_dir must be set when extract is True"
extract_tar_files(dep_urls, extract_dir)

missing_packages = print_summary(dep_urls, unmatched_packages, download, extract)
if extract:
print(f"Exraction location: {extract_dir}")

return len(missing_packages)


def parse_args():
argparser = argparse.ArgumentParser(
"Download source code for third party dependencies specified in a Conda environment yaml file, by default "
Expand All @@ -426,7 +533,8 @@ def parse_args():
"This is used to determine the exact version number actually used by a package which "
"specifies a version range in the Conda environment file."))

argparser.add_argument('--skip_url_verify', default=False, action='store_true')
argparser.add_argument('--dry_run', default=False, action='store_true')
argparser.add_argument('--verify_urls', default=False, action='store_true')
argparser.add_argument('--download', default=False, action='store_true')

argparser.add_argument('--download_dir',
Expand Down Expand Up @@ -469,54 +577,42 @@ def main():
logging.getLogger('requests').setLevel(args.http_log_level)
logging.getLogger("urllib3").setLevel(args.http_log_level)

declared_deps = parse_env_file(args.conda_yaml)
resolved_conda_deps = parse_json_deps(args.conda_json)

merged_deps = merge_deps(declared_deps, KNOWN_NON_CONDA_DEPS, resolved_conda_deps)

if logger.isEnabledFor(logging.DEBUG):
logger.debug("Declared Yaml deps:\n%s", pprint.pformat(sorted(declared_deps)))
logger.debug("Resolved Conda deps:\n%s", pprint.pformat(resolved_conda_deps))
logger.debug("Merged deps:\n%s", pprint.pformat(merged_deps))

(dep_urls, unmatched_packages) = mk_repo_urls(merged_deps)
if len(unmatched_packages) > 0:
logger.error(
"\n------------\nPackages without github info which will need to be fetched manually:\n%s\n------------\n",
pprint.pformat(unmatched_packages))

if not args.download and args.skip_url_verify:
sys.exit(0)

with requests.Session() as session:
if not args.skip_url_verify:
verify_tar_urls(session, dep_urls)

download_dir: str | None = args.download_dir
needs_cleanup = False
download_dir: str | None = args.download_dir
if download_dir is None:
if args.download:
if download_dir is None:
download_dir = tempfile.mkdtemp(prefix="morpheus_deps_download_")
logger.info("Created temporary download directory: %s", download_dir)

download_tars(session, dep_urls, download_dir)
download_dir = tempfile.mkdtemp(prefix="morpheus_deps_download_")
logger.info("Created temporary download directory: %s", download_dir)
needs_cleanup = True
elif args.extract:
logger.error("--extract requires either --download or --download_dir to be set.")
sys.exit(1)

extract_dir: str | None = args.extract_dir
if args.extract:
if extract_dir is None:
extract_dir = tempfile.mkdtemp(prefix="morpheus_deps_extract_")
logger.info("Created temporary extract directory: %s", extract_dir)

extract_tar_files(dep_urls, extract_dir)

if args.download_dir is None and download_dir is not None and not args.no_clean:
if extract_dir is None and args.extract:
extract_dir = tempfile.mkdtemp(prefix="morpheus_deps_extract_")
logger.info("Created temporary extract directory: %s", extract_dir)

num_missing_packages = download_source_deps(conda_yaml=args.conda_yaml,
conda_json=args.conda_json,
package_aliases=PACKAGE_ALIASES,
known_github_urls=KNOWN_GITHUB_URLS,
known_gitlab_urls=KNOWN_GITLAB_URLS,
known_first_party=KNOWN_FIRST_PARTY,
git_tag_format=GIT_TAG_FORMAT,
known_non_conda_deps=KNOWN_NON_CONDA_DEPS,
dry_run=args.dry_run,
verify_urls=args.verify_urls,
download=args.download,
download_dir=download_dir,
extract=args.extract,
extract_dir=extract_dir)

if needs_cleanup and not args.no_clean and download_dir is not None:
logger.info("Removing temporary download directory: %s", download_dir)
shutil.rmtree(download_dir)

missing_packages = print_summary(dep_urls, unmatched_packages, args.download, args.extract)
if args.extract:
print(f"Exraction location: {extract_dir}")

if len(missing_packages) > 0:
if num_missing_packages > 0:
sys.exit(1)


Expand Down

0 comments on commit 104e233

Please sign in to comment.