Skip to content

Commit

Permalink
allow pip requirements in environment.yml (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsavchenko authored Nov 14, 2024
1 parent 06bb652 commit 6fde915
Showing 1 changed file with 82 additions and 58 deletions.
140 changes: 82 additions & 58 deletions nb2workflow/galaxy.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations

import xml.etree.ElementTree as ET
import os
Expand Down Expand Up @@ -28,8 +29,8 @@
import isort

logger = logging.getLogger()


default_ontology_path = 'http://odahub.io/ontology/ontology.ttl'

global_req = []
Expand Down Expand Up @@ -189,9 +190,8 @@ def to_xml_tree(self):
element = ET.Element('data', **attrs)

return element





def _nb2script(nba, ontology_path):
input_nb = nba.notebook_fn
mynb = nbformat.read(input_nb, as_version=4)
Expand Down Expand Up @@ -335,15 +335,15 @@ def _nb2script(nba, ontology_path):


class Requirements:

def __init__(self, available_channels, conda_env_yml = None, requirements_txt = None, extra_req = []):
self.tmpdir = tempfile.TemporaryDirectory()
self.fullenv_file_path = os.path.join(self.tmpdir.name, 'environment.yml')

self.micromamba = self._get_micromamba_binary()

boilerplate_env_dict = {'channels': available_channels, 'dependencies': extra_req}

if conda_env_yml is not None:
with open(conda_env_yml, 'r') as fd:
self.env_dict = yaml.safe_load(fd)
Expand All @@ -361,29 +361,43 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt =
self.env_dict = boilerplate_env_dict
else:
self.env_dict = boilerplate_env_dict

match_spec = re.compile(r'^(?P<pac>[^=<> ]+)')
self._direct_dependencies = []
for dep in self.env_dict['dependencies']:
m = match_spec.match(dep)
self._direct_dependencies.append((m.group('pac'), '', 0, ''))

if requirements_txt is not None:
pip_reqs = self._parse_requirements_txt(requirements_txt)


pip_reqs = []
pip_dep_idx = None
for i, dep in enumerate(self.env_dict['dependencies']):
if isinstance(dep, dict) and list(dep.keys()) == ['pip']:
for pipdep in dep['pip']:
parsed = self._parse_requirements_line(pipdep)
if parsed is not None:
pip_reqs.append(parsed)
pip_dep_idx = i
else:
m = match_spec.match(dep)
self._direct_dependencies.append((m.group('pac'), '', 0, ''))

if pip_dep_idx is not None:
del self.env_dict['dependencies'][pip_dep_idx]

if requirements_txt is not None or pip_reqs:
if requirements_txt is not None:
pip_reqs += self._parse_requirements_txt(requirements_txt)

channels_cl = []
for ch in self.env_dict['channels']:
channels_cl.append('-c')
channels_cl.append(ch)

for req in pip_reqs:
if req[2] == 2:
self._direct_dependencies.append(req)
continue
run_cmd = [self.micromamba, 'search', '--json']
run_cmd.extend(channels_cl)
run_cmd.append(req[0]+req[1])

search_res = sp.run(run_cmd, check=True, capture_output=True, text=True)
search_json = json.loads(search_res.stdout)
if search_json['result']['pkgs']:
Expand All @@ -392,15 +406,15 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt =
else:
logger.warning(f'Dependency {req[0]} not found in conda channels.')
self._direct_dependencies.append((req[0], req[1], 2, req[3]))

if self.env_dict["dependencies"]:
with open(self.fullenv_file_path, 'w') as fd:
yaml.dump(self.env_dict, fd)

resolved_env = self._resolve_environment_yml()
else:
resolved_env = {}

self.final_dependencies = {}
for dep in self._direct_dependencies:
if dep[2] == 2:
Expand All @@ -409,13 +423,12 @@ def __init__(self, available_channels, conda_env_yml = None, requirements_txt =
continue
else:
self.final_dependencies[dep[0]] = (resolved_env[dep[0]], dep[2], dep[3])



def _resolve_environment_yml(self):

with open(self.fullenv_file_path) as fd:
logger.info(f'Will resolve environment:\n\n{fd.read()}')

run_command = [str(self.micromamba),
'env', 'create',
'-n', '__temp_env_name',
Expand All @@ -425,7 +438,7 @@ def _resolve_environment_yml(self):
run_proc = sp.run(run_command, capture_output=True, check=True, text=True)
resolved_env = json.loads(run_proc.stdout)['actions']['FETCH']
resolved_env = {x['name']: x['version'] for x in resolved_env}

return resolved_env

def to_xml_tree(self):
Expand All @@ -440,39 +453,51 @@ def to_xml_tree(self):
type='package',
version = det[0]))
reqs_elements[-1].text = name

return reqs_elements

@staticmethod
def _parse_requirements_txt(filepath):
def _parse_requirements_line(line: str) -> tuple | None:
match_spec = re.compile(
r"^(?P<pac>[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*(?:\[.*\])?\s*(?P<eq>[~=]{0,2})(?P<uneq>[<>]?=?)\s*(?P<ver>[0-9.\*]*)",
re.I,
)
match_from_url = re.compile(
r"^(?P<pac>[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*@(?P<path>.*)", re.I
)

match_spec = re.compile(r'^(?P<pac>[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*(?:\[.*\])?\s*(?P<eq>[~=]{0,2})(?P<uneq>[<>]?=?)\s*(?P<ver>[0-9.\*]*)', re.I)
match_from_url = re.compile(r'^(?P<pac>[A-Z0-9][A-Z0-9._-]*[A-Z0-9])\s*@(?P<path>.*)', re.I)

# TODO: basic, see https://pip.pypa.io/en/stable/reference/requirement-specifiers/

with open(filepath, 'r') as fd:
reqs_str_list = []
if line.startswith("#") or re.match(r"^\s*$", line):
return
elif line.startswith("git+"):
logger.warning("Dependency from git repo is not supported: %s", line)
return (line, "", 2, line)
elif match_from_url.match(line):
logger.warning("Dependency from url is not supported %s", line)
return (line, "", 2, line)
else:
m = match_spec.match(line)
if m is None:
logger.warning("Dependency spec not recognised for %s", line)
return (line, "", 2, line)
if m.group("ver"):
ver = (
m.group("uneq") + m.group("ver")
if m.group("uneq")
else m.group("eq") + m.group("ver")
)
else:
ver = ""
return (m.group("pac"), ver, 1, line)

def _parse_requirements_txt(self, filepath: str | os.PathLike) -> list[tuple]:

with open(filepath, "r") as fd:
reqs_str_list = []
for line in fd:
if line.startswith('#') or re.match(r'^\s*$', line):
continue
elif line.startswith('git+'):
logger.warning('Dependency from git repo is not supported: %s', line)
reqs_str_list.append((line, '', 2, line))
elif match_from_url.match(line):
logger.warning('Dependency from url is not supported %s', line)
reqs_str_list.append((line, '', 2, line))
else:
m = match_spec.match(line)
if m is None:
logger.warning('Dependency spec not recognised for %s', line)
reqs_str_list.append((line, '', 2, line))
if m.group('ver'):
ver = m.group('uneq') + m.group('ver') if m.group('uneq') else m.group('eq') + m.group('ver')
else:
ver = ''
reqs_str_list.append((m.group('pac'), ver, 1, line))

parsed = self._parse_requirements_line(line)
if parsed is not None:
reqs_str_list.append(parsed)

return reqs_str_list

@staticmethod
Expand Down Expand Up @@ -540,7 +565,7 @@ def _test_data_location(repo_dir, tool_dir, default_value, base_url=None):
# TODO: support the case of nb not in repo root
location = os.path.join(base_url, default_value)
return value, location

def to_galaxy(input_path,
toolname,
out_dir,
Expand Down Expand Up @@ -724,4 +749,3 @@ def main():

if __name__ == '__main__':
main()

0 comments on commit 6fde915

Please sign in to comment.