-
-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added something similar to LCEL (#54)
* bugfix. the chain calls should not erase thge data containing in input dictionary * Added something similar to LCEL
- Loading branch information
Showing
21 changed files
with
604 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Chains.HelperChains; | ||
using LangChain.Indexes; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains; | ||
|
||
public static class Chain | ||
{ | ||
public static BaseStackableChain Template(string template, | ||
string outputKey = "prompt") | ||
{ | ||
return new PromptChain(template, outputKey); | ||
} | ||
|
||
public static BaseStackableChain Set(string value, string outputKey = "value") | ||
{ | ||
return new SetChain(value, outputKey); | ||
} | ||
|
||
public static BaseStackableChain LLM(IChatModel llm, | ||
string inputKey = "prompt", string outputKey = "text") | ||
{ | ||
return new LLMChain(llm, inputKey, outputKey); | ||
} | ||
|
||
public static BaseStackableChain RetreiveDocuments(VectorStoreIndexWrapper index, | ||
string inputKey = "query", string outputKey = "documents") | ||
{ | ||
return new RetreiveDocumentsChain(index, inputKey, outputKey); | ||
} | ||
|
||
public static BaseStackableChain StuffDocuments( | ||
string inputKey = "documents", string outputKey = "combined") | ||
{ | ||
return new StuffDocumentsChain(inputKey, outputKey); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Chains.HelperChains.Exceptions; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public abstract class BaseStackableChain:IChain | ||
{ | ||
public string Name { get; set; } | ||
public virtual string[] InputKeys { get; protected set; } | ||
public virtual string[] OutputKeys { get; protected set; } | ||
|
||
protected string GenerateName() | ||
{ | ||
return GetType().Name; | ||
} | ||
|
||
private string GetInputs() | ||
{ | ||
return string.Join(",", InputKeys); | ||
} | ||
|
||
private string GetOutputs() | ||
{ | ||
return string.Join(",", OutputKeys); | ||
} | ||
|
||
string FormatInputValues(IChainValues values) | ||
{ | ||
List<string> res = new(); | ||
foreach (var key in InputKeys) | ||
{ | ||
if (!values.Value.ContainsKey(key)) | ||
{ | ||
res.Add($"{key} is expected but missing"); | ||
continue; | ||
}; | ||
res.Add($"{key}={values.Value[key]}"); | ||
} | ||
return string.Join(",\n", res); | ||
} | ||
|
||
public Task<IChainValues> CallAsync(IChainValues values, ICallbacks? callbacks = null, | ||
List<string>? tags = null, Dictionary<string, object>? metadata = null) | ||
{ | ||
try | ||
{ | ||
return InternallCall(values); | ||
} | ||
catch (StackableChainException) | ||
{ | ||
throw; | ||
} | ||
catch (Exception ex) | ||
{ | ||
var name=Name??GenerateName(); | ||
var inputValues= FormatInputValues(values); | ||
var message = $"Error occured in {name} with inputs \n{inputValues}\n."; | ||
|
||
throw new StackableChainException(message,ex); | ||
} | ||
|
||
} | ||
|
||
protected abstract Task<IChainValues> InternallCall(IChainValues values); | ||
|
||
public static StackChain operator |(BaseStackableChain a, BaseStackableChain b) | ||
Check warning on line 68 in src/libs/LangChain.Core/Chains/StackableChains/BaseStackableChain.cs GitHub Actions / Build, test and publish / Build, test and publish
|
||
{ | ||
return new StackChain(a, b); | ||
} | ||
} |
8 changes: 8 additions & 0 deletions
8
src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
namespace LangChain.Chains.HelperChains.Exceptions; | ||
|
||
public class StackableChainException:Exception | ||
Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs GitHub Actions / Build, test and publish / Build, test and publish
Check warning on line 3 in src/libs/LangChain.Core/Chains/StackableChains/Exceptions/StackableChainException.cs GitHub Actions / Build, test and publish / Build, test and publish
|
||
{ | ||
public StackableChainException(string message,Exception inner) : base(message, inner) | ||
{ | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
src/libs/LangChain.Core/Chains/StackableChains/LLMChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Providers; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class LLMChain:BaseStackableChain | ||
{ | ||
private readonly IChatModel _llm; | ||
|
||
public LLMChain(IChatModel llm, | ||
string inputKey="prompt", | ||
string outputKey="text" | ||
) | ||
{ | ||
InputKeys = new[] { inputKey }; | ||
OutputKeys = new[] { outputKey }; | ||
_llm = llm; | ||
} | ||
|
||
protected override async Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
var prompt = values.Value[InputKeys[0]].ToString(); | ||
var response=await _llm.GenerateAsync(new ChatRequest(new List<Message>() { prompt.AsSystemMessage() })); | ||
values.Value[OutputKeys[0]] = response.Messages.Last().Content; | ||
return values; | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
src/libs/LangChain.Core/Chains/StackableChains/PromptChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
using System.Text.RegularExpressions; | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using LangChain.Chains.LLM; | ||
using LangChain.Prompts; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class PromptChain: BaseStackableChain | ||
{ | ||
private readonly string _template; | ||
|
||
public PromptChain(string template,string outputKey="prompt") | ||
{ | ||
OutputKeys = new[] { outputKey }; | ||
_template = template; | ||
InputKeys = GetVariables().ToArray(); | ||
} | ||
|
||
List<string> GetVariables() | ||
{ | ||
string pattern = @"\{([^\{\}]+)\}"; | ||
var variables = new List<string>(); | ||
var matches = Regex.Matches(_template, pattern); | ||
foreach (Match match in matches) | ||
{ | ||
variables.Add(match.Groups[1].Value); | ||
} | ||
return variables; | ||
} | ||
|
||
|
||
|
||
|
||
protected override Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
// validate that input keys containing all variables | ||
var valueKeys = values.Value.Keys; | ||
var missing = InputKeys.Except(valueKeys); | ||
if (missing.Any()) | ||
{ | ||
throw new Exception($"Input keys must contain all variables in template. Missing: {string.Join(",",missing)}"); | ||
} | ||
|
||
var formattedPrompt = PromptTemplate.InterpolateFString(_template,values.Value); | ||
|
||
values.Value[OutputKeys[0]]= formattedPrompt; | ||
|
||
return Task.FromResult(values); | ||
} | ||
} |
31 changes: 31 additions & 0 deletions
31
src/libs/LangChain.Core/Chains/StackableChains/RetreiveDocumentsChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
using System.Numerics; | ||
using LangChain.Indexes; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class RetreiveDocumentsChain:BaseStackableChain | ||
{ | ||
private readonly VectorStoreIndexWrapper _index; | ||
private readonly int _amount; | ||
|
||
public RetreiveDocumentsChain(VectorStoreIndexWrapper index, string inputKey="query", string outputKey="documents", int amount=4) | ||
{ | ||
_index = index; | ||
_amount = amount; | ||
InputKeys = new[] { inputKey }; | ||
OutputKeys = new[] { outputKey }; | ||
} | ||
|
||
protected override async Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
var retreiver = _index.Store.AsRetreiver(); | ||
retreiver.K = _amount; | ||
|
||
var query = values.Value[InputKeys[0]].ToString(); | ||
var results = await retreiver.GetRelevantDocumentsAsync(query); | ||
values.Value[OutputKeys[0]] = results.ToList(); | ||
return values; | ||
} | ||
} |
23 changes: 23 additions & 0 deletions
23
src/libs/LangChain.Core/Chains/StackableChains/SetChain.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
using LangChain.Abstractions.Chains.Base; | ||
using LangChain.Abstractions.Schema; | ||
using LangChain.Callback; | ||
|
||
namespace LangChain.Chains.HelperChains; | ||
|
||
public class SetChain: BaseStackableChain | ||
{ | ||
private readonly string _query; | ||
public SetChain(string query, string outputKey="query") | ||
{ | ||
OutputKeys = new[] { outputKey }; | ||
_query = query; | ||
} | ||
|
||
protected override Task<IChainValues> InternallCall(IChainValues values) | ||
{ | ||
values.Value[OutputKeys[0]] = _query; | ||
return Task.FromResult(values); | ||
} | ||
|
||
|
||
} |
Oops, something went wrong.