Skip to content

Commit

Permalink
Add support for gemini API
Browse files Browse the repository at this point in the history
  • Loading branch information
joy13975 committed Feb 28, 2024
1 parent 12ad190 commit d42e93d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies = [
"langchain-core",
"langchain-community",
"langchain_openai",
"langchain-google-genai",
"openai>1",
"pysbd>=0.3.4",
"nest-asyncio",
Expand Down
8 changes: 2 additions & 6 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,17 @@ def generate_text(
stop: t.Optional[t.List[str]] = None,
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
temperature = self.get_temperature(n=n)
self.langchain_llm.temperature = self.get_temperature(n=n)
if is_multiple_completion_supported(self.langchain_llm):
return self.langchain_llm.generate_prompt(
prompts=[prompt],
n=n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
)
else:
result = self.langchain_llm.generate_prompt(
prompts=[prompt] * n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
)
Expand All @@ -166,19 +164,17 @@ async def agenerate_text(
stop: t.Optional[t.List[str]] = None,
callbacks: t.Optional[Callbacks] = None,
) -> LLMResult:
temperature = self.get_temperature(n=n)
self.langchain_llm.temperature = self.get_temperature(n=n)
if is_multiple_completion_supported(self.langchain_llm):
return await self.langchain_llm.agenerate_prompt(
prompts=[prompt],
n=n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
)
else:
result = await self.langchain_llm.agenerate_prompt(
prompts=[prompt] * n,
temperature=temperature,
stop=stop,
callbacks=callbacks,
)
Expand Down
68 changes: 67 additions & 1 deletion src/ragas/testset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from datasets import Dataset
from langchain_openai.chat_models import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
from langchain_google_genai import (
HarmBlockThreshold,
HarmCategory,
)

from ragas._analytics import TestsetGenerationEvent, track
from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper
Expand All @@ -30,6 +36,7 @@
from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter
from ragas.utils import check_if_sum_is_close, get_feature_language, is_nan


if t.TYPE_CHECKING:
from langchain_core.documents import Document as LCDocument
from llama_index.core.schema import Document as LlamaindexDocument
Expand Down Expand Up @@ -85,6 +92,65 @@ def with_openai(
embeddings_model = LangchainEmbeddingsWrapper(
OpenAIEmbeddings(model=embeddings)
)
return cls._common_constructor(
chunk_size=chunk_size,
generator_llm_model=generator_llm_model,
embeddings_model=embeddings_model,
critic_llm_model=critic_llm_model,
run_config=run_config,
)


@classmethod
def with_google(
cls,
generator_llm: str = "models/gemini-pro",
critic_llm: str = "models/gemini-pro",
embeddings: str = "models/embedding-001",
docstore: t.Optional[DocumentStore] = None,
chunk_size: int = 512,
run_config: t.Optional[RunConfig] = None,
) -> "TestsetGenerator":
safety_blocknone = {
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
generator_llm_model = LangchainLLMWrapper(
ChatGoogleGenerativeAI(
model=generator_llm,
safety_settings=safety_blocknone,
)
)
critic_llm_model = LangchainLLMWrapper(
ChatGoogleGenerativeAI(
model=critic_llm,
safety_settings=safety_blocknone,
)
)
embeddings_model = LangchainEmbeddingsWrapper(
GoogleGenerativeAIEmbeddings(model=embeddings)
)
return cls._common_constructor(
chunk_size=chunk_size,
generator_llm_model=generator_llm_model,
embeddings_model=embeddings_model,
critic_llm_model=critic_llm_model,
docstore=docstore,
run_config=run_config,
)

@classmethod
def _common_constructor(
cls,
chunk_size: int,
generator_llm_model: LangchainLLMWrapper,
embeddings_model: LangchainLLMWrapper,
critic_llm_model: LangchainLLMWrapper,
docstore: t.Optional[DocumentStore],
run_config: t.Optional[RunConfig],
):
keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model)
if docstore is None:
from langchain.text_splitter import TokenTextSplitter
Expand All @@ -109,7 +175,7 @@ def with_openai(
embeddings=embeddings_model,
docstore=docstore,
)

# if you add any arguments to this function, make sure to add them to
# generate_with_langchain_docs as well
def generate_with_llamaindex_docs(
Expand Down

0 comments on commit d42e93d

Please sign in to comment.