From f4aa92e9696f1b3da0eea399cbf290f5ccbd4b91 Mon Sep 17 00:00:00 2001 From: TesAnti <8780022+TesAnti@users.noreply.github.com> Date: Wed, 8 Nov 2023 12:01:20 +0100 Subject: [PATCH] feat: Added something similar to LCEL (#54) * bugfix. the chain calls should not erase thge data containing in input dictionary * Added something similar to LCEL --- .../InMemoryVectorStore.cs | 10 ++ src/libs/LangChain.Core/Chains/Chain.cs | 38 +++++++ .../BaseCombineDocumentsChain.cs | 13 ++- .../StuffDocumentsChainInput.cs | 2 +- .../LangChain.Core/Chains/LLM/LLMChain.cs | 20 +++- .../StackableChains/BaseStackableChain.cs | 72 ++++++++++++ .../Exceptions/StackableChainException.cs | 8 ++ .../Chains/StackableChains/LLMChain.cs | 28 +++++ .../Chains/StackableChains/PromptChain.cs | 52 +++++++++ .../StackableChains/RetreiveDocumentsChain.cs | 31 +++++ .../Chains/StackableChains/SetChain.cs | 23 ++++ .../Chains/StackableChains/StackChain.cs | 78 +++++++++++++ .../StackableChains/StuffDocumentsChain.cs | 55 +++++++++ .../Docstore/DocumentExtensions.cs | 21 ++++ src/libs/LangChain.Core/LangChain.Core.csproj | 8 +- .../LangChain.Core/Prompts/PromptTemplate.cs | 24 +++- .../VectorStores/VectorStore.cs | 2 +- .../VectorStores/VectorStoreRetriever.cs | 2 +- .../LLamaSharpEmbeddings.cs | 10 ++ .../LLamaSharpModelInstruction.cs | 10 ++ .../ChainTests.cs | 107 ++++++++++++++++++ 21 files changed, 604 insertions(+), 10 deletions(-) create mode 100644 src/libs/LangChain.Core/Chains/Chain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/RetreiveDocumentsChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs create mode 100644 src/libs/LangChain.Core/Chains/StackableChains/StuffDocumentsChain.cs create mode 100644 src/libs/LangChain.Core/Docstore/DocumentExtensions.cs create mode 100644 src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/ChainTests.cs diff --git a/src/libs/Databases/LangChain.Databases.InMemory/InMemoryVectorStore.cs b/src/libs/Databases/LangChain.Databases.InMemory/InMemoryVectorStore.cs index b9577aee..80b6d336 100644 --- a/src/libs/Databases/LangChain.Databases.InMemory/InMemoryVectorStore.cs +++ b/src/libs/Databases/LangChain.Databases.InMemory/InMemoryVectorStore.cs @@ -5,12 +5,22 @@ using System.Threading.Tasks; using LangChain.Abstractions.Embeddings.Base; using LangChain.Docstore; +using LangChain.Indexes; +using LangChain.TextSplitters; using LangChain.VectorStores; namespace LangChain.Databases.InMemory { public class InMemoryVectorStore:VectorStore { + public static async Task CreateIndexFromDocuments(IEmbeddings embeddings,List documents) + { + InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings); + var textSplitter = new CharacterTextSplitter(); + VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter); + var index = await indexCreator.FromDocumentsAsync(documents); + return index; + } private readonly Func _distanceFunction; List<(float[] vec, string id, Document doc)> _storage = new List<(float[] vec, string id, Document doc)>(); diff --git a/src/libs/LangChain.Core/Chains/Chain.cs b/src/libs/LangChain.Core/Chains/Chain.cs new file mode 100644 index 00000000..24abc5b2 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/Chain.cs @@ -0,0 +1,38 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Chains.HelperChains; +using LangChain.Indexes; +using LangChain.Providers; + +namespace LangChain.Chains; + +public static class Chain +{ + public static BaseStackableChain Template(string template, + string outputKey = "prompt") + { + return new PromptChain(template, outputKey); + } + + public static BaseStackableChain Set(string value, string outputKey = "value") + { + return new SetChain(value, outputKey); + } + + public static BaseStackableChain LLM(IChatModel llm, + string inputKey = "prompt", string outputKey = "text") + { + return new LLMChain(llm, inputKey, outputKey); + } + + public static BaseStackableChain RetreiveDocuments(VectorStoreIndexWrapper index, + string inputKey = "query", string outputKey = "documents") + { + return new RetreiveDocumentsChain(index, inputKey, outputKey); + } + + public static BaseStackableChain StuffDocuments( + string inputKey = "documents", string outputKey = "combined") + { + return new StuffDocumentsChain(inputKey, outputKey); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs index e047a8ad..f2d8c443 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/BaseCombineDocumentsChain.cs @@ -42,10 +42,19 @@ protected override async Task CallAsync(IChainValues values, Callb .ToDictionary(kv => kv.Key, kv => kv.Value); var (output, returnDict) = await CombineDocsAsync((docs as List), otherKeys); - + returnDict[OutputKey] = output; - return new ChainValues(returnDict); + // merge dictionaries + foreach (var kv in returnDict) + { + if (!returnDict.ContainsKey(kv.Key)) + { + values.Value[kv.Key] = returnDict[kv.Key]; + } + } + + return values; } /// diff --git a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChainInput.cs b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChainInput.cs index f8d42ab1..e6738aae 100644 --- a/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChainInput.cs +++ b/src/libs/LangChain.Core/Chains/CombineDocuments/StuffDocumentsChainInput.cs @@ -13,7 +13,7 @@ public class StuffDocumentsChainInput(ILlmChain llmChain) : BaseCombineDocuments public ILlmChain LlmChain { get; } = llmChain; /// - /// Prompt to use to format each document, gets passed to `format_document`. + /// Template to use to format each document, gets passed to `format_document`. /// public BasePromptTemplate DocumentPrompt { get; set; } = new PromptTemplate( new PromptTemplateInput( diff --git a/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs b/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs index a0365c88..a8125abc 100644 --- a/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs +++ b/src/libs/LangChain.Core/Chains/LLM/LLMChain.cs @@ -73,10 +73,24 @@ protected override async Task CallAsync(IChainValues values, Callb Console.WriteLine("\n".PadLeft(Console.WindowWidth, '<')); } - if(string.IsNullOrEmpty(OutputKey)) - return new ChainValues(response.Messages.Last().Content); + Dictionary returnDict = new Dictionary(); + + + if (string.IsNullOrEmpty(OutputKey)) + returnDict["text"] = response.Messages.Last().Content; + else + returnDict[OutputKey] = response.Messages.Last().Content; - return new ChainValues(OutputKey,response.Messages.Last().Content); + // merge dictionaries + foreach (var kv in returnDict) + { + if (!returnDict.ContainsKey(kv.Key)) + { + values.Value[kv.Key] = returnDict[kv.Key]; + } + } + + return values; } public async Task Predict(ChainValues values) diff --git a/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs new file mode 100644 index 00000000..20a2fbc1 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs @@ -0,0 +1,72 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using LangChain.Chains.HelperChains.Exceptions; + +namespace LangChain.Chains.HelperChains; + +public abstract class BaseStackableChain:IChain +{ + public string Name { get; set; } + public virtual string[] InputKeys { get; protected set; } + public virtual string[] OutputKeys { get; protected set; } + + protected string GenerateName() + { + return GetType().Name; + } + + private string GetInputs() + { + return string.Join(",", InputKeys); + } + + private string GetOutputs() + { + return string.Join(",", OutputKeys); + } + + string FormatInputValues(IChainValues values) + { + List res = new(); + foreach (var key in InputKeys) + { + if (!values.Value.ContainsKey(key)) + { + res.Add($"{key} is expected but missing"); + continue; + }; + res.Add($"{key}={values.Value[key]}"); + } + return string.Join(",\n", res); + } + + public Task CallAsync(IChainValues values, ICallbacks? callbacks = null, + List? tags = null, Dictionary? metadata = null) + { + try + { + return InternallCall(values); + } + catch (StackableChainException) + { + throw; + } + catch (Exception ex) + { + var name=Name??GenerateName(); + var inputValues= FormatInputValues(values); + var message = $"Error occured in {name} with inputs \n{inputValues}\n."; + + throw new StackableChainException(message,ex); + } + + } + + protected abstract Task InternallCall(IChainValues values); + + public static StackChain operator |(BaseStackableChain a, BaseStackableChain b) + { + return new StackChain(a, b); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs b/src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs new file mode 100644 index 00000000..caa230ec --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs @@ -0,0 +1,8 @@ +namespace LangChain.Chains.HelperChains.Exceptions; + +public class StackableChainException:Exception +{ + public StackableChainException(string message,Exception inner) : base(message, inner) + { + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs new file mode 100644 index 00000000..48615bcd --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs @@ -0,0 +1,28 @@ +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using LangChain.Providers; + +namespace LangChain.Chains.HelperChains; + +public class LLMChain:BaseStackableChain +{ + private readonly IChatModel _llm; + + public LLMChain(IChatModel llm, + string inputKey="prompt", + string outputKey="text" + ) + { + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputKey }; + _llm = llm; + } + + protected override async Task InternallCall(IChainValues values) + { + var prompt = values.Value[InputKeys[0]].ToString(); + var response=await _llm.GenerateAsync(new ChatRequest(new List() { prompt.AsSystemMessage() })); + values.Value[OutputKeys[0]] = response.Messages.Last().Content; + return values; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs new file mode 100644 index 00000000..c6cd6c71 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs @@ -0,0 +1,52 @@ +using System.Text.RegularExpressions; +using LangChain.Abstractions.Chains.Base; +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using LangChain.Chains.LLM; +using LangChain.Prompts; + +namespace LangChain.Chains.HelperChains; + +public class PromptChain: BaseStackableChain +{ + private readonly string _template; + + public PromptChain(string template,string outputKey="prompt") + { + OutputKeys = new[] { outputKey }; + _template = template; + InputKeys = GetVariables().ToArray(); + } + + List GetVariables() + { + string pattern = @"\{([^\{\}]+)\}"; + var variables = new List(); + var matches = Regex.Matches(_template, pattern); + foreach (Match match in matches) + { + variables.Add(match.Groups[1].Value); + } + return variables; + } + + + + + protected override Task InternallCall(IChainValues values) + { + // validate that input keys containing all variables + var valueKeys = values.Value.Keys; + var missing = InputKeys.Except(valueKeys); + if (missing.Any()) + { + throw new Exception($"Input keys must contain all variables in template. Missing: {string.Join(",",missing)}"); + } + + var formattedPrompt = PromptTemplate.InterpolateFString(_template,values.Value); + + values.Value[OutputKeys[0]]= formattedPrompt; + + return Task.FromResult(values); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/RetreiveDocumentsChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/RetreiveDocumentsChain.cs new file mode 100644 index 00000000..c21d61ed --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/RetreiveDocumentsChain.cs @@ -0,0 +1,31 @@ +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using System.Numerics; +using LangChain.Indexes; + +namespace LangChain.Chains.HelperChains; + +public class RetreiveDocumentsChain:BaseStackableChain +{ + private readonly VectorStoreIndexWrapper _index; + private readonly int _amount; + + public RetreiveDocumentsChain(VectorStoreIndexWrapper index, string inputKey="query", string outputKey="documents", int amount=4) + { + _index = index; + _amount = amount; + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputKey }; + } + + protected override async Task InternallCall(IChainValues values) + { + var retreiver = _index.Store.AsRetreiver(); + retreiver.K = _amount; + + var query = values.Value[InputKeys[0]].ToString(); + var results = await retreiver.GetRelevantDocumentsAsync(query); + values.Value[OutputKeys[0]] = results.ToList(); + return values; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs new file mode 100644 index 00000000..9ee7f1b9 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs @@ -0,0 +1,23 @@ +using LangChain.Abstractions.Chains.Base; +using LangChain.Abstractions.Schema; +using LangChain.Callback; + +namespace LangChain.Chains.HelperChains; + +public class SetChain: BaseStackableChain +{ + private readonly string _query; + public SetChain(string query, string outputKey="query") + { + OutputKeys = new[] { outputKey }; + _query = query; + } + + protected override Task InternallCall(IChainValues values) + { + values.Value[OutputKeys[0]] = _query; + return Task.FromResult(values); + } + + +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs new file mode 100644 index 00000000..d51577c7 --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/StackChain.cs @@ -0,0 +1,78 @@ +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using LangChain.Schema; + +namespace LangChain.Chains.HelperChains; + +public class StackChain:BaseStackableChain +{ + private readonly BaseStackableChain _a; + private readonly BaseStackableChain _b; + + public string[] IsolatedInputKeys { get; set; }=new string[0]; + public string[] IsolatedOutputKeys { get; set; }=new string[0]; + + public StackChain(BaseStackableChain a, BaseStackableChain b) + { + _a = a; + _b = b; + + } + + public StackChain AsIsolated(string[] inputKeys = null, string[] outputKeys = null) + { + IsolatedInputKeys = inputKeys ?? IsolatedInputKeys; + IsolatedOutputKeys = outputKeys ?? IsolatedOutputKeys; + return this; + } + + public StackChain AsIsolated(string inputKey = null, string outputKey = null) + { + if (inputKey != null) IsolatedInputKeys = new[] { inputKey }; + if (outputKey != null) IsolatedOutputKeys = new[] { outputKey }; + return this; + } + + protected override async Task InternallCall(IChainValues values) + { + // since it is reference type, the values would be changed anyhow + var originalValues = values; + + if (IsolatedInputKeys.Length>0) + { + var res = new ChainValues(); + foreach (var key in IsolatedInputKeys) + { + res.Value[key] = values.Value[key]; + } + values = res; + } + await _a.CallAsync(values); + await _b.CallAsync(values); + if (IsolatedOutputKeys.Length > 0) + { + + foreach (var key in IsolatedOutputKeys) + { + originalValues.Value[key] = values.Value[key]; + } + + } + return originalValues; + } + + + + public async Task Run() + { + + var res = await CallAsync(new ChainValues()); + return res; + } + + public async Task Run(string resultKey) + { + var res = await CallAsync(new ChainValues()); + return res.Value[resultKey].ToString(); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Chains/StackableChains/StuffDocumentsChain.cs b/src/libs/LangChain.Core/Chains/StackableChains/StuffDocumentsChain.cs new file mode 100644 index 00000000..32c7004f --- /dev/null +++ b/src/libs/LangChain.Core/Chains/StackableChains/StuffDocumentsChain.cs @@ -0,0 +1,55 @@ +using LangChain.Abstractions.Schema; +using LangChain.Callback; +using LangChain.Docstore; +using LangChain.Prompts; + +namespace LangChain.Chains.HelperChains; + +public class StuffDocumentsChain:BaseStackableChain +{ + public string DocumentsSeparator { get; set; } = "\n\n"; + + public string Format { get; set; } = "{document}"; + public string FormatKey { get; set; } = "document"; + + public StuffDocumentsChain(string inputKey="documents", string outputKey="combined") + { + InputKeys = new[] { inputKey }; + OutputKeys = new[] { outputKey }; + + } + + public StuffDocumentsChain WithSeparator(string separator) + { + DocumentsSeparator = separator; + return this; + } + + public StuffDocumentsChain WithFormat(string format, string key="document") + { + Format = format; + FormatKey = key; + return this; + } + + protected override Task InternallCall(IChainValues values) + { + var documentsObject = values.Value[InputKeys[0]]; + if (documentsObject is not List docs) + { + throw new ArgumentException($"{InputKeys[0]} is not a list of documents"); + } + + var docStrings = new List(); + foreach (var doc in docs) + { + var docString = PromptTemplate.InterpolateFStringSafe(Format,new Dictionary{{FormatKey,doc.PageContent}}); + docStrings.Add(docString); + } + + var docsString = String.Join(DocumentsSeparator, docStrings); + + values.Value[OutputKeys[0]] = docsString; + return Task.FromResult(values); + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/Docstore/DocumentExtensions.cs b/src/libs/LangChain.Core/Docstore/DocumentExtensions.cs new file mode 100644 index 00000000..bce00cc2 --- /dev/null +++ b/src/libs/LangChain.Core/Docstore/DocumentExtensions.cs @@ -0,0 +1,21 @@ +using System.Runtime.CompilerServices; + +namespace LangChain.Docstore; + +public static class DocumentExtensions +{ + public static Document ToDocument(this string self) + { + return new Document(self); + } + + public static List ToDocuments(this IEnumerable self) + { + List documents = new(); + foreach (var item in self) + { + documents.Add(item.ToDocument()); + } + return documents; + } +} \ No newline at end of file diff --git a/src/libs/LangChain.Core/LangChain.Core.csproj b/src/libs/LangChain.Core/LangChain.Core.csproj index 038c4ce7..2b9e8e5a 100644 --- a/src/libs/LangChain.Core/LangChain.Core.csproj +++ b/src/libs/LangChain.Core/LangChain.Core.csproj @@ -7,6 +7,12 @@ $(NoWarn);CA1716;CA1819;CA1012;CA2227;CA1707;CA2214;CA1854;CA1040;CA1051;CS1591;CS8600;CS8602;CS8603;CS1998;CS1574;CS8604;CS0219;CS8629;CA2237 + + + + + + @@ -32,7 +38,7 @@ - + diff --git a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs index 493b52c5..31f25bee 100644 --- a/src/libs/LangChain.Core/Prompts/PromptTemplate.cs +++ b/src/libs/LangChain.Core/Prompts/PromptTemplate.cs @@ -115,7 +115,7 @@ public static async Task Deserialize(SerializedPromptTemplate da { if (string.IsNullOrEmpty(data.Template)) { - throw new Exception("Prompt template must have a template"); + throw new Exception("Template template must have a template"); } return new PromptTemplate(new PromptTemplateInput(data.Template, data.InputVariables) @@ -159,7 +159,29 @@ public static string InterpolateFString(string template, Dictionary + /// Safer version of that will not throw an exception if a variable is missing. + /// + public static string InterpolateFStringSafe(string template, Dictionary values) + { + List nodes = ParseFString(template); + return nodes.Aggregate("", (res, node) => + { + if (node.Type == "variable") + { + var parsedNode = node as VariableNode; + if (values.ContainsKey(parsedNode.Name)) + { + return res + values[parsedNode.Name]; + } + + return res + "{" + parsedNode.Name + "}"; + } + + return res + (node as LiteralNode).Text; + }); + } public static List ParseFString(string template) { // Core logic replicated from internals of pythons built in Formatter class. diff --git a/src/libs/LangChain.Core/VectorStores/VectorStore.cs b/src/libs/LangChain.Core/VectorStores/VectorStore.cs index d631e5f5..23e2f5e2 100644 --- a/src/libs/LangChain.Core/VectorStores/VectorStore.cs +++ b/src/libs/LangChain.Core/VectorStores/VectorStore.cs @@ -166,7 +166,7 @@ public abstract Task> MaxMarginalRelevanceSearchByVector( /// /// Maximal marginal relevance optimizes for similarity to query AND diversity among selected documents. /// - /// Query to look up documents similar to. + /// Set to look up documents similar to. /// Number of Documents to return. Defaults to 4. /// Number of Documents to fetch to pass to MMR algorithm. /// Number between 0 and 1 that determines the degree diff --git a/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs b/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs index 99e6bb29..3e511142 100644 --- a/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs +++ b/src/libs/LangChain.Core/VectorStores/VectorStoreRetriever.cs @@ -13,7 +13,7 @@ public class VectorStoreRetriever : BaseRetriever public VectorStore Vectorstore { get; init; } private ESearchType SearchType { get; init; } - private int K { get; init; } = 4; + public int K { get; set; } = 4; private float? ScoreThreshold { get; init; } diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs index 1c4ed7d6..70465bab 100644 --- a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpEmbeddings.cs @@ -6,6 +6,16 @@ namespace LangChain.Providers.LLamaSharp; public class LLamaSharpEmbeddings:IEmbeddings { + + public static LLamaSharpEmbeddings FromPath(string path, float temperature = 0) + { + return new LLamaSharpEmbeddings(new LLamaSharpConfiguration + { + PathToModelFile = path, + Temperature = temperature + }); + + } protected readonly LLamaSharpConfiguration _configuration; protected readonly LLamaWeights _model; protected readonly ModelParams _parameters; diff --git a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs index f9e3c785..b87f51de 100644 --- a/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs +++ b/src/libs/Providers/LangChain.Providers.LLamaSharp/LLamaSharpModelInstruction.cs @@ -6,6 +6,16 @@ namespace LangChain.Providers.LLamaSharp; public class LLamaSharpModelInstruction:LLamaSharpModelBase { + public static LLamaSharpModelInstruction FromPath(string path, float temperature = 0) + { + return new LLamaSharpModelInstruction(new LLamaSharpConfiguration + { + PathToModelFile = path, + Temperature = temperature + }); + + } + public LLamaSharpModelInstruction(LLamaSharpConfiguration configuration) : base(configuration) { } diff --git a/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/ChainTests.cs b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/ChainTests.cs new file mode 100644 index 00000000..ef092ff4 --- /dev/null +++ b/src/tests/LangChain.Providers.LLamaSharp.IntegrationTests/ChainTests.cs @@ -0,0 +1,107 @@ +using LangChain.Chains; +using LangChain.Databases.InMemory; +using LangChain.Docstore; +using LangChain.Providers.Downloader; +using static LangChain.Chains.Chain; +namespace LangChain.Providers.LLamaSharp.IntegrationTests; + +[TestClass] +public class ChainTests +{ + string ModelPath => HuggingFaceModelDownloader.Instance.GetModel("TheBloke/Thespis-13B-v0.5-GGUF", "thespis-13b-v0.5.Q2_K.gguf", "main").Result; + + [TestMethod] + public void PromptTest() + { + var chain= + Set("World", outputKey:"var2") + |Set("Hello", outputKey: "var1") + |Template("{var1}, {var2}", outputKey: "prompt"); + + var res = chain.Run(resultKey:"prompt").Result; + + Assert.AreEqual("Hello, World", res); + } + + [TestMethod] +#if CONTINUOUS_INTEGRATION_BUILD + [Ignore] +#endif + public void LLMChainTest() + { + var llm = LLamaSharpModelInstruction.FromPath(ModelPath); + var promptText = + @"You will be provided with information about pet. Your goal is to extract the pet name. + +Information: +{information} + +The pet name is +"; + + var chain= + Set("My dog name is Bob", outputKey: "information") + |Template(promptText, outputKey: "prompt") + |LLM(llm,inputKey:"prompt", outputKey:"text"); + + var res = chain.Run(resultKey:"text").Result; + + Assert.AreEqual("Bob", res); + } + + [TestMethod] +#if CONTINUOUS_INTEGRATION_BUILD + [Ignore] +#endif + public void RetreivalChainTest() + { + var llm = LLamaSharpModelInstruction.FromPath(ModelPath); + var embeddings = LLamaSharpEmbeddings.FromPath(ModelPath); + var documents = new string[] + { + "I spent entire day watching TV", + "My dog name is Bob", + "This icecream is delicious", + "It is cold in space" + }.ToDocuments(); + var index = InMemoryVectorStore + .CreateIndexFromDocuments(embeddings, documents).Result; + + string prompt1Text = + @"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + +{context} + +Question: {question} +Helpful Answer:"; + + var prompt2Text = + @"Human will provide you with sentence about pet. You need to answer with pet name. + +Human: My dog name is Jack +Answer: Jack +Human: I think the best name for a pet is ""Jerry"" +Answer: Jerry +Human: {pet_sentence} +Answer: "; + + + + var chainQuestion = + Set("What is the good name for a pet?", outputKey: "question") + | RetreiveDocuments(index, inputKey: "question", outputKey: "documents") + | StuffDocuments(inputKey: "documents", outputKey: "context") + | Template(prompt1Text, outputKey: "prompt") + | LLM(llm, inputKey: "prompt", outputKey: "pet_sentence"); + + var chainFilter = + // do not move the entire dictionary from the other chain + chainQuestion.AsIsolated(outputKey: "pet_sentence") + | Template(prompt2Text, outputKey: "prompt") + | LLM(llm, inputKey: "prompt", outputKey: "text"); + + + var res = chainFilter.Run(resultKey: "text").Result; + Assert.AreEqual("Bob", res); + } +} \ No newline at end of file