Skip to content

Commit

Permalink
refactor: make Sweeper pydantic Model
Browse files Browse the repository at this point in the history
  • Loading branch information
stavros11 committed Aug 28, 2024
1 parent fea2969 commit 5205f97
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
30 changes: 23 additions & 7 deletions src/qibolab/sweeper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional

import numpy as np
import numpy.typing as npt
from pydantic import model_validator

from .identifier import ChannelId
from .pulses import Pulse
from .serialize import Model


class Parameter(Enum):
Expand Down Expand Up @@ -38,8 +43,7 @@ class Parameter(Enum):
}


@dataclass
class Sweeper:
class Sweeper(Model):
"""Data structure for Sweeper object.
This object is passed as an argument to the method :func:`qibolab.platforms.platform.Platform.execute`
Expand Down Expand Up @@ -77,11 +81,13 @@ class Sweeper:
"""

parameter: Parameter
values: npt.NDArray
pulses: Optional[list] = None
channels: Optional[list] = None
values: Optional[npt.NDArray] = None
linspace: Optional[tuple[float, float, float]] = None
pulses: Optional[list[Pulse]] = None
channels: Optional[list[ChannelId]] = None

def __post_init__(self):
@model_validator(mode="after")
def check_values(self):
if self.pulses is not None and self.channels is not None:
raise ValueError(
"Cannot create a sweeper by using both pulses and channels."
Expand All @@ -98,6 +104,16 @@ def __post_init__(self):
raise ValueError(
"Cannot create a sweeper without specifying pulses or channels."
)
if self.linspace is not None and self.values is not None:
raise ValueError("'linspace' and 'values' are mutually exclusive")

return self

@property
def values_array(self) -> npt.NDArray:
if self.linspace is not None:
return np.linspace(*self.linspace)
return self.values


ParallelSweepers = list[Sweeper]
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,14 @@ def wrapped(
amp_values = np.arange(0.01, 0.06, 0.01)
freq_values = np.arange(-4e6, 4e6, 1e6)
sweeper1 = Sweeper(
Parameter.bias, amp_values, channels=[qubit.flux.name]
parameter=Parameter.bias,
values=amp_values,
channels=[qubit.flux.name],
)
sweeper2 = Sweeper(
Parameter.amplitude, freq_values, pulses=[probe_pulse]
parameter=Parameter.amplitude,
values=freq_values,
pulses=[probe_pulse],
)
sweepers = [[sweeper1], [sweeper2]]
if target is None:
Expand Down
30 changes: 22 additions & 8 deletions tests/test_sweeper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest

from qibolab.identifier import ChannelId
from qibolab.pulses import Pulse, Rectangular
from qibolab.sweeper import ChannelParameter, Parameter, Sweeper

Expand All @@ -20,26 +21,30 @@ def test_sweeper_pulses(parameter):
with pytest.raises(
ValueError, match="Cannot create a sweeper .* without specifying channels"
):
_ = Sweeper(parameter, parameter_range, pulses=[pulse])
_ = Sweeper(parameter=parameter, values=parameter_range, pulses=[pulse])
else:
sweeper = Sweeper(parameter, parameter_range, pulses=[pulse])
sweeper = Sweeper(parameter=parameter, values=parameter_range, pulses=[pulse])
assert sweeper.parameter is parameter


@pytest.mark.parametrize("parameter", Parameter)
def test_sweeper_channels(parameter):
channel = ChannelId.load("0/probe")
parameter_range = np.random.randint(10, size=10)
if parameter in ChannelParameter:
sweeper = Sweeper(parameter, parameter_range, channels=["some channel"])
sweeper = Sweeper(
parameter=parameter, values=parameter_range, channels=[channel]
)
assert sweeper.parameter is parameter
else:
with pytest.raises(
ValueError, match="Cannot create a sweeper .* without specifying pulses"
):
_ = Sweeper(parameter, parameter_range, channels=["canal"])
_ = Sweeper(parameter=parameter, values=parameter_range, channels=[channel])


def test_sweeper_errors():
channel = ChannelId.load("0/probe")
pulse = Pulse(
duration=40,
amplitude=0.1,
Expand All @@ -50,13 +55,22 @@ def test_sweeper_errors():
ValueError,
match="Cannot create a sweeper without specifying pulses or channels",
):
Sweeper(Parameter.frequency, parameter_range)
Sweeper(parameter=Parameter.frequency, values=parameter_range)
with pytest.raises(
ValueError, match="Cannot create a sweeper by using both pulses and channels"
):
Sweeper(
Parameter.frequency,
parameter_range,
parameter=Parameter.frequency,
values=parameter_range,
pulses=[pulse],
channels=["some channel"],
channels=[channel],
)
with pytest.raises(
ValueError, match="'linspace' and 'values' are mutually exclusive"
):
Sweeper(
parameter=Parameter.frequency,
values=parameter_range,
linspace=(0, 10, 1),
channels=[channel],
)

0 comments on commit 5205f97

Please sign in to comment.