Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy #291

Merged
merged 29 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bb523b4
Add mypy
felixhekhorn Jul 21, 2023
103567c
Restrict mypy and start fixing
felixhekhorn Jul 21, 2023
e7b6e93
Start fixing tests for num_flavs_ref=None
felixhekhorn Jul 21, 2023
bbfc166
Apply more mypy fixes
felixhekhorn Jul 21, 2023
2d8d8bf
Apply more mypy fixes 2
felixhekhorn Jul 21, 2023
494da3c
Rename OpMembers
felixhekhorn Jul 21, 2023
48ac90e
Fix mypy in io/legacy
felixhekhorn Jul 21, 2023
bab8cf2
Fix mypy in box/{genpdf,apply}
felixhekhorn Jul 21, 2023
e61c8e6
Fix ekomark/plots
felixhekhorn Jul 24, 2023
e40a718
Fix ekobox/cli
felixhekhorn Jul 24, 2023
9271ebc
Make managers class
felixhekhorn Jul 24, 2023
80b85f1
Use backported return in apply
felixhekhorn Jul 24, 2023
34c68a5
Fix eko/io/legacy masses usage
felixhekhorn Jul 24, 2023
d6d1734
Fix remaining mypy errors
felixhekhorn Jul 24, 2023
6e0f434
Remove init from benchmarks
felixhekhorn Jul 25, 2023
e8c218e
Drop new_op_key variable in grid
felixhekhorn Aug 17, 2023
54d9e53
Define couplings cache key type
felixhekhorn Aug 17, 2023
9b0d804
Define label type in apply
felixhekhorn Aug 17, 2023
3b57600
Cast labels in evop/Operator
felixhekhorn Aug 17, 2023
bbe955c
Update src/eko/runner/operators.py
felixhekhorn Aug 17, 2023
16e3809
Introduce OperatorLabel type
felixhekhorn Aug 17, 2023
c9cdc72
Remove list comprension in msbar
felixhekhorn Aug 17, 2023
c946b3f
Upgrade banana
felixhekhorn Aug 17, 2023
0451697
Merge branch 'master' into mypy
felixhekhorn Jan 12, 2024
a49fca6
Fix evol_pdf
felixhekhorn Jan 12, 2024
4be6d44
Merge branch 'master' into mypy
felixhekhorn Jul 15, 2024
0d13e35
Merge branch 'master' into mypy
felixhekhorn Aug 8, 2024
b651ebe
Fix Rust ev_op patch
felixhekhorn Aug 8, 2024
deb41cd
Merge branch 'master' into mypy
felixhekhorn Aug 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ repos:
args: ["--add-ignore=D107,D105"]
additional_dependencies:
- toml
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML]
pass_filenames: false
args: ["--ignore-missing-imports", "src/"]
- repo: local
hooks:
- id: fmt
Expand Down
9 changes: 7 additions & 2 deletions src/eko/couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
"""

import logging
from typing import Iterable, List
from typing import Dict, Iterable, List, Tuple

import numba as nb
import numpy as np
import numpy.typing as npt
import scipy

from . import constants, matchings
Expand Down Expand Up @@ -383,6 +384,10 @@ def couplings_expanded_fixed_alphaem(order, couplings_ref, nf, scale_from, scale
return np.array([res_as, aem])


_CouplingsCacheKey = Tuple[float, float, int, float, float]
"""Cache key containing (a0, a1, nf, scale_from, scale_to)."""


class Couplings:
r"""Compute the strong and electromagnetic coupling constants :math:`a_s, a_{em}`.

Expand Down Expand Up @@ -480,7 +485,7 @@ def assert_positive(name, var):
self.decoupled_running,
)
# cache
self.cache = {}
self.cache: Dict[_CouplingsCacheKey, npt.NDArray] = {}

@property
def mu2_ref(self):
Expand Down
21 changes: 13 additions & 8 deletions src/eko/evolution_operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from multiprocessing import Pool
from typing import Dict, Tuple

import numba as nb
import numpy as np
Expand All @@ -20,7 +21,7 @@
from .. import basis_rotation as br
from .. import interpolation, mellin
from .. import scale_variations as sv
from ..io.types import EvolutionMethod
from ..io.types import EvolutionMethod, OperatorLabel
from ..kernels import ev_method
from ..kernels import non_singlet as ns
from ..kernels import non_singlet_qed as qed_ns
Expand Down Expand Up @@ -602,6 +603,10 @@ def quad_ker_qed(
return ker


OpMembers = Dict[OperatorLabel, OpMember]
"""Map of all operators."""


class Operator(sv.ScaleVariationModeMixin):
"""Internal representation of a single EKO.

Expand All @@ -627,8 +632,8 @@ class Operator(sv.ScaleVariationModeMixin):

log_label = "Evolution"
# complete list of possible evolution operators labels
full_labels = br.full_labels
full_labels_qed = br.full_unified_labels
full_labels: Tuple[OperatorLabel, ...] = br.full_labels
full_labels_qed: Tuple[OperatorLabel, ...] = br.full_unified_labels

def __init__(
self, config, managers, segment: Segment, mellin_cut=5e-2, is_threshold=False
Expand All @@ -641,9 +646,9 @@ def __init__(
# TODO make 'cut' external parameter?
self._mellin_cut = mellin_cut
self.is_threshold = is_threshold
self.op_members = {}
self.op_members: OpMembers = {}
self.order = tuple(config["order"])
self.alphaem_running = self.managers["couplings"].alphaem_running
self.alphaem_running = self.managers.couplings.alphaem_running
if self.log_label == "Evolution":
self.a = self.compute_a()
self.as_list, self.a_half_list = self.compute_aem_list()
Expand All @@ -665,7 +670,7 @@ def xif2(self):
@property
def int_disp(self):
"""Return the interpolation dispatcher."""
return self.managers["interpol_dispatcher"]
return self.managers.interpolator

@property
def grid_size(self):
Expand All @@ -688,7 +693,7 @@ def mu2(self):

def compute_a(self):
"""Return the computed values for :math:`a_s` and :math:`a_{em}`."""
coupling = self.managers["couplings"]
coupling = self.managers.couplings
a0 = coupling.a(
self.mu2[0],
nf_to=self.nf,
Expand Down Expand Up @@ -724,7 +729,7 @@ def compute_aem_list(self):
as_list = np.array([self.a_s[0], self.a_s[1]])
a_half = np.zeros((ev_op_iterations, 2))
else:
couplings = self.managers["couplings"]
couplings = self.managers.couplings
mu2_steps = np.geomspace(self.q2_from, self.q2_to, 1 + ev_op_iterations)
mu2_l = mu2_steps[0]
as_list = np.array(
Expand Down
20 changes: 10 additions & 10 deletions src/eko/evolution_operator/__init__.py.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/src/eko/evolution_operator/__init__.py b/src/eko/evolution_operator/__init__.py
index ec56b6db..374d0d0b 100644
index fe07ade9..0f58c9e5 100644
--- a/src/eko/evolution_operator/__init__.py
+++ b/src/eko/evolution_operator/__init__.py
@@ -3,15 +3,15 @@ r"""Contains the central operator classes.
@@ -3,16 +3,16 @@ r"""Contains the central operator classes.
See :doc:`Operator overview </code/Operators>`.
"""

Expand All @@ -11,6 +11,7 @@ index ec56b6db..374d0d0b 100644
import os
import time
from multiprocessing import Pool
from typing import Dict, Tuple

+import ekors
import numba as nb
Expand All @@ -20,7 +21,7 @@ index ec56b6db..374d0d0b 100644

import ekore.anomalous_dimensions.polarized.space_like as ad_ps
import ekore.anomalous_dimensions.unpolarized.space_like as ad_us
@@ -29,92 +29,10 @@ from ..kernels import singlet_qed as qed_s
@@ -30,92 +30,10 @@ from ..kernels import singlet_qed as qed_s
from ..kernels import valence_qed as qed_v
from ..matchings import Segment, lepton_number
from ..member import OpMember
Expand Down Expand Up @@ -114,7 +115,7 @@ index ec56b6db..374d0d0b 100644
spec = [
("is_singlet", nb.boolean),
("is_QEDsinglet", nb.boolean),
@@ -186,422 +104,6 @@ class QuadKerBase:
@@ -187,421 +105,6 @@ class QuadKerBase:
return self.path.prefactor * pj * self.path.jac


Expand Down Expand Up @@ -533,11 +534,10 @@ index ec56b6db..374d0d0b 100644
- )
- return ker
-
-
class Operator(sv.ScaleVariationModeMixin):
"""Internal representation of a single EKO.

@@ -787,50 +289,6 @@ class Operator(sv.ScaleVariationModeMixin):
OpMembers = Dict[OperatorLabel, OpMember]
"""Map of all operators."""
@@ -792,50 +295,6 @@ class Operator(sv.ScaleVariationModeMixin):
"""Return the evolution method."""
return ev_method(EvolutionMethod(self.config["method"]))

Expand Down Expand Up @@ -588,7 +588,7 @@ index ec56b6db..374d0d0b 100644
def initialize_op_members(self):
"""Init all operators with the identity or zeros."""
eye = OpMember(
@@ -853,10 +311,7 @@ class Operator(sv.ScaleVariationModeMixin):
@@ -858,10 +317,7 @@ class Operator(sv.ScaleVariationModeMixin):
else:
self.op_members[n] = zero.copy()

Expand All @@ -600,7 +600,7 @@ index ec56b6db..374d0d0b 100644
"""Run the integration for each grid point.

Parameters
@@ -871,18 +326,56 @@ class Operator(sv.ScaleVariationModeMixin):
@@ -876,18 +332,56 @@ class Operator(sv.ScaleVariationModeMixin):
"""
column = []
k, logx = log_grid
Expand Down
38 changes: 23 additions & 15 deletions src/eko/evolution_operator/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""

import logging
from dataclasses import astuple
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -18,9 +18,9 @@
from ..interpolation import InterpolatorDispatcher
from ..io.runcards import Configs, Debug
from ..io.types import EvolutionPoint as EPoint
from ..io.types import Order
from ..io.types import Order, SquaredScale
from ..matchings import Atlas, Segment, flavor_shift, is_downward_path
from . import Operator, flavors, matching_condition, physical
from . import Operator, OpMembers, flavors, matching_condition, physical
from .operator_matrix_element import OperatorMatrixElement

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,15 @@
"""In particular, only the ``operator`` and ``error`` fields are expected."""


@dataclass(frozen=True)
class Managers:
"""Set of steering objects."""

atlas: Atlas
couplings: Couplings
interpolator: InterpolatorDispatcher


class OperatorGrid(sv.ScaleVariationModeMixin):
"""Collection of evolution operators for several scales.

Expand Down Expand Up @@ -64,7 +73,7 @@ def __init__(
use_fhmruvv: bool,
):
# check
config = {}
config: Dict[str, Any] = {}
config["order"] = order
config["intrinsic_range"] = intrinsic_flavors
config["xif2"] = xif**2
Expand Down Expand Up @@ -95,13 +104,13 @@ def __init__(

self.config = config
self.q2_grid = mu2grid
self.managers = dict(
thresholds_config=atlas,
self.managers = Managers(
atlas=atlas,
couplings=couplings,
interpol_dispatcher=interpol_dispatcher,
interpolator=interpol_dispatcher,
)
self._threshold_operators = {}
self._matching_operators = {}
self._threshold_operators: Dict[Segment, Operator] = {}
self._matching_operators: Dict[SquaredScale, OpMembers] = {}

def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
"""Generate the threshold operators.
Expand All @@ -123,7 +132,6 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
is_downward = is_downward_path(path)
shift = flavor_shift(is_downward)
for seg in path[:-1]:
new_op_key = astuple(seg)
kthr = self.config["thresholds_ratios"][seg.nf - shift]
ome = OperatorMatrixElement(
self.config,
Expand All @@ -134,13 +142,13 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
np.log(kthr),
self.config["HQ"] == "MSBAR",
)
if new_op_key not in self._threshold_operators:
if seg not in self._threshold_operators:
# Compute the operator and store it
logger.info("Prepare threshold operator")
op_th = Operator(self.config, self.managers, seg, is_threshold=True)
op_th.compute()
self._threshold_operators[new_op_key] = op_th
thr_ops.append(self._threshold_operators[new_op_key])
self._threshold_operators[seg] = op_th
thr_ops.append(self._threshold_operators[seg])

# Compute the matching conditions and store it
if seg.target not in self._matching_operators:
Expand All @@ -159,7 +167,7 @@ def generate(self, q2: EPoint) -> OpDict:

"""
# The lists of areas as produced by the thresholds
path = self.managers["thresholds_config"].path(q2)
path = self.managers.atlas.path(q2)
# Prepare the path for the composition of the operator
thr_ops = self.get_threshold_operators(path)
# we start composing with the highest operator ...
Expand Down
6 changes: 3 additions & 3 deletions src/eko/evolution_operator/operator_matrix_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class OperatorMatrixElement(Operator):

log_label = "Matching"
# complete list of possible matching operators labels
full_labels = [
full_labels = (
*br.singlet_labels,
(br.matching_hplus_pid, 21),
(br.matching_hplus_pid, 100),
Expand All @@ -238,7 +238,7 @@ class OperatorMatrixElement(Operator):
(200, br.matching_hminus_pid),
(br.matching_hminus_pid, 200),
(br.matching_hminus_pid, br.matching_hminus_pid),
]
)
# still valid in QED since Sdelta and Vdelta matchings are diagonal
full_labels_qed = copy.deepcopy(full_labels)

Expand Down Expand Up @@ -339,7 +339,7 @@ def a_s(self):

Note that here you need to use :math:`a_s^{n_f+1}`
"""
sc = self.managers["couplings"]
sc = self.managers.couplings
return sc.a_s(
self.q2_from
* (self.xif2 if self.sv_mode == sv.Modes.exponentiated else 1.0),
Expand Down
4 changes: 2 additions & 2 deletions src/eko/io/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import base64
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Dict, Generic, Literal, Optional, Type, TypeVar

import yaml

from .access import AccessConfigs
from .items import Header, Operator

NBYTES = 8
ENDIANNESS = "little"
ENDIANNESS: Literal["little", "big"] = "little"

HEADER_EXT = ".yaml"
ARRAY_EXT = [".npy", ".npz"]
Expand Down
Loading
Loading