Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy #291

Merged
merged 29 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bb523b4
Add mypy
felixhekhorn Jul 21, 2023
103567c
Restrict mypy and start fixing
felixhekhorn Jul 21, 2023
e7b6e93
Start fixing tests for num_flavs_ref=None
felixhekhorn Jul 21, 2023
bbfc166
Apply more mypy fixes
felixhekhorn Jul 21, 2023
2d8d8bf
Apply more mypy fixes 2
felixhekhorn Jul 21, 2023
494da3c
Rename OpMembers
felixhekhorn Jul 21, 2023
48ac90e
Fix mypy in io/legacy
felixhekhorn Jul 21, 2023
bab8cf2
Fix mypy in box/{genpdf,apply}
felixhekhorn Jul 21, 2023
e61c8e6
Fix ekomark/plots
felixhekhorn Jul 24, 2023
e40a718
Fix ekobox/cli
felixhekhorn Jul 24, 2023
9271ebc
Make managers class
felixhekhorn Jul 24, 2023
80b85f1
Use backported return in apply
felixhekhorn Jul 24, 2023
34c68a5
Fix eko/io/legacy masses usage
felixhekhorn Jul 24, 2023
d6d1734
Fix remaining mypy errors
felixhekhorn Jul 24, 2023
6e0f434
Remove init from benchmarks
felixhekhorn Jul 25, 2023
e8c218e
Drop new_op_key variable in grid
felixhekhorn Aug 17, 2023
54d9e53
Define couplings cache key type
felixhekhorn Aug 17, 2023
9b0d804
Define label type in apply
felixhekhorn Aug 17, 2023
3b57600
Cast labels in evop/Operator
felixhekhorn Aug 17, 2023
bbe955c
Update src/eko/runner/operators.py
felixhekhorn Aug 17, 2023
16e3809
Introduce OperatorLabel type
felixhekhorn Aug 17, 2023
c9cdc72
Remove list comprension in msbar
felixhekhorn Aug 17, 2023
c946b3f
Upgrade banana
felixhekhorn Aug 17, 2023
0451697
Merge branch 'master' into mypy
felixhekhorn Jan 12, 2024
a49fca6
Fix evol_pdf
felixhekhorn Jan 12, 2024
4be6d44
Merge branch 'master' into mypy
felixhekhorn Jul 15, 2024
0d13e35
Merge branch 'master' into mypy
felixhekhorn Aug 8, 2024
b651ebe
Fix Rust ev_op patch
felixhekhorn Aug 8, 2024
deb41cd
Merge branch 'master' into mypy
felixhekhorn Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ repos:
args: ["--add-ignore=D107,D105"]
additional_dependencies:
- toml
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML]
pass_filenames: false
args: ["--ignore-missing-imports", "src/"]
- repo: local
hooks:
- id: fmt
Expand Down
9 changes: 7 additions & 2 deletions src/eko/couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
"""

import logging
from typing import Iterable, List
from typing import Dict, Iterable, List, Tuple

import numba as nb
import numpy as np
import numpy.typing as npt
import scipy

from . import constants, matchings
Expand Down Expand Up @@ -383,6 +384,10 @@ def couplings_expanded_fixed_alphaem(order, couplings_ref, nf, scale_from, scale
return np.array([res_as, aem])


_CouplingsCacheKey = Tuple[float, float, int, float, float]
"""Cache key containing (a0, a1, nf, scale_from, scale_to)."""


class Couplings:
r"""Compute the strong and electromagnetic coupling constants :math:`a_s, a_{em}`.

Expand Down Expand Up @@ -480,7 +485,7 @@ def assert_positive(name, var):
self.decoupled_running,
)
# cache
self.cache = {}
self.cache: Dict[_CouplingsCacheKey, npt.NDArray] = {}

@property
def mu2_ref(self):
Expand Down
21 changes: 13 additions & 8 deletions src/eko/evolution_operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from multiprocessing import Pool
from typing import Dict, Tuple

import numba as nb
import numpy as np
Expand All @@ -20,7 +21,7 @@
from .. import basis_rotation as br
from .. import interpolation, mellin
from .. import scale_variations as sv
from ..io.types import EvolutionMethod
from ..io.types import EvolutionMethod, OperatorLabel
from ..kernels import ev_method
from ..kernels import non_singlet as ns
from ..kernels import non_singlet_qed as qed_ns
Expand Down Expand Up @@ -602,6 +603,10 @@ def quad_ker_qed(
return ker


OpMembers = Dict[OperatorLabel, OpMember]
"""Map of all operators."""


class Operator(sv.ScaleVariationModeMixin):
"""Internal representation of a single EKO.

Expand All @@ -627,8 +632,8 @@ class Operator(sv.ScaleVariationModeMixin):

log_label = "Evolution"
# complete list of possible evolution operators labels
full_labels = br.full_labels
full_labels_qed = br.full_unified_labels
full_labels: Tuple[OperatorLabel, ...] = br.full_labels
full_labels_qed: Tuple[OperatorLabel, ...] = br.full_unified_labels

def __init__(
self, config, managers, segment: Segment, mellin_cut=5e-2, is_threshold=False
Expand All @@ -641,9 +646,9 @@ def __init__(
# TODO make 'cut' external parameter?
self._mellin_cut = mellin_cut
self.is_threshold = is_threshold
self.op_members = {}
self.op_members: OpMembers = {}
self.order = tuple(config["order"])
self.alphaem_running = self.managers["couplings"].alphaem_running
self.alphaem_running = self.managers.couplings.alphaem_running
if self.log_label == "Evolution":
self.a = self.compute_a()
self.as_list, self.a_half_list = self.compute_aem_list()
Expand All @@ -665,7 +670,7 @@ def xif2(self):
@property
def int_disp(self):
"""Return the interpolation dispatcher."""
return self.managers["interpol_dispatcher"]
return self.managers.interpolator

@property
def grid_size(self):
Expand All @@ -688,7 +693,7 @@ def mu2(self):

def compute_a(self):
"""Return the computed values for :math:`a_s` and :math:`a_{em}`."""
coupling = self.managers["couplings"]
coupling = self.managers.couplings
a0 = coupling.a(
self.mu2[0],
nf_to=self.nf,
Expand Down Expand Up @@ -724,7 +729,7 @@ def compute_aem_list(self):
as_list = np.array([self.a_s[0], self.a_s[1]])
a_half = np.zeros((ev_op_iterations, 2))
else:
couplings = self.managers["couplings"]
couplings = self.managers.couplings
mu2_steps = np.geomspace(self.q2_from, self.q2_to, 1 + ev_op_iterations)
mu2_l = mu2_steps[0]
as_list = np.array(
Expand Down
38 changes: 23 additions & 15 deletions src/eko/evolution_operator/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""

import logging
from dataclasses import astuple
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -18,9 +18,9 @@
from ..interpolation import InterpolatorDispatcher
from ..io.runcards import Configs, Debug
from ..io.types import EvolutionPoint as EPoint
from ..io.types import Order
from ..io.types import Order, SquaredScale
from ..matchings import Atlas, Segment, flavor_shift, is_downward_path
from . import Operator, flavors, matching_condition, physical
from . import Operator, OpMembers, flavors, matching_condition, physical
from .operator_matrix_element import OperatorMatrixElement

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,15 @@
"""In particular, only the ``operator`` and ``error`` fields are expected."""


@dataclass(frozen=True)
class Managers:
"""Set of steering objects."""

atlas: Atlas
couplings: Couplings
interpolator: InterpolatorDispatcher


class OperatorGrid(sv.ScaleVariationModeMixin):
"""Collection of evolution operators for several scales.

Expand Down Expand Up @@ -64,7 +73,7 @@ def __init__(
use_fhmruvv: bool,
):
# check
config = {}
config: Dict[str, Any] = {}
config["order"] = order
config["intrinsic_range"] = intrinsic_flavors
config["xif2"] = xif**2
Expand Down Expand Up @@ -95,13 +104,13 @@ def __init__(

self.config = config
self.q2_grid = mu2grid
self.managers = dict(
thresholds_config=atlas,
self.managers = Managers(
atlas=atlas,
couplings=couplings,
interpol_dispatcher=interpol_dispatcher,
interpolator=interpol_dispatcher,
)
self._threshold_operators = {}
self._matching_operators = {}
self._threshold_operators: Dict[Segment, Operator] = {}
self._matching_operators: Dict[SquaredScale, OpMembers] = {}

def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
"""Generate the threshold operators.
Expand All @@ -123,7 +132,6 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
is_downward = is_downward_path(path)
shift = flavor_shift(is_downward)
for seg in path[:-1]:
new_op_key = astuple(seg)
kthr = self.config["thresholds_ratios"][seg.nf - shift]
ome = OperatorMatrixElement(
self.config,
Expand All @@ -134,13 +142,13 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
np.log(kthr),
self.config["HQ"] == "MSBAR",
)
if new_op_key not in self._threshold_operators:
if seg not in self._threshold_operators:
# Compute the operator and store it
logger.info("Prepare threshold operator")
op_th = Operator(self.config, self.managers, seg, is_threshold=True)
op_th.compute()
self._threshold_operators[new_op_key] = op_th
thr_ops.append(self._threshold_operators[new_op_key])
self._threshold_operators[seg] = op_th
thr_ops.append(self._threshold_operators[seg])

# Compute the matching conditions and store it
if seg.target not in self._matching_operators:
Expand All @@ -159,7 +167,7 @@ def generate(self, q2: EPoint) -> OpDict:

"""
# The lists of areas as produced by the thresholds
path = self.managers["thresholds_config"].path(q2)
path = self.managers.atlas.path(q2)
# Prepare the path for the composition of the operator
thr_ops = self.get_threshold_operators(path)
# we start composing with the highest operator ...
Expand Down
6 changes: 3 additions & 3 deletions src/eko/evolution_operator/operator_matrix_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class OperatorMatrixElement(Operator):

log_label = "Matching"
# complete list of possible matching operators labels
full_labels = [
full_labels = (
*br.singlet_labels,
(br.matching_hplus_pid, 21),
(br.matching_hplus_pid, 100),
Expand All @@ -238,7 +238,7 @@ class OperatorMatrixElement(Operator):
(200, br.matching_hminus_pid),
(br.matching_hminus_pid, 200),
(br.matching_hminus_pid, br.matching_hminus_pid),
]
)
# still valid in QED since Sdelta and Vdelta matchings are diagonal
full_labels_qed = copy.deepcopy(full_labels)

Expand Down Expand Up @@ -339,7 +339,7 @@ def a_s(self):

Note that here you need to use :math:`a_s^{n_f+1}`
"""
sc = self.managers["couplings"]
sc = self.managers.couplings
return sc.a_s(
self.q2_from
* (self.xif2 if self.sv_mode == sv.Modes.exponentiated else 1.0),
Expand Down
4 changes: 2 additions & 2 deletions src/eko/io/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import base64
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Dict, Generic, Literal, Optional, Type, TypeVar

import yaml

from .access import AccessConfigs
from .items import Header, Operator

NBYTES = 8
ENDIANNESS = "little"
ENDIANNESS: Literal["little", "big"] = "little"

HEADER_EXT = ".yaml"
ARRAY_EXT = [".npy", ".npz"]
Expand Down
40 changes: 27 additions & 13 deletions src/eko/io/legacy.py
felixhekhorn marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@
import numpy as np
import yaml

from eko.interpolation import XGrid
from eko.io.runcards import flavored_mugrid
from eko.quantities.heavy_quarks import HeavyInfo, HeavyQuarkMasses, MatchingRatios

from ..interpolation import XGrid
from ..io.runcards import flavored_mugrid
from ..quantities.heavy_quarks import (
HeavyInfo,
HeavyQuarkMasses,
MatchingRatios,
QuarkMassScheme,
)
from . import raw
from .dictlike import DictLike
from .struct import EKO, Operator
from .types import EvolutionPoint as EPoint
from .types import RawCard
from .types import RawCard, ReferenceRunning

_MC = 1.51
_MB = 4.92
_MT = 172.5


def load_tar(source: os.PathLike, dest: os.PathLike, errors: bool = False):
Expand All @@ -39,8 +47,8 @@ def load_tar(source: os.PathLike, dest: os.PathLike, errors: bool = False):
whether to load also errors (default ``False``)

"""
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
with tempfile.TemporaryDirectory() as tmpdirr:
tmpdir = pathlib.Path(tmpdirr)

with tarfile.open(source, "r") as tar:
raw.safe_extractall(tar, tmpdir)
Expand All @@ -60,7 +68,7 @@ def load_tar(source: os.PathLike, dest: os.PathLike, errors: bool = False):
if op5 is None:
op5 = metaold["mu2grid"]
grid = op5to4(
flavored_mugrid(op5, theory.heavy.masses, theory.heavy.matching_ratios), arrays
flavored_mugrid(op5, [_MC, _MB, _MT], theory.heavy.matching_ratios), arrays
)

with EKO.create(dest) as builder:
Expand Down Expand Up @@ -93,10 +101,16 @@ def from_old(cls, old: RawCard):
"""Load from old metadata."""
heavy = HeavyInfo(
num_flavs_init=4,
num_flavs_max_pdf=None,
intrinsic_flavors=None,
masses=HeavyQuarkMasses([1.51, 4.92, 172.5]),
masses_scheme=None,
num_flavs_max_pdf=5,
felixhekhorn marked this conversation as resolved.
Show resolved Hide resolved
intrinsic_flavors=[],
masses=HeavyQuarkMasses(
[
ReferenceRunning([_MC, np.inf]),
ReferenceRunning([_MB, np.inf]),
ReferenceRunning([_MT, np.inf]),
]
),
masses_scheme=QuarkMassScheme.POLE,
matching_ratios=MatchingRatios([1.0, 1.0, 1.0]),
)
return cls(heavy=heavy)
Expand Down Expand Up @@ -125,7 +139,7 @@ def from_old(cls, old: RawCard):
mu2list = old["mu2grid"]
mu2grid = np.array(mu2list)
evolgrid = flavored_mugrid(
np.sqrt(mu2grid).tolist(), [1.51, 4.92, 172.5], [1, 1, 1]
np.sqrt(mu2grid).tolist(), [_MC, _MB, _MT], [1, 1, 1]
)

xgrid = XGrid(old["interpolation_xgrid"])
Expand Down
28 changes: 15 additions & 13 deletions src/eko/io/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,17 @@ def approx(
Raises
------
ValueError
if multiple values are find in the neighbourhood
if multiple values are found in the neighbourhood

"""
eps = np.array([ep_ for ep_ in self if ep_[1] == ep[1]])
mu2s = np.array([mu2 for mu2, _ in eps])
close = eps[np.isclose(ep[0], mu2s, rtol=rtol, atol=atol)]

if len(close) == 1:
return tuple(close[0])
found = close[0]
assert isinstance(found[0], float)
return (found[0], int(found[1]))
if len(close) == 0:
return None
raise ValueError(f"Multiple values of Q2 have been found close to {ep}")
Expand Down Expand Up @@ -374,17 +376,17 @@ def open(cls, path: os.PathLike, mode="r"):
raise ValueError(f"Unknown file mode: {mode}")

tmpdir = pathlib.Path(tempfile.mkdtemp(prefix=TEMP_PREFIX))
if load:
cls.load(path, tmpdir)
metadata = Metadata.load(tmpdir)
opened = cls(
**inventories(tmpdir, access),
metadata=metadata,
access=access,
)
opened.operators.sync()
else:
opened = Builder(path=tmpdir, access=access)
if not load:
return Builder(path=tmpdir, access=access)
# load existing instead
cls.load(path, tmpdir)
metadata = Metadata.load(tmpdir)
opened: EKO = cls(
**inventories(tmpdir, access),
metadata=metadata,
access=access,
)
opened.operators.sync()

return opened

Expand Down
Loading
Loading