From 0858afd79889357d06b6b31d1555566555c080fe Mon Sep 17 00:00:00 2001
From: Andrea <andreapasquale97@gmail.com>
Date: Wed, 5 Jun 2024 11:16:02 +0400
Subject: [PATCH] fix: Proper callback for CMA

---
 src/qibo/optimizers.py           | 18 ++++++++++++++----
 tests/test_models_variational.py |  9 ++-------
 2 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/src/qibo/optimizers.py b/src/qibo/optimizers.py
index 6e972b3235..6645326d95 100644
--- a/src/qibo/optimizers.py
+++ b/src/qibo/optimizers.py
@@ -133,10 +133,20 @@ def cmaes(loss, initial_parameters, args=(), callback=None, options=None):
     """
     import cma
 
-    r = cma.fmin2(
-        loss, initial_parameters, 1.7, options=options, args=args, callback=callback
-    )
-    return r[1].result.fbest, r[1].result.xbest, r
+    es = cma.CMAEvolutionStrategy(initial_parameters, sigma0=1.7, inopts=options)
+
+    if callback is not None:
+        while not es.stop():
+            solutions = es.ask()
+            objective_values = [loss(x, *args) for x in solutions]
+            for solution in solutions:
+                callback(solution)
+            es.tell(solutions, objective_values)
+            es.logger.add()
+    else:
+        es.optimize(loss, args=args)
+
+    return es.result.fbest, es.result.xbest, es.result
 
 
 def newtonian(
diff --git a/tests/test_models_variational.py b/tests/test_models_variational.py
index a296e0a0c2..2ee1d2bc59 100644
--- a/tests/test_models_variational.py
+++ b/tests/test_models_variational.py
@@ -92,7 +92,7 @@ def myloss(parameters, circuit, target):
     ("BFGS", {"maxiter": 1}, False, "vqe_bfgs.out"),
     ("parallel_L-BFGS-B", {"maxiter": 1}, True, None),
     ("parallel_L-BFGS-B", {"maxiter": 1}, False, None),
-    ("cma", {"maxfevals": 2}, False, None),
+    ("cma", {"maxiter": 1}, False, None),
     ("sgd", {"nepochs": 5}, False, None),
     ("sgd", {"nepochs": 5}, True, None),
 ]
@@ -132,10 +132,6 @@ def test_vqe(backend, method, options, compile, filename):
     loss_values = []
 
     def callback(parameters, loss_values=loss_values, vqe=v):
-        # cma callback takes as input a CMAEvolutionStrategy class
-        # which keeps track of the best current solution into its .best.x
-        if method == "cma":
-            parameters = parameters.best.x
         vqe.circuit.set_parameters(parameters)
         loss_values.append(vqe.hamiltonian.expectation(vqe.circuit().state()))
 
@@ -153,7 +149,6 @@ def callback(parameters, loss_values=loss_values, vqe=v):
         shutil.rmtree("outcmaes")
     if filename is not None:
         assert_regression_fixture(backend, params, filename)
-
     assert best == min(loss_values)
 
     # test energy fluctuation
@@ -316,7 +311,7 @@ def __call__(self, x):
 test_names = "method,options,compile,filename"
 test_values = [
     ("BFGS", {"maxiter": 1}, False, "aavqe_bfgs.out"),
-    ("cma", {"maxfevals": 2}, False, None),
+    ("cma", {"maxiter": 1}, False, None),
     ("parallel_L-BFGS-B", {"maxiter": 1}, False, None),
 ]