From 3524a33a718edd08f4e978c421f09a95f142074e Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 09:49:10 +0400 Subject: [PATCH 1/8] equality --- src/qibolab/pulses.py | 49 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index dc468b298..d628d8d16 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -178,6 +178,10 @@ def __init__(self): self.name = "Rectangular" self.pulse: Pulse = None + def __eq__(self, item) -> bool: + """Overloads == operator""" + return type(item) is Rectangular + @property def envelope_waveform_i(self) -> Waveform: """The envelope waveform of the i component of the pulse.""" @@ -227,6 +231,12 @@ def __init__(self, rel_sigma: float): self.pulse: Pulse = None self.rel_sigma: float = float(rel_sigma) + def __eq__(self, item) -> bool: + """Overloads == operator""" + if type(item) is Gaussian: + return self.rel_sigma == item.rel_sigma + return False + @property def envelope_waveform_i(self) -> Waveform: """The envelope waveform of the i component of the pulse.""" @@ -281,6 +291,12 @@ def __init__(self, rel_sigma, beta): self.rel_sigma = float(rel_sigma) self.beta = float(beta) + def __eq__(self, item) -> bool: + """Overloads == operator""" + if type(item) is Drag: + return self.rel_sigma == item.rel_sigma and self.beta == item.beta + return False + @property def envelope_waveform_i(self) -> Waveform: """The envelope waveform of the i component of the pulse.""" @@ -350,6 +366,12 @@ def __init__(self, b, a, target: PulseShape): self.b: np.ndarray = np.array(b) # Check len(a) = len(b) = 2 + def __eq__(self, item) -> bool: + """Overloads == operator""" + if type(item) is IIR: + return self.target == item.target and self.a == item.a and self.b == target.b + return False + @property def pulse(self): return self._pulse @@ -422,6 +444,12 @@ def __init__(self, t_half_flux_pulse=None, b_amplitude=1): self.t_half_flux_pulse: float = t_half_flux_pulse self.b_amplitude: float = b_amplitude + def __eq__(self, item) -> bool: + """Overloads == operator""" + if type(item) is SNZ: + return self.t_half_flux_pulse == item.t_half_flux_pulse + return False + @property def envelope_waveform_i(self) -> Waveform: """The envelope waveform of the i component of the pulse.""" @@ -485,6 +513,12 @@ def __init__(self, alpha: float): self.pulse: Pulse = None self.alpha: float = float(alpha) + def __eq__(self, item) -> bool: + """Overloads == operator""" + if type(item) is eCap: + return self.alpha == item.alpha + return False + @property def envelope_waveform_i(self) -> Waveform: if self.pulse: @@ -1011,6 +1045,19 @@ def shallow_copy(self): # -> Pulse: self._qubit, ) + def is_equal(self, item: Pulse) -> bool: + """Check if two pulses are equal, excepto from the start time""" + return ( + self.duration == item.duration + and self.amplitude == item.amplitude + and self.frequency == item.frequency + and self.relative_phase == item.relative_phase + and self.shape == item.shape + and self.channel == item.channel + and self.type == item.type + and self.qubit == item.qubit + ) + def plot(self, savefig_filename=None): """Plots the pulse envelope and modulated waveforms. @@ -1495,7 +1542,7 @@ def add(self, *items): ps = item for pulse in ps.pulses: self.pulses.append(pulse) - self.pulses.sort(key=lambda item: (item.channel, item.start)) + self.pulses.sort(key=lambda item: (item.start, item.channel)) def index(self, pulse): """Returns the index of a pulse in the sequence.""" From 89da91bbd75ca47c2469198f87d1bcbadafcff78 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 09:52:24 +0400 Subject: [PATCH 2/8] fix --- src/qibolab/pulses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index d628d8d16..137c38502 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -1045,7 +1045,7 @@ def shallow_copy(self): # -> Pulse: self._qubit, ) - def is_equal(self, item: Pulse) -> bool: + def is_equal(self, item) -> bool: """Check if two pulses are equal, excepto from the start time""" return ( self.duration == item.duration From 222fae7590b8c58a41fd33add22e36c963edc2d1 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 10:21:10 +0400 Subject: [PATCH 3/8] fix tests --- tests/test_pulses.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_pulses.py b/tests/test_pulses.py index 14884c120..240633763 100644 --- a/tests/test_pulses.py +++ b/tests/test_pulses.py @@ -401,9 +401,9 @@ def test_pulses_pulse_split_pulse(): def test_pulses_pulsesequence_init(): - p1 = Pulse(600, 40, 0.9, 100e6, 0, Drag(5, 1), 1, PulseType.DRIVE) + p1 = Pulse(400, 40, 0.9, 100e6, 0, Drag(5, 1), 3, PulseType.DRIVE) p2 = Pulse(500, 40, 0.9, 100e6, 0, Drag(5, 1), 2, PulseType.DRIVE) - p3 = Pulse(400, 40, 0.9, 100e6, 0, Drag(5, 1), 3, PulseType.DRIVE) + p3 = Pulse(600, 40, 0.9, 100e6, 0, Drag(5, 1), 1, PulseType.DRIVE) ps = PulseSequence() assert type(ps) == PulseSequence @@ -433,9 +433,9 @@ def test_pulses_pulsesequence_operators(): ps = ps + ReadoutPulse(800, 200, 0.9, 20e6, 0, Rectangular(), 2) ps = ReadoutPulse(800, 200, 0.9, 20e6, 0, Rectangular(), 3) + ps - p4 = Pulse(300, 40, 0.9, 50e6, 0, Gaussian(5), 1, PulseType.DRIVE) + p4 = Pulse(100, 40, 0.9, 50e6, 0, Gaussian(5), 3, PulseType.DRIVE) p5 = Pulse(200, 40, 0.9, 50e6, 0, Gaussian(5), 2, PulseType.DRIVE) - p6 = Pulse(100, 40, 0.9, 50e6, 0, Gaussian(5), 3, PulseType.DRIVE) + p6 = Pulse(300, 40, 0.9, 50e6, 0, Gaussian(5), 1, PulseType.DRIVE) another_ps = PulseSequence() another_ps.add(p4) From 5bad9aa828fafef584d2a2bd6c31b942202a5588 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 10:35:11 +0400 Subject: [PATCH 4/8] fix lint --- src/qibolab/pulses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index 137c38502..3897f54fb 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -369,7 +369,7 @@ def __init__(self, b, a, target: PulseShape): def __eq__(self, item) -> bool: """Overloads == operator""" if type(item) is IIR: - return self.target == item.target and self.a == item.a and self.b == target.b + return self.target == item.target and self.a == item.a and self.b == item.b return False @property From 3df598fd893eebb7bf7c9ab96a011ed21152f742 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 13:52:42 +0400 Subject: [PATCH 5/8] add tests --- src/qibolab/pulses.py | 4 +-- tests/test_pulses.py | 74 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index 3897f54fb..4c7365c31 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -369,7 +369,7 @@ def __init__(self, b, a, target: PulseShape): def __eq__(self, item) -> bool: """Overloads == operator""" if type(item) is IIR: - return self.target == item.target and self.a == item.a and self.b == item.b + return self.target == item.target and (self.a == item.a).all() and (self.b == item.b).all() return False @property @@ -447,7 +447,7 @@ def __init__(self, t_half_flux_pulse=None, b_amplitude=1): def __eq__(self, item) -> bool: """Overloads == operator""" if type(item) is SNZ: - return self.t_half_flux_pulse == item.t_half_flux_pulse + return self.t_half_flux_pulse == item.t_half_flux_pulse and self.b_amplitude == item.b_amplitude return False @property diff --git a/tests/test_pulses.py b/tests/test_pulses.py index 240633763..21f235858 100644 --- a/tests/test_pulses.py +++ b/tests/test_pulses.py @@ -275,6 +275,26 @@ def test_pulses_pulse_attributes(): assert p0.finish == 100 +def test_pulses_is_equal(): + """Checks if two pulses are equal, not looking at start time""" + + p1 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) + p2 = Pulse(100, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) + p3 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) + p4 = Pulse(200, 40, 0.9, 0, 0, Rectangular(), 2, PulseType.FLUX, 0) + assert p1.is_equal(p2) + assert p1.is_equal(p3) + assert not p1.is_equal(p4) + + p1 = Pulse(0, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) + p2 = Pulse(10, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) + p3 = Pulse(20, 50, 0.8, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) + p4 = Pulse(30, 40, 0.9, 50e6, 0, Gaussian(4), 0, PulseType.DRIVE, 2) + assert p1.is_equal(p2) + assert not p1.is_equal(p3) + assert not p1.is_equal(p4) + + def test_pulses_pulse_serial(): p11 = Pulse(0, 40, 0.9, 50_000_000, 0, Gaussian(5), 0, PulseType.DRIVE) assert p11.serial == "Pulse(0, 40, 0.9, 50_000_000, 0, Gaussian(5), 0, PulseType.DRIVE, 0)" @@ -810,6 +830,60 @@ def test_pulses_pulseshape_drag(): ) +def test_pulses_pulseshape_eq(): + """Checks == operator for pulse shapes""" + + shape1 = Rectangular() + shape2 = Rectangular() + shape3 = Gaussian(5) + assert shape1 == shape2 + assert not shape1 == shape3 + + shape1 = Gaussian(4) + shape2 = Gaussian(4) + shape3 = Gaussian(5) + assert shape1 == shape2 + assert not shape1 == shape3 + + shape1 = Drag(4, 0.01) + shape2 = Drag(4, 0.01) + shape3 = Drag(5, 0.01) + shape4 = Drag(4, 0.05) + shape5 = Drag(5, 0.05) + assert shape1 == shape2 + assert not shape1 == shape3 + assert not shape1 == shape4 + assert not shape1 == shape5 + + shape1 = IIR([-0.5, 2], [1], Rectangular()) + shape2 = IIR([-0.5, 2], [1], Rectangular()) + shape3 = IIR([-0.5, 4], [1], Rectangular()) + shape4 = IIR([-0.4, 2], [1], Rectangular()) + shape5 = IIR([-0.5, 2], [2], Rectangular()) + shape6 = IIR([-0.5, 2], [2], Gaussian(5)) + assert shape1 == shape2 + assert not shape1 == shape3 + assert not shape1 == shape4 + assert not shape1 == shape5 + assert not shape1 == shape6 + + shape1 = SNZ(17, 0.8) + shape2 = SNZ(17, 0.8) + shape3 = SNZ(18, 0.8) + shape4 = SNZ(17, 0.9) + shape5 = SNZ(18, 0.9) + assert shape1 == shape2 + assert not shape1 == shape3 + assert not shape1 == shape4 + assert not shape1 == shape5 + + shape1 = eCap(4) + shape2 = eCap(4) + shape3 = eCap(5) + assert shape1 == shape2 + assert not shape1 == shape3 + + def test_pulse(): duration = 50 rel_sigma = 5 From 0ae8e601df4a8b59cb2fa34afd26b1bf3424ec11 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 16:21:08 +0400 Subject: [PATCH 6/8] better function name --- src/qibolab/pulses.py | 2 +- tests/test_pulses.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index 4c7365c31..8b87fe386 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -1045,7 +1045,7 @@ def shallow_copy(self): # -> Pulse: self._qubit, ) - def is_equal(self, item) -> bool: + def is_equal_ignoring_start(self, item) -> bool: """Check if two pulses are equal, excepto from the start time""" return ( self.duration == item.duration diff --git a/tests/test_pulses.py b/tests/test_pulses.py index 21f235858..49a5509cb 100644 --- a/tests/test_pulses.py +++ b/tests/test_pulses.py @@ -275,24 +275,24 @@ def test_pulses_pulse_attributes(): assert p0.finish == 100 -def test_pulses_is_equal(): +def test_pulses_is_equal_ignoring_start(): """Checks if two pulses are equal, not looking at start time""" p1 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) p2 = Pulse(100, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) p3 = Pulse(0, 40, 0.9, 0, 0, Rectangular(), 0, PulseType.FLUX, 0) p4 = Pulse(200, 40, 0.9, 0, 0, Rectangular(), 2, PulseType.FLUX, 0) - assert p1.is_equal(p2) - assert p1.is_equal(p3) - assert not p1.is_equal(p4) + assert p1.is_equal_ignoring_start(p2) + assert p1.is_equal_ignoring_start(p3) + assert not p1.is_equal_ignoring_start(p4) p1 = Pulse(0, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) p2 = Pulse(10, 40, 0.9, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) p3 = Pulse(20, 50, 0.8, 50e6, 0, Gaussian(5), 0, PulseType.DRIVE, 2) p4 = Pulse(30, 40, 0.9, 50e6, 0, Gaussian(4), 0, PulseType.DRIVE, 2) - assert p1.is_equal(p2) - assert not p1.is_equal(p3) - assert not p1.is_equal(p4) + assert p1.is_equal_ignoring_start(p2) + assert not p1.is_equal_ignoring_start(p3) + assert not p1.is_equal_ignoring_start(p4) def test_pulses_pulse_serial(): From b39341467e6dcf3dd439e8d71a70616d9a72a797 Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 16:24:14 +0400 Subject: [PATCH 7/8] move class __eq__ to abstract --- src/qibolab/pulses.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index 8b87fe386..7a3480ba5 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -167,6 +167,10 @@ def modulated_waveforms(self): modulated_waveform_q.serial = f"Modulated_Waveform_Q(num_samples = {num_samples}, amplitude = {format(pulse.amplitude, '.6f').rstrip('0').rstrip('.')}, shape = {str(pulse.shape)}, frequency = {format(pulse.frequency, '_')}, phase = {format(global_phase + pulse.relative_phase, '.6f').rstrip('0').rstrip('.')})" return (modulated_waveform_i, modulated_waveform_q) + def __eq__(self, item) -> bool: + """Overloads == operator""" + return type(item) is self.__class__ + class Rectangular(PulseShape): """ @@ -178,10 +182,6 @@ def __init__(self): self.name = "Rectangular" self.pulse: Pulse = None - def __eq__(self, item) -> bool: - """Overloads == operator""" - return type(item) is Rectangular - @property def envelope_waveform_i(self) -> Waveform: """The envelope waveform of the i component of the pulse.""" @@ -233,7 +233,7 @@ def __init__(self, rel_sigma: float): def __eq__(self, item) -> bool: """Overloads == operator""" - if type(item) is Gaussian: + if super().__eq__(item): return self.rel_sigma == item.rel_sigma return False @@ -293,7 +293,7 @@ def __init__(self, rel_sigma, beta): def __eq__(self, item) -> bool: """Overloads == operator""" - if type(item) is Drag: + if super().__eq__(item): return self.rel_sigma == item.rel_sigma and self.beta == item.beta return False @@ -368,7 +368,7 @@ def __init__(self, b, a, target: PulseShape): def __eq__(self, item) -> bool: """Overloads == operator""" - if type(item) is IIR: + if super().__eq__(item): return self.target == item.target and (self.a == item.a).all() and (self.b == item.b).all() return False @@ -446,7 +446,7 @@ def __init__(self, t_half_flux_pulse=None, b_amplitude=1): def __eq__(self, item) -> bool: """Overloads == operator""" - if type(item) is SNZ: + if super().__eq__(item): return self.t_half_flux_pulse == item.t_half_flux_pulse and self.b_amplitude == item.b_amplitude return False @@ -515,7 +515,7 @@ def __init__(self, alpha: float): def __eq__(self, item) -> bool: """Overloads == operator""" - if type(item) is eCap: + if super().__eq__(item): return self.alpha == item.alpha return False From 3eb9861ebf3efed651e34e2b218100a34cb9f28a Mon Sep 17 00:00:00 2001 From: rodolfocarobene Date: Wed, 17 May 2023 16:26:23 +0400 Subject: [PATCH 8/8] typo --- src/qibolab/pulses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/qibolab/pulses.py b/src/qibolab/pulses.py index 7a3480ba5..1b411cb11 100644 --- a/src/qibolab/pulses.py +++ b/src/qibolab/pulses.py @@ -1046,7 +1046,7 @@ def shallow_copy(self): # -> Pulse: ) def is_equal_ignoring_start(self, item) -> bool: - """Check if two pulses are equal, excepto from the start time""" + """Check if two pulses are equal ignoring start time""" return ( self.duration == item.duration and self.amplitude == item.amplitude