From 5205f97478e3c1f9eaeafc13a802c5c96b63a72a Mon Sep 17 00:00:00 2001 From: Stavros Efthymiou <35475381+stavros11@users.noreply.github.com> Date: Wed, 28 Aug 2024 18:49:30 +0300 Subject: [PATCH] refactor: make Sweeper pydantic Model --- src/qibolab/sweeper.py | 30 +++++++++++++++++++++++------- tests/conftest.py | 8 ++++++-- tests/test_sweeper.py | 30 ++++++++++++++++++++++-------- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/src/qibolab/sweeper.py b/src/qibolab/sweeper.py index b02ca2c3e7..9736f30caa 100644 --- a/src/qibolab/sweeper.py +++ b/src/qibolab/sweeper.py @@ -1,8 +1,13 @@ -from dataclasses import dataclass from enum import Enum, auto from typing import Optional +import numpy as np import numpy.typing as npt +from pydantic import model_validator + +from .identifier import ChannelId +from .pulses import Pulse +from .serialize import Model class Parameter(Enum): @@ -38,8 +43,7 @@ class Parameter(Enum): } -@dataclass -class Sweeper: +class Sweeper(Model): """Data structure for Sweeper object. This object is passed as an argument to the method :func:`qibolab.platforms.platform.Platform.execute` @@ -77,11 +81,13 @@ class Sweeper: """ parameter: Parameter - values: npt.NDArray - pulses: Optional[list] = None - channels: Optional[list] = None + values: Optional[npt.NDArray] = None + linspace: Optional[tuple[float, float, float]] = None + pulses: Optional[list[Pulse]] = None + channels: Optional[list[ChannelId]] = None - def __post_init__(self): + @model_validator(mode="after") + def check_values(self): if self.pulses is not None and self.channels is not None: raise ValueError( "Cannot create a sweeper by using both pulses and channels." @@ -98,6 +104,16 @@ def __post_init__(self): raise ValueError( "Cannot create a sweeper without specifying pulses or channels." ) + if self.linspace is not None and self.values is not None: + raise ValueError("'linspace' and 'values' are mutually exclusive") + + return self + + @property + def values_array(self) -> npt.NDArray: + if self.linspace is not None: + return np.linspace(*self.linspace) + return self.values ParallelSweepers = list[Sweeper] diff --git a/tests/conftest.py b/tests/conftest.py index 0ace831064..25a5e088d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,10 +156,14 @@ def wrapped( amp_values = np.arange(0.01, 0.06, 0.01) freq_values = np.arange(-4e6, 4e6, 1e6) sweeper1 = Sweeper( - Parameter.bias, amp_values, channels=[qubit.flux.name] + parameter=Parameter.bias, + values=amp_values, + channels=[qubit.flux.name], ) sweeper2 = Sweeper( - Parameter.amplitude, freq_values, pulses=[probe_pulse] + parameter=Parameter.amplitude, + values=freq_values, + pulses=[probe_pulse], ) sweepers = [[sweeper1], [sweeper2]] if target is None: diff --git a/tests/test_sweeper.py b/tests/test_sweeper.py index 848f114112..a3e1229064 100644 --- a/tests/test_sweeper.py +++ b/tests/test_sweeper.py @@ -1,6 +1,7 @@ import numpy as np import pytest +from qibolab.identifier import ChannelId from qibolab.pulses import Pulse, Rectangular from qibolab.sweeper import ChannelParameter, Parameter, Sweeper @@ -20,26 +21,30 @@ def test_sweeper_pulses(parameter): with pytest.raises( ValueError, match="Cannot create a sweeper .* without specifying channels" ): - _ = Sweeper(parameter, parameter_range, pulses=[pulse]) + _ = Sweeper(parameter=parameter, values=parameter_range, pulses=[pulse]) else: - sweeper = Sweeper(parameter, parameter_range, pulses=[pulse]) + sweeper = Sweeper(parameter=parameter, values=parameter_range, pulses=[pulse]) assert sweeper.parameter is parameter @pytest.mark.parametrize("parameter", Parameter) def test_sweeper_channels(parameter): + channel = ChannelId.load("0/probe") parameter_range = np.random.randint(10, size=10) if parameter in ChannelParameter: - sweeper = Sweeper(parameter, parameter_range, channels=["some channel"]) + sweeper = Sweeper( + parameter=parameter, values=parameter_range, channels=[channel] + ) assert sweeper.parameter is parameter else: with pytest.raises( ValueError, match="Cannot create a sweeper .* without specifying pulses" ): - _ = Sweeper(parameter, parameter_range, channels=["canal"]) + _ = Sweeper(parameter=parameter, values=parameter_range, channels=[channel]) def test_sweeper_errors(): + channel = ChannelId.load("0/probe") pulse = Pulse( duration=40, amplitude=0.1, @@ -50,13 +55,22 @@ def test_sweeper_errors(): ValueError, match="Cannot create a sweeper without specifying pulses or channels", ): - Sweeper(Parameter.frequency, parameter_range) + Sweeper(parameter=Parameter.frequency, values=parameter_range) with pytest.raises( ValueError, match="Cannot create a sweeper by using both pulses and channels" ): Sweeper( - Parameter.frequency, - parameter_range, + parameter=Parameter.frequency, + values=parameter_range, pulses=[pulse], - channels=["some channel"], + channels=[channel], + ) + with pytest.raises( + ValueError, match="'linspace' and 'values' are mutually exclusive" + ): + Sweeper( + parameter=Parameter.frequency, + values=parameter_range, + linspace=(0, 10, 1), + channels=[channel], )