Skip to content

Commit

Permalink
Merge pull request #417 from qiboteam/update-pulses
Browse files Browse the repository at this point in the history
Pulses: change order, check equality
  • Loading branch information
rodolfocarobene authored May 19, 2023
2 parents 59a7748 + 3eb9861 commit 047f013
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 5 deletions.
49 changes: 48 additions & 1 deletion src/qibolab/pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def modulated_waveforms(self):
modulated_waveform_q.serial = f"Modulated_Waveform_Q(num_samples = {num_samples}, amplitude = {format(pulse.amplitude, '.6f').rstrip('0').rstrip('.')}, shape = {str(pulse.shape)}, frequency = {format(pulse.frequency, '_')}, phase = {format(global_phase + pulse.relative_phase, '.6f').rstrip('0').rstrip('.')})"
return (modulated_waveform_i, modulated_waveform_q)

def __eq__(self, item) -> bool:
"""Overloads == operator"""
return type(item) is self.__class__


class Rectangular(PulseShape):
"""
Expand Down Expand Up @@ -227,6 +231,12 @@ def __init__(self, rel_sigma: float):
self.pulse: Pulse = None
self.rel_sigma: float = float(rel_sigma)

def __eq__(self, item) -> bool:
"""Overloads == operator"""
if super().__eq__(item):
return self.rel_sigma == item.rel_sigma
return False

@property
def envelope_waveform_i(self) -> Waveform:
"""The envelope waveform of the i component of the pulse."""
Expand Down Expand Up @@ -281,6 +291,12 @@ def __init__(self, rel_sigma, beta):
self.rel_sigma = float(rel_sigma)
self.beta = float(beta)

def __eq__(self, item) -> bool:
"""Overloads == operator"""
if super().__eq__(item):
return self.rel_sigma == item.rel_sigma and self.beta == item.beta
return False

@property
def envelope_waveform_i(self) -> Waveform:
"""The envelope waveform of the i component of the pulse."""
Expand Down Expand Up @@ -350,6 +366,12 @@ def __init__(self, b, a, target: PulseShape):
self.b: np.ndarray = np.array(b)
# Check len(a) = len(b) = 2

def __eq__(self, item) -> bool:
"""Overloads == operator"""
if super().__eq__(item):
return self.target == item.target and (self.a == item.a).all() and (self.b == item.b).all()
return False

@property
def pulse(self):
return self._pulse
Expand Down Expand Up @@ -422,6 +444,12 @@ def __init__(self, t_half_flux_pulse=None, b_amplitude=1):
self.t_half_flux_pulse: float = t_half_flux_pulse
self.b_amplitude: float = b_amplitude

def __eq__(self, item) -> bool:
"""Overloads == operator"""
if super().__eq__(item):
return self.t_half_flux_pulse == item.t_half_flux_pulse and self.b_amplitude == item.b_amplitude
return False

@property
def envelope_waveform_i(self) -> Waveform:
"""The envelope waveform of the i component of the pulse."""
Expand Down Expand Up @@ -485,6 +513,12 @@ def __init__(self, alpha: float):
self.pulse: Pulse = None
self.alpha: float = float(alpha)

def __eq__(self, item) -> bool:
"""Overloads == operator"""
if super().__eq__(item):
return self.alpha == item.alpha
return False

@property
def envelope_waveform_i(self) -> Waveform:
if self.pulse:
Expand Down Expand Up @@ -1011,6 +1045,19 @@ def shallow_copy(self): # -> Pulse:
self._qubit,
)

def is_equal_ignoring_start(self, item) -> bool:
"""Check if two pulses are equal ignoring start time"""
return (
self.duration == item.duration
and self.amplitude == item.amplitude
and self.frequency == item.frequency
and self.relative_phase == item.relative_phase
and self.shape == item.shape
and self.channel == item.channel
and self.type == item.type
and self.qubit == item.qubit
)

def plot(self, savefig_filename=None):
"""Plots the pulse envelope and modulated waveforms.
Expand Down Expand Up @@ -1495,7 +1542,7 @@ def add(self, *items):
ps = item
for pulse in ps.pulses:
self.pulses.append(pulse)
self.pulses.sort(key=lambda item: (item.channel, item.start))
self.pulses.sort(key=lambda item: (item.start, item.channel))

def index(self, pulse):
"""Returns the index of a pulse in the sequence."""
Expand Down
82 changes: 78 additions & 4 deletions tests/test_pulses.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,26 @@ def test_pulses_pulse_attributes():
assert p0.finish == 100


def test_pulses_is_equal_ignoring_start():
"""Checks if two pulses are equal, not looking at start time"""

p1 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0)
p2 = Pulse(100, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0)
p3 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0)
p4 = Pulse(200, 40, 0.9, 0, 0, Rectangular(), 2, PulseType.FLUX, 0)
assert p1.is_equal_ignoring_start(p2)
assert p1.is_equal_ignoring_start(p3)
assert not p1.is_equal_ignoring_start(p4)

p1 = Pulse(0, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2)
p2 = Pulse(10, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2)
p3 = Pulse(20, 50, 0.8, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2)
p4 = Pulse(30, 40, 0.9, 50e6, 0, Gaussian(4), 0, PulseType.DRIVE, 2)
assert p1.is_equal_ignoring_start(p2)
assert not p1.is_equal_ignoring_start(p3)
assert not p1.is_equal_ignoring_start(p4)


def test_pulses_pulse_serial():
p11 = Pulse(0, 40, 0.9, 50_000_000, 0, Gaussian(5), 0, PulseType.DRIVE)
assert p11.serial == "Pulse(0, 40, 0.9, 50_000_000, 0, Gaussian(5), 0, PulseType.DRIVE, 0)"
Expand Down Expand Up @@ -401,9 +421,9 @@ def test_pulses_pulse_split_pulse():


def test_pulses_pulsesequence_init():
p1 = Pulse(600, 40, 0.9, 100e6, 0, Drag(5, 1), 1, PulseType.DRIVE)
p1 = Pulse(400, 40, 0.9, 100e6, 0, Drag(5, 1), 3, PulseType.DRIVE)
p2 = Pulse(500, 40, 0.9, 100e6, 0, Drag(5, 1), 2, PulseType.DRIVE)
p3 = Pulse(400, 40, 0.9, 100e6, 0, Drag(5, 1), 3, PulseType.DRIVE)
p3 = Pulse(600, 40, 0.9, 100e6, 0, Drag(5, 1), 1, PulseType.DRIVE)

ps = PulseSequence()
assert type(ps) == PulseSequence
Expand Down Expand Up @@ -433,9 +453,9 @@ def test_pulses_pulsesequence_operators():
ps = ps + ReadoutPulse(800, 200, 0.9, 20e6, 0, Rectangular(), 2)
ps = ReadoutPulse(800, 200, 0.9, 20e6, 0, Rectangular(), 3) + ps

p4 = Pulse(300, 40, 0.9, 50e6, 0, Gaussian(5), 1, PulseType.DRIVE)
p4 = Pulse(100, 40, 0.9, 50e6, 0, Gaussian(5), 3, PulseType.DRIVE)
p5 = Pulse(200, 40, 0.9, 50e6, 0, Gaussian(5), 2, PulseType.DRIVE)
p6 = Pulse(100, 40, 0.9, 50e6, 0, Gaussian(5), 3, PulseType.DRIVE)
p6 = Pulse(300, 40, 0.9, 50e6, 0, Gaussian(5), 1, PulseType.DRIVE)

another_ps = PulseSequence()
another_ps.add(p4)
Expand Down Expand Up @@ -810,6 +830,60 @@ def test_pulses_pulseshape_drag():
)


def test_pulses_pulseshape_eq():
"""Checks == operator for pulse shapes"""

shape1 = Rectangular()
shape2 = Rectangular()
shape3 = Gaussian(5)
assert shape1 == shape2
assert not shape1 == shape3

shape1 = Gaussian(4)
shape2 = Gaussian(4)
shape3 = Gaussian(5)
assert shape1 == shape2
assert not shape1 == shape3

shape1 = Drag(4, 0.01)
shape2 = Drag(4, 0.01)
shape3 = Drag(5, 0.01)
shape4 = Drag(4, 0.05)
shape5 = Drag(5, 0.05)
assert shape1 == shape2
assert not shape1 == shape3
assert not shape1 == shape4
assert not shape1 == shape5

shape1 = IIR([-0.5, 2], [1], Rectangular())
shape2 = IIR([-0.5, 2], [1], Rectangular())
shape3 = IIR([-0.5, 4], [1], Rectangular())
shape4 = IIR([-0.4, 2], [1], Rectangular())
shape5 = IIR([-0.5, 2], [2], Rectangular())
shape6 = IIR([-0.5, 2], [2], Gaussian(5))
assert shape1 == shape2
assert not shape1 == shape3
assert not shape1 == shape4
assert not shape1 == shape5
assert not shape1 == shape6

shape1 = SNZ(17, 0.8)
shape2 = SNZ(17, 0.8)
shape3 = SNZ(18, 0.8)
shape4 = SNZ(17, 0.9)
shape5 = SNZ(18, 0.9)
assert shape1 == shape2
assert not shape1 == shape3
assert not shape1 == shape4
assert not shape1 == shape5

shape1 = eCap(4)
shape2 = eCap(4)
shape3 = eCap(5)
assert shape1 == shape2
assert not shape1 == shape3


def test_pulse():
duration = 50
rel_sigma = 5
Expand Down

0 comments on commit 047f013

Please sign in to comment.