Skip to content

Commit

Permalink
Merge pull request #291 from NNPDF/mypy
Browse files Browse the repository at this point in the history
Add mypy
  • Loading branch information
felixhekhorn authored Aug 8, 2024
2 parents 48476f9 + deb41cd commit b8761a5
Show file tree
Hide file tree
Showing 26 changed files with 213 additions and 138 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ repos:
args: ["--add-ignore=D107,D105"]
additional_dependencies:
- toml
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
hooks:
- id: mypy
additional_dependencies: [types-PyYAML]
pass_filenames: false
args: ["--ignore-missing-imports", "src/"]
- repo: local
hooks:
- id: fmt
Expand Down
9 changes: 7 additions & 2 deletions src/eko/couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
"""

import logging
from typing import Iterable, List
from typing import Dict, Iterable, List, Tuple

import numba as nb
import numpy as np
import numpy.typing as npt
import scipy

from . import constants, matchings
Expand Down Expand Up @@ -383,6 +384,10 @@ def couplings_expanded_fixed_alphaem(order, couplings_ref, nf, scale_from, scale
return np.array([res_as, aem])


_CouplingsCacheKey = Tuple[float, float, int, float, float]
"""Cache key containing (a0, a1, nf, scale_from, scale_to)."""


class Couplings:
r"""Compute the strong and electromagnetic coupling constants :math:`a_s, a_{em}`.
Expand Down Expand Up @@ -480,7 +485,7 @@ def assert_positive(name, var):
self.decoupled_running,
)
# cache
self.cache = {}
self.cache: Dict[_CouplingsCacheKey, npt.NDArray] = {}

@property
def mu2_ref(self):
Expand Down
21 changes: 13 additions & 8 deletions src/eko/evolution_operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import time
from multiprocessing import Pool
from typing import Dict, Tuple

import numba as nb
import numpy as np
Expand All @@ -20,7 +21,7 @@
from .. import basis_rotation as br
from .. import interpolation, mellin
from .. import scale_variations as sv
from ..io.types import EvolutionMethod
from ..io.types import EvolutionMethod, OperatorLabel
from ..kernels import ev_method
from ..kernels import non_singlet as ns
from ..kernels import non_singlet_qed as qed_ns
Expand Down Expand Up @@ -602,6 +603,10 @@ def quad_ker_qed(
return ker


OpMembers = Dict[OperatorLabel, OpMember]
"""Map of all operators."""


class Operator(sv.ScaleVariationModeMixin):
"""Internal representation of a single EKO.
Expand All @@ -627,8 +632,8 @@ class Operator(sv.ScaleVariationModeMixin):

log_label = "Evolution"
# complete list of possible evolution operators labels
full_labels = br.full_labels
full_labels_qed = br.full_unified_labels
full_labels: Tuple[OperatorLabel, ...] = br.full_labels
full_labels_qed: Tuple[OperatorLabel, ...] = br.full_unified_labels

def __init__(
self, config, managers, segment: Segment, mellin_cut=5e-2, is_threshold=False
Expand All @@ -641,9 +646,9 @@ def __init__(
# TODO make 'cut' external parameter?
self._mellin_cut = mellin_cut
self.is_threshold = is_threshold
self.op_members = {}
self.op_members: OpMembers = {}
self.order = tuple(config["order"])
self.alphaem_running = self.managers["couplings"].alphaem_running
self.alphaem_running = self.managers.couplings.alphaem_running
if self.log_label == "Evolution":
self.a = self.compute_a()
self.as_list, self.a_half_list = self.compute_aem_list()
Expand All @@ -665,7 +670,7 @@ def xif2(self):
@property
def int_disp(self):
"""Return the interpolation dispatcher."""
return self.managers["interpol_dispatcher"]
return self.managers.interpolator

@property
def grid_size(self):
Expand All @@ -688,7 +693,7 @@ def mu2(self):

def compute_a(self):
"""Return the computed values for :math:`a_s` and :math:`a_{em}`."""
coupling = self.managers["couplings"]
coupling = self.managers.couplings
a0 = coupling.a(
self.mu2[0],
nf_to=self.nf,
Expand Down Expand Up @@ -724,7 +729,7 @@ def compute_aem_list(self):
as_list = np.array([self.a_s[0], self.a_s[1]])
a_half = np.zeros((ev_op_iterations, 2))
else:
couplings = self.managers["couplings"]
couplings = self.managers.couplings
mu2_steps = np.geomspace(self.q2_from, self.q2_to, 1 + ev_op_iterations)
mu2_l = mu2_steps[0]
as_list = np.array(
Expand Down
20 changes: 10 additions & 10 deletions src/eko/evolution_operator/__init__.py.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/src/eko/evolution_operator/__init__.py b/src/eko/evolution_operator/__init__.py
index ec56b6db..374d0d0b 100644
index fe07ade9..0f58c9e5 100644
--- a/src/eko/evolution_operator/__init__.py
+++ b/src/eko/evolution_operator/__init__.py
@@ -3,15 +3,15 @@ r"""Contains the central operator classes.
@@ -3,16 +3,16 @@ r"""Contains the central operator classes.
See :doc:`Operator overview </code/Operators>`.
"""

Expand All @@ -11,6 +11,7 @@ index ec56b6db..374d0d0b 100644
import os
import time
from multiprocessing import Pool
from typing import Dict, Tuple

+import ekors
import numba as nb
Expand All @@ -20,7 +21,7 @@ index ec56b6db..374d0d0b 100644

import ekore.anomalous_dimensions.polarized.space_like as ad_ps
import ekore.anomalous_dimensions.unpolarized.space_like as ad_us
@@ -29,92 +29,10 @@ from ..kernels import singlet_qed as qed_s
@@ -30,92 +30,10 @@ from ..kernels import singlet_qed as qed_s
from ..kernels import valence_qed as qed_v
from ..matchings import Segment, lepton_number
from ..member import OpMember
Expand Down Expand Up @@ -114,7 +115,7 @@ index ec56b6db..374d0d0b 100644
spec = [
("is_singlet", nb.boolean),
("is_QEDsinglet", nb.boolean),
@@ -186,422 +104,6 @@ class QuadKerBase:
@@ -187,421 +105,6 @@ class QuadKerBase:
return self.path.prefactor * pj * self.path.jac


Expand Down Expand Up @@ -533,11 +534,10 @@ index ec56b6db..374d0d0b 100644
- )
- return ker
-
-
class Operator(sv.ScaleVariationModeMixin):
"""Internal representation of a single EKO.

@@ -787,50 +289,6 @@ class Operator(sv.ScaleVariationModeMixin):
OpMembers = Dict[OperatorLabel, OpMember]
"""Map of all operators."""
@@ -792,50 +295,6 @@ class Operator(sv.ScaleVariationModeMixin):
"""Return the evolution method."""
return ev_method(EvolutionMethod(self.config["method"]))

Expand Down Expand Up @@ -588,7 +588,7 @@ index ec56b6db..374d0d0b 100644
def initialize_op_members(self):
"""Init all operators with the identity or zeros."""
eye = OpMember(
@@ -853,10 +311,7 @@ class Operator(sv.ScaleVariationModeMixin):
@@ -858,10 +317,7 @@ class Operator(sv.ScaleVariationModeMixin):
else:
self.op_members[n] = zero.copy()

Expand All @@ -600,7 +600,7 @@ index ec56b6db..374d0d0b 100644
"""Run the integration for each grid point.

Parameters
@@ -871,18 +326,56 @@ class Operator(sv.ScaleVariationModeMixin):
@@ -876,18 +332,56 @@ class Operator(sv.ScaleVariationModeMixin):
"""
column = []
k, logx = log_grid
Expand Down
38 changes: 23 additions & 15 deletions src/eko/evolution_operator/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"""

import logging
from dataclasses import astuple
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

import numpy as np
import numpy.typing as npt
Expand All @@ -18,9 +18,9 @@
from ..interpolation import InterpolatorDispatcher
from ..io.runcards import Configs, Debug
from ..io.types import EvolutionPoint as EPoint
from ..io.types import Order
from ..io.types import Order, SquaredScale
from ..matchings import Atlas, Segment, flavor_shift, is_downward_path
from . import Operator, flavors, matching_condition, physical
from . import Operator, OpMembers, flavors, matching_condition, physical
from .operator_matrix_element import OperatorMatrixElement

logger = logging.getLogger(__name__)
Expand All @@ -29,6 +29,15 @@
"""In particular, only the ``operator`` and ``error`` fields are expected."""


@dataclass(frozen=True)
class Managers:
"""Set of steering objects."""

atlas: Atlas
couplings: Couplings
interpolator: InterpolatorDispatcher


class OperatorGrid(sv.ScaleVariationModeMixin):
"""Collection of evolution operators for several scales.
Expand Down Expand Up @@ -64,7 +73,7 @@ def __init__(
use_fhmruvv: bool,
):
# check
config = {}
config: Dict[str, Any] = {}
config["order"] = order
config["intrinsic_range"] = intrinsic_flavors
config["xif2"] = xif**2
Expand Down Expand Up @@ -95,13 +104,13 @@ def __init__(

self.config = config
self.q2_grid = mu2grid
self.managers = dict(
thresholds_config=atlas,
self.managers = Managers(
atlas=atlas,
couplings=couplings,
interpol_dispatcher=interpol_dispatcher,
interpolator=interpol_dispatcher,
)
self._threshold_operators = {}
self._matching_operators = {}
self._threshold_operators: Dict[Segment, Operator] = {}
self._matching_operators: Dict[SquaredScale, OpMembers] = {}

def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
"""Generate the threshold operators.
Expand All @@ -123,7 +132,6 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
is_downward = is_downward_path(path)
shift = flavor_shift(is_downward)
for seg in path[:-1]:
new_op_key = astuple(seg)
kthr = self.config["thresholds_ratios"][seg.nf - shift]
ome = OperatorMatrixElement(
self.config,
Expand All @@ -134,13 +142,13 @@ def get_threshold_operators(self, path: List[Segment]) -> List[Operator]:
np.log(kthr),
self.config["HQ"] == "MSBAR",
)
if new_op_key not in self._threshold_operators:
if seg not in self._threshold_operators:
# Compute the operator and store it
logger.info("Prepare threshold operator")
op_th = Operator(self.config, self.managers, seg, is_threshold=True)
op_th.compute()
self._threshold_operators[new_op_key] = op_th
thr_ops.append(self._threshold_operators[new_op_key])
self._threshold_operators[seg] = op_th
thr_ops.append(self._threshold_operators[seg])

# Compute the matching conditions and store it
if seg.target not in self._matching_operators:
Expand All @@ -159,7 +167,7 @@ def generate(self, q2: EPoint) -> OpDict:
"""
# The lists of areas as produced by the thresholds
path = self.managers["thresholds_config"].path(q2)
path = self.managers.atlas.path(q2)
# Prepare the path for the composition of the operator
thr_ops = self.get_threshold_operators(path)
# we start composing with the highest operator ...
Expand Down
6 changes: 3 additions & 3 deletions src/eko/evolution_operator/operator_matrix_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class OperatorMatrixElement(Operator):

log_label = "Matching"
# complete list of possible matching operators labels
full_labels = [
full_labels = (
*br.singlet_labels,
(br.matching_hplus_pid, 21),
(br.matching_hplus_pid, 100),
Expand All @@ -238,7 +238,7 @@ class OperatorMatrixElement(Operator):
(200, br.matching_hminus_pid),
(br.matching_hminus_pid, 200),
(br.matching_hminus_pid, br.matching_hminus_pid),
]
)
# still valid in QED since Sdelta and Vdelta matchings are diagonal
full_labels_qed = copy.deepcopy(full_labels)

Expand Down Expand Up @@ -339,7 +339,7 @@ def a_s(self):
Note that here you need to use :math:`a_s^{n_f+1}`
"""
sc = self.managers["couplings"]
sc = self.managers.couplings
return sc.a_s(
self.q2_from
* (self.xif2 if self.sv_mode == sv.Modes.exponentiated else 1.0),
Expand Down
4 changes: 2 additions & 2 deletions src/eko/io/inventory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import base64
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Generic, Optional, Type, TypeVar
from typing import Dict, Generic, Literal, Optional, Type, TypeVar

import yaml

from .access import AccessConfigs
from .items import Header, Operator

NBYTES = 8
ENDIANNESS = "little"
ENDIANNESS: Literal["little", "big"] = "little"

HEADER_EXT = ".yaml"
ARRAY_EXT = [".npy", ".npz"]
Expand Down
Loading

0 comments on commit b8761a5

Please sign in to comment.