From d5e80ef66ede1989babfa570deef4f65278b749e Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 16 Oct 2024 14:22:09 -0700 Subject: [PATCH] fix saving/loading retriever --- dspy/primitives/module.py | 19 ++++++++++++------- tests/retrieve/test_llama_index_rm.py | 26 ++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/dspy/primitives/module.py b/dspy/primitives/module.py index 87ce8b1aa..762387ef7 100644 --- a/dspy/primitives/module.py +++ b/dspy/primitives/module.py @@ -8,6 +8,7 @@ # NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from # named_sub_modules for the time being. + class BaseModule: def __init__(self): pass @@ -29,7 +30,7 @@ def add_parameter(param_name, param_value): visited.add(id(param_value)) param_name = postprocess_parameter_name(param_name, param_value) named_parameters.append((param_name, param_value)) - + elif isinstance(param_value, dspy.Module): # When a sub-module is pre-compiled, keep it frozen. if not getattr(param_value, "_compiled", False): @@ -42,7 +43,7 @@ def add_parameter(param_name, param_value): for name, value in self.__dict__.items(): if isinstance(value, Parameter): add_parameter(name, value) - + elif isinstance(value, dspy.Module): # When a sub-module is pre-compiled, keep it frozen. if not getattr(value, "_compiled", False): @@ -153,7 +154,11 @@ def dump_state(self, save_verbose): def load_state(self, state, use_legacy_loading=False): for name, param in self.named_parameters(): - param.load_state(state[name], use_legacy_loading=use_legacy_loading) + if isinstance(param, BaseModule): + param.load_state(state[name], use_legacy_loading=use_legacy_loading) + else: + # `use_legacy_loading` is only applicable for BaseModule instances. + param.load_state(state[name]) def save(self, path, save_field_meta=False): with open(path, "w") as f: @@ -168,11 +173,11 @@ def postprocess_parameter_name(name, value): # For ChainOfThought backward compatibility, remove ending ._predict if it's there if name.endswith("._predict"): name = name[:-9] - + if name.endswith(".self"): name = name[:-5] - + if name == "_predict": return "self" - - return name \ No newline at end of file + + return name diff --git a/tests/retrieve/test_llama_index_rm.py b/tests/retrieve/test_llama_index_rm.py index f06f96388..9a4246fa9 100644 --- a/tests/retrieve/test_llama_index_rm.py +++ b/tests/retrieve/test_llama_index_rm.py @@ -21,7 +21,7 @@ @pytest.fixture() def rag_setup() -> dict: """Builds the necessary fixtures to test LI""" - pytest.importorskip("llamaindex") + pytest.importorskip("llama_index") dataset = HotPotQA(train_seed=1, train_size=8, eval_seed=2023, dev_size=4, test_size=0) trainset = [x.with_inputs("question") for x in dataset.train] devset = [x.with_inputs("question") for x in dataset.dev] @@ -46,7 +46,7 @@ def rag_setup() -> dict: def test_lirm_as_rm(rag_setup): """Test the retriever as retriever method""" - pytest.importorskip("llamaindex") + pytest.importorskip("llama_index") retriever = rag_setup.get("retriever") test_res_li = retriever.retrieve("At My Window was released by which American singer-songwriter?") rm = rag_setup.get("rm") @@ -59,3 +59,25 @@ def test_lirm_as_rm(rag_setup): assert isinstance(test_res_dspy, list), "Ensuring the results are a list from the DSPy retriever" assert len(test_res_li) == len(test_res_dspy), "Rough equality check of the results" + + +def test_save_load_llama_index_rag(rag_setup, tmp_path): + pytest.importorskip("llama_index") + + class RAG(dspy.Module): + def __init__(self): + super().__init__() + self.retriever = dspy.Retrieve(k=3) + self.cot = dspy.ChainOfThought("question, context -> answer") + + rag = RAG() + rag.retriever.k = 4 + + file_path = tmp_path / "rag.json" + rag.save(file_path) + loaded_rag = RAG() + # Before loading, the retriever k should be 3. + assert loaded_rag.retriever.k == 3 + # After loading, the retriever k should be 4. + loaded_rag.load(file_path) + assert loaded_rag.retriever.k == 4