Skip to content

Commit

Permalink
Externalize texts
Browse files Browse the repository at this point in the history
  • Loading branch information
mickahell committed Jan 29, 2024
1 parent d634545 commit a53624a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
24 changes: 18 additions & 6 deletions client/purplecaffeine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
circuits: Optional[List[List[Union[str, QuantumCircuit]]]] = None,
operators: Optional[List[List[Union[str, Operator]]]] = None,
artifacts: Optional[List[List[str]]] = None,
texts: Optional[List[List[str]]] = None,
texts: Optional[List[List[Union[str, str]]]] = None,
arrays: Optional[List[List[Union[str, np.ndarray]]]] = None,
tags: Optional[List[str]] = None,
versions: Optional[List[List[str]]] = None,
Expand Down Expand Up @@ -481,15 +481,22 @@ def save(self, trial: Trial) -> str:
if not os.path.isdir(save_path):
os.makedirs(save_path)

with open(
os.path.join(save_path, "trial.json"), "w", encoding="utf-8"
) as trial_file:
json.dump(trial.__dict__, trial_file, cls=TrialEncoder, indent=4)

for circuit in trial.circuits:
save_circuit = os.path.join(save_path, f"circuit_{circuit[0]}.json")
with open(save_circuit, "w", encoding="utf-8") as circuit_file:
json.dump(circuit, circuit_file, cls=RuntimeEncoder, indent=4)
circuit[1] = f"Check the circuit_{circuit[0]}.json file."

for text in trial.texts:
save_text = os.path.join(save_path, f"text_{text[0]}.json")
with open(save_text, "w", encoding="utf-8") as text_file:
json.dump(text, text_file, cls=RuntimeEncoder, indent=4)
text[1] = f"Check the text_{text[0]}.json file."

with open(
os.path.join(save_path, "trial.json"), "w", encoding="utf-8"
) as trial_file:
json.dump(trial.__dict__, trial_file, cls=TrialEncoder, indent=4)

return self.path

Expand Down Expand Up @@ -519,6 +526,11 @@ def get(self, trial_id: str) -> Trial:
with open(circ_path, "r", encoding="utf-8") as circ_file:
trial.circuits[index] = json.load(circ_file, cls=TrialDecoder)

for index, text in enumerate(copy.copy(trial.texts)):
text_path = os.path.join(trial_path, f"text_{text[0]}.json")
with open(text_path, "r", encoding="utf-8") as text_file:
trial.texts[index] = json.load(text_file, cls=TrialDecoder)

return trial

def list(
Expand Down
1 change: 1 addition & 0 deletions client/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_save_get_list_local_storage(self):
self.assertTrue(isinstance(recovered, Trial))
self.assertEqual(recovered.parameters, [["test_parameter", "parameter"]])
self.assertEqual(recovered.circuits, [["test_circuit", QuantumCircuit(2)]])
self.assertEqual(recovered.texts, [["test_text", "text"]])
with self.assertRaises(ValueError):
self.local_storage.get(trial_id="999")
# List
Expand Down
4 changes: 4 additions & 0 deletions client/tests/utils/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ def test_encoder_decoder(self):
self.assertTrue(isinstance(trial_encode, str))
for circuit in my_trial.circuits:
circ_encode = json.dumps(circuit, cls=RuntimeEncoder, indent=4)
for text in my_trial.texts:
text_encode = json.dumps(text, cls=RuntimeEncoder, indent=4)

# Decoder
trial_decode = Trial(**json.loads(trial_encode, cls=TrialDecoder))
for index, circuit in enumerate(copy.copy(trial_decode.circuits)):
trial_decode.circuits[index] = json.loads(circ_encode, cls=TrialDecoder)
for index, text in enumerate(copy.copy(trial_decode.texts)):
trial_decode.texts[index] = json.loads(text_encode, cls=TrialDecoder)

self.assertTrue(isinstance(trial_decode, Trial))
self.assertEqual(trial_decode.metrics, my_trial.metrics)
Expand Down

0 comments on commit a53624a

Please sign in to comment.