diff --git a/pyproject.toml b/pyproject.toml index 0851232f37..6f402d1158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "langchain-core", "langchain-community", "langchain_openai", + "langchain-google-genai", "openai>1", "pysbd>=0.3.4", "nest-asyncio", diff --git a/src/ragas/llms/base.py b/src/ragas/llms/base.py index f2e0d78202..4272823a67 100644 --- a/src/ragas/llms/base.py +++ b/src/ragas/llms/base.py @@ -136,19 +136,17 @@ def generate_text( stop: t.Optional[t.List[str]] = None, callbacks: t.Optional[Callbacks] = None, ) -> LLMResult: - temperature = self.get_temperature(n=n) + self.langchain_llm.temperature = self.get_temperature(n=n) if is_multiple_completion_supported(self.langchain_llm): return self.langchain_llm.generate_prompt( prompts=[prompt], n=n, - temperature=temperature, stop=stop, callbacks=callbacks, ) else: result = self.langchain_llm.generate_prompt( prompts=[prompt] * n, - temperature=temperature, stop=stop, callbacks=callbacks, ) @@ -166,19 +164,17 @@ async def agenerate_text( stop: t.Optional[t.List[str]] = None, callbacks: t.Optional[Callbacks] = None, ) -> LLMResult: - temperature = self.get_temperature(n=n) + self.langchain_llm.temperature = self.get_temperature(n=n) if is_multiple_completion_supported(self.langchain_llm): return await self.langchain_llm.agenerate_prompt( prompts=[prompt], n=n, - temperature=temperature, stop=stop, callbacks=callbacks, ) else: result = await self.langchain_llm.agenerate_prompt( prompts=[prompt] * n, - temperature=temperature, stop=stop, callbacks=callbacks, ) diff --git a/src/ragas/testset/generator.py b/src/ragas/testset/generator.py index 4febcc5e14..a3f682c59b 100644 --- a/src/ragas/testset/generator.py +++ b/src/ragas/testset/generator.py @@ -9,6 +9,12 @@ from datasets import Dataset from langchain_openai.chat_models import ChatOpenAI from langchain_openai.embeddings import OpenAIEmbeddings +from langchain_google_genai.chat_models import ChatGoogleGenerativeAI +from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings +from langchain_google_genai import ( + HarmBlockThreshold, + HarmCategory, +) from ragas._analytics import TestsetGenerationEvent, track from ragas.embeddings.base import BaseRagasEmbeddings, LangchainEmbeddingsWrapper @@ -30,6 +36,7 @@ from ragas.testset.filters import EvolutionFilter, NodeFilter, QuestionFilter from ragas.utils import check_if_sum_is_close, get_feature_language, is_nan + if t.TYPE_CHECKING: from langchain_core.documents import Document as LCDocument from llama_index.core.schema import Document as LlamaindexDocument @@ -85,6 +92,65 @@ def with_openai( embeddings_model = LangchainEmbeddingsWrapper( OpenAIEmbeddings(model=embeddings) ) + return cls._common_constructor( + chunk_size=chunk_size, + generator_llm_model=generator_llm_model, + embeddings_model=embeddings_model, + critic_llm_model=critic_llm_model, + run_config=run_config, + ) + + + @classmethod + def with_google( + cls, + generator_llm: str = "models/gemini-pro", + critic_llm: str = "models/gemini-pro", + embeddings: str = "models/embedding-001", + docstore: t.Optional[DocumentStore] = None, + chunk_size: int = 512, + run_config: t.Optional[RunConfig] = None, + ) -> "TestsetGenerator": + safety_blocknone = { + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, + } + generator_llm_model = LangchainLLMWrapper( + ChatGoogleGenerativeAI( + model=generator_llm, + safety_settings=safety_blocknone, + ) + ) + critic_llm_model = LangchainLLMWrapper( + ChatGoogleGenerativeAI( + model=critic_llm, + safety_settings=safety_blocknone, + ) + ) + embeddings_model = LangchainEmbeddingsWrapper( + GoogleGenerativeAIEmbeddings(model=embeddings) + ) + return cls._common_constructor( + chunk_size=chunk_size, + generator_llm_model=generator_llm_model, + embeddings_model=embeddings_model, + critic_llm_model=critic_llm_model, + docstore=docstore, + run_config=run_config, + ) + + @classmethod + def _common_constructor( + cls, + chunk_size: int, + generator_llm_model: LangchainLLMWrapper, + embeddings_model: LangchainLLMWrapper, + critic_llm_model: LangchainLLMWrapper, + docstore: t.Optional[DocumentStore], + run_config: t.Optional[RunConfig], + ): keyphrase_extractor = KeyphraseExtractor(llm=generator_llm_model) if docstore is None: from langchain.text_splitter import TokenTextSplitter @@ -109,7 +175,7 @@ def with_openai( embeddings=embeddings_model, docstore=docstore, ) - + # if you add any arguments to this function, make sure to add them to # generate_with_langchain_docs as well def generate_with_llamaindex_docs(