diff --git a/nb2workflow/galaxy.py b/nb2workflow/galaxy.py index 4b3b88e..048f760 100644 --- a/nb2workflow/galaxy.py +++ b/nb2workflow/galaxy.py @@ -1,3 +1,4 @@ +from __future__ import annotations import xml.etree.ElementTree as ET import os @@ -28,8 +29,8 @@ import isort logger = logging.getLogger() - - + + default_ontology_path = 'http://odahub.io/ontology/ontology.ttl' global_req = [] @@ -189,9 +190,8 @@ def to_xml_tree(self): element = ET.Element('data', **attrs) return element - - - + + def _nb2script(nba, ontology_path): input_nb = nba.notebook_fn mynb = nbformat.read(input_nb, as_version=4) @@ -335,15 +335,15 @@ def _nb2script(nba, ontology_path): class Requirements: - + def __init__(self, available_channels, conda_env_yml = None, requirements_txt = None, extra_req = []): self.tmpdir = tempfile.TemporaryDirectory() self.fullenv_file_path = os.path.join(self.tmpdir.name, 'environment.yml') - + self.micromamba = self._get_micromamba_binary() - + boilerplate_env_dict = {'channels': available_channels, 'dependencies': extra_req} - + if conda_env_yml is not None: with open(conda_env_yml, 'r') as fd: self.env_dict = yaml.safe_load(fd) @@ -361,21 +361,35 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt = self.env_dict = boilerplate_env_dict else: self.env_dict = boilerplate_env_dict - + match_spec = re.compile(r'^(?P[^=<> ]+)') self._direct_dependencies = [] - for dep in self.env_dict['dependencies']: - m = match_spec.match(dep) - self._direct_dependencies.append((m.group('pac'), '', 0, '')) - - if requirements_txt is not None: - pip_reqs = self._parse_requirements_txt(requirements_txt) - + + pip_reqs = [] + pip_dep_idx = None + for i, dep in enumerate(self.env_dict['dependencies']): + if isinstance(dep, dict) and list(dep.keys()) == ['pip']: + for pipdep in dep['pip']: + parsed = self._parse_requirements_line(pipdep) + if parsed is not None: + pip_reqs.append(parsed) + pip_dep_idx = i + else: + m = match_spec.match(dep) + self._direct_dependencies.append((m.group('pac'), '', 0, '')) + + if pip_dep_idx is not None: + del self.env_dict['dependencies'][pip_dep_idx] + + if requirements_txt is not None or pip_reqs: + if requirements_txt is not None: + pip_reqs += self._parse_requirements_txt(requirements_txt) + channels_cl = [] for ch in self.env_dict['channels']: channels_cl.append('-c') channels_cl.append(ch) - + for req in pip_reqs: if req[2] == 2: self._direct_dependencies.append(req) @@ -383,7 +397,7 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt = run_cmd = [self.micromamba, 'search', '--json'] run_cmd.extend(channels_cl) run_cmd.append(req[0]+req[1]) - + search_res = sp.run(run_cmd, check=True, capture_output=True, text=True) search_json = json.loads(search_res.stdout) if search_json['result']['pkgs']: @@ -392,15 +406,15 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt = else: logger.warning(f'Dependency {req[0]} not found in conda channels.') self._direct_dependencies.append((req[0], req[1], 2, req[3])) - + if self.env_dict["dependencies"]: with open(self.fullenv_file_path, 'w') as fd: yaml.dump(self.env_dict, fd) - + resolved_env = self._resolve_environment_yml() else: resolved_env = {} - + self.final_dependencies = {} for dep in self._direct_dependencies: if dep[2] == 2: @@ -409,13 +423,12 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt = continue else: self.final_dependencies[dep[0]] = (resolved_env[dep[0]], dep[2], dep[3]) - - + def _resolve_environment_yml(self): - + with open(self.fullenv_file_path) as fd: logger.info(f'Will resolve environment:\n\n{fd.read()}') - + run_command = [str(self.micromamba), 'env', 'create', '-n', '__temp_env_name', @@ -425,7 +438,7 @@ def _resolve_environment_yml(self): run_proc = sp.run(run_command, capture_output=True, check=True, text=True) resolved_env = json.loads(run_proc.stdout)['actions']['FETCH'] resolved_env = {x['name']: x['version'] for x in resolved_env} - + return resolved_env def to_xml_tree(self): @@ -440,39 +453,51 @@ def to_xml_tree(self): type='package', version = det[0])) reqs_elements[-1].text = name - + return reqs_elements - + @staticmethod - def _parse_requirements_txt(filepath): + def _parse_requirements_line(line: str) -> tuple | None: + match_spec = re.compile( + r"^(?P[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*(?:\[.*\])?\s*(?P[~=]{0,2})(?P[<>]?=?)\s*(?P[0-9.\*]*)", + re.I, + ) + match_from_url = re.compile( + r"^(?P[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*@(?P.*)", re.I + ) - match_spec = re.compile(r'^(?P[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*(?:\[.*\])?\s*(?P[~=]{0,2})(?P[<>]?=?)\s*(?P[0-9.\*]*)', re.I) - match_from_url = re.compile(r'^(?P[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*@(?P.*)', re.I) - - # TODO: basic, see https://pip.pypa.io/en/stable/reference/requirement-specifiers/ - - with open(filepath, 'r') as fd: - reqs_str_list = [] + if line.startswith("#") or re.match(r"^\s*$", line): + return + elif line.startswith("git+"): + logger.warning("Dependency from git repo is not supported: %s", line) + return (line, "", 2, line) + elif match_from_url.match(line): + logger.warning("Dependency from url is not supported %s", line) + return (line, "", 2, line) + else: + m = match_spec.match(line) + if m is None: + logger.warning("Dependency spec not recognised for %s", line) + return (line, "", 2, line) + if m.group("ver"): + ver = ( + m.group("uneq") + m.group("ver") + if m.group("uneq") + else m.group("eq") + m.group("ver") + ) + else: + ver = "" + return (m.group("pac"), ver, 1, line) + + def _parse_requirements_txt(self, filepath: str | os.PathLike) -> list[tuple]: + + with open(filepath, "r") as fd: + reqs_str_list = [] for line in fd: - if line.startswith('#') or re.match(r'^\s*$', line): - continue - elif line.startswith('git+'): - logger.warning('Dependency from git repo is not supported: %s', line) - reqs_str_list.append((line, '', 2, line)) - elif match_from_url.match(line): - logger.warning('Dependency from url is not supported %s', line) - reqs_str_list.append((line, '', 2, line)) - else: - m = match_spec.match(line) - if m is None: - logger.warning('Dependency spec not recognised for %s', line) - reqs_str_list.append((line, '', 2, line)) - if m.group('ver'): - ver = m.group('uneq') + m.group('ver') if m.group('uneq') else m.group('eq') + m.group('ver') - else: - ver = '' - reqs_str_list.append((m.group('pac'), ver, 1, line)) - + parsed = self._parse_requirements_line(line) + if parsed is not None: + reqs_str_list.append(parsed) + return reqs_str_list @staticmethod @@ -540,7 +565,7 @@ def _test_data_location(repo_dir, tool_dir, default_value, base_url=None): # TODO: support the case of nb not in repo root location = os.path.join(base_url, default_value) return value, location - + def to_galaxy(input_path, toolname, out_dir, @@ -724,4 +749,3 @@ def main(): if __name__ == '__main__': main() -