Skip to content

Commit

Permalink
feat: Added something similar to LCEL (#54)
Browse files Browse the repository at this point in the history
* bugfix. the chain calls should not erase thge data containing in input dictionary

* Added something similar to LCEL
  • Loading branch information
TesAnti authored Nov 8, 2023
1 parent e545e5d commit f4aa92e
Show file tree
Hide file tree
Showing 21 changed files with 604 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@
using System.Threading.Tasks;
using LangChain.Abstractions.Embeddings.Base;
using LangChain.Docstore;
using LangChain.Indexes;
using LangChain.TextSplitters;
using LangChain.VectorStores;

namespace LangChain.Databases.InMemory
{
public class InMemoryVectorStore:VectorStore
{
public static async Task<VectorStoreIndexWrapper> CreateIndexFromDocuments(IEmbeddings embeddings,List<Document> documents)
{
InMemoryVectorStore vectorStore = new InMemoryVectorStore(embeddings);
var textSplitter = new CharacterTextSplitter();
VectorStoreIndexCreator indexCreator = new VectorStoreIndexCreator(vectorStore, textSplitter);
var index = await indexCreator.FromDocumentsAsync(documents);
return index;
}

private readonly Func<float[], float[], float> _distanceFunction;
List<(float[] vec, string id, Document doc)> _storage = new List<(float[] vec, string id, Document doc)>();
Expand Down
38 changes: 38 additions & 0 deletions src/libs/LangChain.Core/Chains/Chain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Chains.HelperChains;
using LangChain.Indexes;
using LangChain.Providers;

namespace LangChain.Chains;

public static class Chain
{
public static BaseStackableChain Template(string template,
string outputKey = "prompt")
{
return new PromptChain(template, outputKey);
}

public static BaseStackableChain Set(string value, string outputKey = "value")
{
return new SetChain(value, outputKey);
}

public static BaseStackableChain LLM(IChatModel llm,
string inputKey = "prompt", string outputKey = "text")
{
return new LLMChain(llm, inputKey, outputKey);
}

public static BaseStackableChain RetreiveDocuments(VectorStoreIndexWrapper index,
string inputKey = "query", string outputKey = "documents")
{
return new RetreiveDocumentsChain(index, inputKey, outputKey);
}

public static BaseStackableChain StuffDocuments(
string inputKey = "documents", string outputKey = "combined")
{
return new StuffDocumentsChain(inputKey, outputKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,19 @@ protected override async Task<IChainValues> CallAsync(IChainValues values, Callb
.ToDictionary(kv => kv.Key, kv => kv.Value);

var (output, returnDict) = await CombineDocsAsync((docs as List<Document>), otherKeys);

returnDict[OutputKey] = output;

return new ChainValues(returnDict);
// merge dictionaries
foreach (var kv in returnDict)
{
if (!returnDict.ContainsKey(kv.Key))
{
values.Value[kv.Key] = returnDict[kv.Key];
}
}

return values;
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public class StuffDocumentsChainInput(ILlmChain llmChain) : BaseCombineDocuments
public ILlmChain LlmChain { get; } = llmChain;

/// <summary>
/// Prompt to use to format each document, gets passed to `format_document`.
/// Template to use to format each document, gets passed to `format_document`.
/// </summary>
public BasePromptTemplate DocumentPrompt { get; set; } = new PromptTemplate(
new PromptTemplateInput(
Expand Down
20 changes: 17 additions & 3 deletions src/libs/LangChain.Core/Chains/LLM/LLMChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,24 @@ protected override async Task<IChainValues> CallAsync(IChainValues values, Callb
Console.WriteLine("\n".PadLeft(Console.WindowWidth, '<'));
}

if(string.IsNullOrEmpty(OutputKey))
return new ChainValues(response.Messages.Last().Content);
Dictionary<string, object> returnDict = new Dictionary<string, object>();


if (string.IsNullOrEmpty(OutputKey))
returnDict["text"] = response.Messages.Last().Content;
else
returnDict[OutputKey] = response.Messages.Last().Content;

return new ChainValues(OutputKey,response.Messages.Last().Content);
// merge dictionaries
foreach (var kv in returnDict)
{
if (!returnDict.ContainsKey(kv.Key))
{
values.Value[kv.Key] = returnDict[kv.Key];
}
}

return values;
}

public async Task<object> Predict(ChainValues values)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.HelperChains.Exceptions;

namespace LangChain.Chains.HelperChains;

public abstract class BaseStackableChain:IChain
{
public string Name { get; set; }
public virtual string[] InputKeys { get; protected set; }
public virtual string[] OutputKeys { get; protected set; }

protected string GenerateName()
{
return GetType().Name;
}

private string GetInputs()
{
return string.Join(",", InputKeys);
}

private string GetOutputs()
{
return string.Join(",", OutputKeys);
}

string FormatInputValues(IChainValues values)
{
List<string> res = new();
foreach (var key in InputKeys)
{
if (!values.Value.ContainsKey(key))
{
res.Add($"{key} is expected but missing");
continue;
};
res.Add($"{key}={values.Value[key]}");
}
return string.Join(",\n", res);
}

public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null,
List<string>? tags = null, Dictionary<string, object>? metadata = null)
{
try
{
return InternallCall(values);
}
catch (StackableChainException)
{
throw;
}
catch (Exception ex)
{
var name=Name??GenerateName();
var inputValues= FormatInputValues(values);
var message = $"Error occured in {name} with inputs \n{inputValues}\n.";

throw new StackableChainException(message,ex);
}

}

protected abstract Task<IChainValues> InternallCall(IChainValues values);

public static StackChain operator |(BaseStackableChain a, BaseStackableChain b)

Check warning on line 68 in src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Provide a method named 'BitwiseOr' as a friendly alternate for operator op_BitwiseOr (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca2225)
{
return new StackChain(a, b);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace LangChain.Chains.HelperChains.Exceptions;

public class StackableChainException:Exception

Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Add the following constructor to StackableChainException: public StackableChainException() (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1032)

Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs

View workflow job for this annotation

GitHub Actions / Build, test and publish / Build, test and publish

Add the following constructor to StackableChainException: public StackableChainException(string message) (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1032)
{
public StackableChainException(string message,Exception inner) : base(message, inner)
{
}
}
28 changes: 28 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Providers;

namespace LangChain.Chains.HelperChains;

public class LLMChain:BaseStackableChain
{
private readonly IChatModel _llm;

public LLMChain(IChatModel llm,
string inputKey="prompt",
string outputKey="text"
)
{
InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };
_llm = llm;
}

protected override async Task<IChainValues> InternallCall(IChainValues values)
{
var prompt = values.Value[InputKeys[0]].ToString();
var response=await _llm.GenerateAsync(new ChatRequest(new List<Message>() { prompt.AsSystemMessage() }));
values.Value[OutputKeys[0]] = response.Messages.Last().Content;
return values;
}
}
52 changes: 52 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using System.Text.RegularExpressions;
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains.LLM;
using LangChain.Prompts;

namespace LangChain.Chains.HelperChains;

public class PromptChain: BaseStackableChain
{
private readonly string _template;

public PromptChain(string template,string outputKey="prompt")
{
OutputKeys = new[] { outputKey };
_template = template;
InputKeys = GetVariables().ToArray();
}

List<string> GetVariables()
{
string pattern = @"\{([^\{\}]+)\}";
var variables = new List<string>();
var matches = Regex.Matches(_template, pattern);
foreach (Match match in matches)
{
variables.Add(match.Groups[1].Value);
}
return variables;
}




protected override Task<IChainValues> InternallCall(IChainValues values)
{
// validate that input keys containing all variables
var valueKeys = values.Value.Keys;
var missing = InputKeys.Except(valueKeys);
if (missing.Any())
{
throw new Exception($"Input keys must contain all variables in template. Missing: {string.Join(",",missing)}");
}

var formattedPrompt = PromptTemplate.InterpolateFString(_template,values.Value);

values.Value[OutputKeys[0]]= formattedPrompt;

return Task.FromResult(values);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using System.Numerics;
using LangChain.Indexes;

namespace LangChain.Chains.HelperChains;

public class RetreiveDocumentsChain:BaseStackableChain
{
private readonly VectorStoreIndexWrapper _index;
private readonly int _amount;

public RetreiveDocumentsChain(VectorStoreIndexWrapper index, string inputKey="query", string outputKey="documents", int amount=4)
{
_index = index;
_amount = amount;
InputKeys = new[] { inputKey };
OutputKeys = new[] { outputKey };
}

protected override async Task<IChainValues> InternallCall(IChainValues values)
{
var retreiver = _index.Store.AsRetreiver();
retreiver.K = _amount;

var query = values.Value[InputKeys[0]].ToString();
var results = await retreiver.GetRelevantDocumentsAsync(query);
values.Value[OutputKeys[0]] = results.ToList();
return values;
}
}
23 changes: 23 additions & 0 deletions src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;

namespace LangChain.Chains.HelperChains;

public class SetChain: BaseStackableChain
{
private readonly string _query;
public SetChain(string query, string outputKey="query")
{
OutputKeys = new[] { outputKey };
_query = query;
}

protected override Task<IChainValues> InternallCall(IChainValues values)
{
values.Value[OutputKeys[0]] = _query;
return Task.FromResult(values);
}


}
Loading

0 comments on commit f4aa92e

Please sign in to comment.