Skip to content

Commit

Permalink
Merge pull request stanfordnlp#1638 from chenmoneygithub/fix-saving-arg
Browse files Browse the repository at this point in the history
Fix saving/loading retriever
  • Loading branch information
okhat authored Oct 16, 2024
2 parents a1eae3f + d5e80ef commit 05b1da8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
19 changes: 12 additions & 7 deletions dspy/primitives/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from
# named_sub_modules for the time being.


class BaseModule:
def __init__(self):
pass
Expand All @@ -29,7 +30,7 @@ def add_parameter(param_name, param_value):
visited.add(id(param_value))
param_name = postprocess_parameter_name(param_name, param_value)
named_parameters.append((param_name, param_value))

elif isinstance(param_value, dspy.Module):
# When a sub-module is pre-compiled, keep it frozen.
if not getattr(param_value, "_compiled", False):
Expand All @@ -42,7 +43,7 @@ def add_parameter(param_name, param_value):
for name, value in self.__dict__.items():
if isinstance(value, Parameter):
add_parameter(name, value)

elif isinstance(value, dspy.Module):
# When a sub-module is pre-compiled, keep it frozen.
if not getattr(value, "_compiled", False):
Expand Down Expand Up @@ -153,7 +154,11 @@ def dump_state(self, save_verbose):

def load_state(self, state, use_legacy_loading=False):
for name, param in self.named_parameters():
param.load_state(state[name], use_legacy_loading=use_legacy_loading)
if isinstance(param, BaseModule):
param.load_state(state[name], use_legacy_loading=use_legacy_loading)
else:
# `use_legacy_loading` is only applicable for BaseModule instances.
param.load_state(state[name])

def save(self, path, save_field_meta=False):
with open(path, "w") as f:
Expand All @@ -168,11 +173,11 @@ def postprocess_parameter_name(name, value):
# For ChainOfThought backward compatibility, remove ending ._predict if it's there
if name.endswith("._predict"):
name = name[:-9]

if name.endswith(".self"):
name = name[:-5]

if name == "_predict":
return "self"
return name

return name
26 changes: 24 additions & 2 deletions tests/retrieve/test_llama_index_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@pytest.fixture()
def rag_setup() -> dict:
"""Builds the necessary fixtures to test LI"""
pytest.importorskip("llamaindex")
pytest.importorskip("llama_index")
dataset = HotPotQA(train_seed=1, train_size=8, eval_seed=2023, dev_size=4, test_size=0)
trainset = [x.with_inputs("question") for x in dataset.train]
devset = [x.with_inputs("question") for x in dataset.dev]
Expand All @@ -46,7 +46,7 @@ def rag_setup() -> dict:

def test_lirm_as_rm(rag_setup):
"""Test the retriever as retriever method"""
pytest.importorskip("llamaindex")
pytest.importorskip("llama_index")
retriever = rag_setup.get("retriever")
test_res_li = retriever.retrieve("At My Window was released by which American singer-songwriter?")
rm = rag_setup.get("rm")
Expand All @@ -59,3 +59,25 @@ def test_lirm_as_rm(rag_setup):
assert isinstance(test_res_dspy, list), "Ensuring the results are a list from the DSPy retriever"

assert len(test_res_li) == len(test_res_dspy), "Rough equality check of the results"


def test_save_load_llama_index_rag(rag_setup, tmp_path):
pytest.importorskip("llama_index")

class RAG(dspy.Module):
def __init__(self):
super().__init__()
self.retriever = dspy.Retrieve(k=3)
self.cot = dspy.ChainOfThought("question, context -> answer")

rag = RAG()
rag.retriever.k = 4

file_path = tmp_path / "rag.json"
rag.save(file_path)
loaded_rag = RAG()
# Before loading, the retriever k should be 3.
assert loaded_rag.retriever.k == 3
# After loading, the retriever k should be 4.
loaded_rag.load(file_path)
assert loaded_rag.retriever.k == 4

0 comments on commit 05b1da8

Please sign in to comment.