From db0d13e15d1c561507c32dcf4552db4fafb38290 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 31 Jan 2024 11:37:27 +1000 Subject: [PATCH] BREAKING: generic object return for `full_ouput=True` Instead of returning a dict that loses the type of the implied probabilities, we can return an object that is generic on the implied probability return type. This means that as far as mypy is concerned, the following two are equivalent: ```py calculate_implied_probabilities([2.0, 2.0]) calculate_implied_probabilities([2.0, 2.0], full_output=True).implied_probabilities ``` --- python/shin/__init__.py | 63 ++++++++++++++++++++++++++++++---------- tests/test_shin.py | 28 +++++++++--------- typesafety/test_shin.yml | 4 +-- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/python/shin/__init__.py b/python/shin/__init__.py index 5d2f05e..ff4de8a 100644 --- a/python/shin/__init__.py +++ b/python/shin/__init__.py @@ -1,12 +1,14 @@ from __future__ import annotations from collections.abc import Mapping, Sequence +from dataclasses import dataclass from math import sqrt -from typing import Any, Literal, TypeVar, overload +from typing import Any, Generic, Literal, TypeAlias, TypeVar, overload from .shin import optimise as _optimise_rust T = TypeVar("T") +OutputT = TypeVar("OutputT", bound="list[float] | dict[Any, float]") def _optimise( @@ -33,6 +35,14 @@ def _optimise( return z, delta, iterations +@dataclass +class FullOutput(Generic[OutputT]): + implied_probabilities: OutputT + iterations: float + delta: float + z: float + + # sequence input, full output False @overload def calculate_implied_probabilities( @@ -59,16 +69,29 @@ def calculate_implied_probabilities( ... -# full output True +# sequence, full output True @overload def calculate_implied_probabilities( - odds: Sequence[float] | Mapping[Any, float], + odds: Sequence[float], *, max_iterations: int = ..., convergence_threshold: float = ..., full_output: Literal[True], force_python_optimiser: bool = ..., -) -> dict[str, Any]: +) -> FullOutput[list[float]]: + ... + + +# mapping, full output True +@overload +def calculate_implied_probabilities( + odds: Mapping[T, float], + *, + max_iterations: int = ..., + convergence_threshold: float = ..., + full_output: Literal[True], + force_python_optimiser: bool = ..., +) -> FullOutput[dict[T, float]]: ... @@ -79,7 +102,7 @@ def calculate_implied_probabilities( convergence_threshold: float = 1e-12, full_output: bool = False, force_python_optimiser: bool = False, -) -> dict[str, Any] | list[float] | dict[T, float]: +) -> FullOutput[list[float]] | FullOutput[dict[T, float]] | list[float] | dict[T, float]: odds_seq = odds.values() if isinstance(odds, Mapping) else odds if len(odds_seq) < 2: @@ -110,19 +133,27 @@ def calculate_implied_probabilities( convergence_threshold=convergence_threshold, ) - p: list[float] | dict[Any, float] = [ + p_gen = ( (sqrt(z**2 + 4 * (1 - z) * io**2 / sum_inverse_odds) - z) / (2 * (1 - z)) for io in inverse_odds - ] + ) if isinstance(odds, Mapping): - p = {k: v for k, v in zip(odds, p)} + d = {k: v for k, v in zip(odds, p_gen)} + if full_output: + return FullOutput( + implied_probabilities=d, + iterations=iterations, + delta=delta, + z=z, + ) + return d + l = list(p_gen) if full_output: - return { - "implied_probabilities": p, - "iterations": iterations, - "delta": delta, - "z": z, - } - else: - return p + return FullOutput( + implied_probabilities=l, + iterations=iterations, + delta=delta, + z=z, + ) + return l diff --git a/tests/test_shin.py b/tests/test_shin.py index 6b2d7a7..29ac4ab 100644 --- a/tests/test_shin.py +++ b/tests/test_shin.py @@ -15,19 +15,19 @@ def test_calculate_implied_probabilities(): result = shin.calculate_implied_probabilities([2.6, 2.4, 4.3], full_output=True) - assert pytest.approx(0.3729941) == result["implied_probabilities"][0] - assert pytest.approx(0.4047794) == result["implied_probabilities"][1] - assert pytest.approx(0.2222265) == result["implied_probabilities"][2] - assert pytest.approx(0.01694251) == result["z"] + assert pytest.approx(0.3729941) == result.implied_probabilities[0] + assert pytest.approx(0.4047794) == result.implied_probabilities[1] + assert pytest.approx(0.2222265) == result.implied_probabilities[2] + assert pytest.approx(0.01694251) == result.z result = shin.calculate_implied_probabilities( [2.6, 2.4, 4.3], full_output=True, force_python_optimiser=True ) - assert pytest.approx(0.3729941) == result["implied_probabilities"][0] - assert pytest.approx(0.4047794) == result["implied_probabilities"][1] - assert pytest.approx(0.2222265) == result["implied_probabilities"][2] - assert pytest.approx(0.01694251) == result["z"] + assert pytest.approx(0.3729941) == result.implied_probabilities[0] + assert pytest.approx(0.4047794) == result.implied_probabilities[1] + assert pytest.approx(0.2222265) == result.implied_probabilities[2] + assert pytest.approx(0.01694251) == result.z result = shin.calculate_implied_probabilities([2.6, 2.4, 4.3]) assert pytest.approx(0.3729941) == result[0] @@ -50,23 +50,23 @@ def test_calculate_implied_probabilities(): "HOME": pytest.approx(0.3729941), "AWAY": pytest.approx(0.4047794), "DRAW": pytest.approx(0.2222265), - } == result["implied_probabilities"] - assert pytest.approx(0.01694251) == result["z"] + } == result.implied_probabilities + assert pytest.approx(0.01694251) == result.z odds = [1.5, 2.74] inverse_odds = [1 / o for o in odds] sum_inverse_odds = sum(inverse_odds) result = shin.calculate_implied_probabilities(odds, full_output=True) - assert result["iterations"] == 0 - assert result["delta"] == 0 + assert result.iterations == 0 + assert result.delta == 0 # With two outcomes, Shin is equivalent to the Additive Method described in Clarke et al. (2017) assert ( pytest.approx(inverse_odds[0] - (sum_inverse_odds - 1) / 2) - == result["implied_probabilities"][0] + == result.implied_probabilities[0] ) assert ( pytest.approx(inverse_odds[1] - (sum_inverse_odds - 1) / 2) - == result["implied_probabilities"][1] + == result.implied_probabilities[1] ) diff --git a/typesafety/test_shin.yml b/typesafety/test_shin.yml index aaa8c37..6f8cf30 100644 --- a/typesafety/test_shin.yml +++ b/typesafety/test_shin.yml @@ -28,7 +28,7 @@ reveal_type(shin.calculate_implied_probabilities([3.0, 3.0, 3.0], full_output=True)) out: | - main:3: note: Revealed type is "builtins.dict[builtins.str, Any]" + main:3: note: Revealed type is "shin.FullOutput[builtins.list[builtins.float]]" - case: test_mapping_input_full_output_overload main: | @@ -36,4 +36,4 @@ reveal_type(shin.calculate_implied_probabilities({1: 3.0, 2: 3.0, 3: 3.0}, full_output=True)) out: | - main:3: note: Revealed type is "builtins.dict[builtins.str, Any]" + main:3: note: Revealed type is "shin.FullOutput[builtins.dict[builtins.int, builtins.float]]"