Skip to content

Commit

Permalink
feat: callbacks refactor (#49)
Browse files Browse the repository at this point in the history
* callbacks refactor

* naming

* build fix

* implement base and console trace

* console logger

* remove commented

---------

Co-authored-by: Evgenii Khoroshev <[email protected]>
  • Loading branch information
khoroshevj and Evgenii Khoroshev authored Nov 6, 2023
1 parent ea43091 commit 6b7c5dc
Show file tree
Hide file tree
Showing 42 changed files with 1,699 additions and 568 deletions.
5 changes: 4 additions & 1 deletion examples/LangChain.Samples.Prompts/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
HumanMessagePromptTemplate.FromTemplate("{text}")
});

var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt));
var chainB = new LlmChain(new LlmChainInput(chat, chatPrompt)
{
Verbose = true
});

var resultB = await chainB.CallAsync(new ChainValues(new Dictionary<string, object>(3)
{
Expand Down
17 changes: 11 additions & 6 deletions examples/LangChain.Samples.SequentialChain/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

var chainOne = new LlmChain(new LlmChainInput(llm, firstPrompt)
{
Verbose = true,
OutputKey = "company_name"
});

Expand All @@ -20,16 +21,20 @@

var chainTwo = new LlmChain(new LlmChainInput(llm, secondPrompt));

var overallChain = new SequentialChain(new SequentialChainInput(new []
{
chainOne,
chainTwo
}, new []{"product"}));
var overallChain = new SequentialChain(new SequentialChainInput(
new[]
{
chainOne,
chainTwo
},
new[] { "product" },
new[] { "company_name", "text" }
));

var result = await overallChain.CallAsync(new ChainValues(new Dictionary<string, object>(1)
{
{ "product", "colourful socks" }
}));

Console.WriteLine(result.Value["text"]);
Console.WriteLine("Test");
Console.WriteLine("SequentialChain sample finished.");
128 changes: 74 additions & 54 deletions src/libs/LangChain.Core/Base/BaseCallbackHandler.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Docstore;
using LangChain.LLMS;
using LangChain.Providers;
using LangChain.Retrievers;
using LangChain.Schema;

namespace LangChain.Base;
Expand All @@ -7,11 +11,36 @@ namespace LangChain.Base;
public abstract class BaseCallbackHandler : IBaseCallbackHandler
{
/// <inheritdoc />
public string Name { get; protected set; }
public abstract string Name { get; }

public bool IgnoreLlm { get; set; }
public bool IgnoreRetry { get; set; }
public bool IgnoreChain { get; set; }
public bool IgnoreAgent { get; set; }
public bool IgnoreRetriever { get; set; }
public bool IgnoreChatModel { get; set; }

/// <summary>
///
/// </summary>
/// <param name="input"></param>
protected BaseCallbackHandler(IBaseCallbackHandlerInput input)
{
input = input ?? throw new ArgumentNullException(nameof(input));

IgnoreLlm = input.IgnoreLlm;
IgnoreRetry = input.IgnoreRetry;
IgnoreChain = input.IgnoreChain;
IgnoreAgent = input.IgnoreAgent;
IgnoreRetriever = input.IgnoreRetriever;
IgnoreChatModel = input.IgnoreChatModel;
}

/// <inheritdoc />
public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string runId, string? parentRunId = null,
Dictionary<string, object>? extraParams = null);
public abstract Task HandleLlmStartAsync(
BaseLlm llm, string[] prompts, string runId, string? parentRunId = null,
List<string>? tags = null, Dictionary<string, object>? metadata = null,
string name = null, Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleLlmNewTokenAsync(string token, string runId, string? parentRunId = null);
Expand All @@ -23,20 +52,42 @@ public abstract Task HandleLlmStartAsync(BaseLlm llm, string[] prompts, string r
public abstract Task HandleLlmEndAsync(LlmResult output, string runId, string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleChatModelStartAsync(Dictionary<string, object> llm, List<List<object>> messages, string runId, string? parentRunId = null,
public abstract Task HandleChatModelStartAsync(BaseLlm llm, List<List<Message>> messages, string runId,
string? parentRunId = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleChainStartAsync(Dictionary<string, object> chain, Dictionary<string, object> inputs, string runId, string? parentRunId = null);
public abstract Task HandleChainStartAsync(IChain chain, Dictionary<string, object> inputs,
string runId, string? parentRunId = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string runType = null,
string name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleChainErrorAsync(Exception err, string runId, string? parentRunId = null);
public abstract Task HandleChainErrorAsync(
Exception err, string runId,
Dictionary<string, object>? inputs = null,
string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleChainEndAsync(Dictionary<string, object> outputs, string runId, string? parentRunId = null);
public abstract Task HandleChainEndAsync(
Dictionary<string, object>? inputs,
Dictionary<string, object> outputs,
string runId,
string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleToolStartAsync(Dictionary<string, object> tool, string input, string runId, string? parentRunId = null);
public abstract Task HandleToolStartAsync(
Dictionary<string, object> tool,
string input, string runId,
string? parentRunId = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string runType = null,
string name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleToolErrorAsync(Exception err, string runId, string? parentRunId = null);
Expand All @@ -54,55 +105,24 @@ public abstract Task HandleChatModelStartAsync(Dictionary<string, object> llm, L
public abstract Task HandleAgentEndAsync(Dictionary<string, object> action, string runId, string? parentRunId = null);

/// <inheritdoc />
public abstract Task HandleRetrieverStartAsync(string query, string runId, string? parentRunId);
public abstract Task HandleRetrieverStartAsync(
BaseRetriever retriever,
string query,
string runId,
string? parentRunId,
List<string>? tags = null,
Dictionary<string, object>? metadata = null,
string? runType = null,
string? name = null,
Dictionary<string, object>? extraParams = null);

/// <inheritdoc />
public abstract Task HandleRetrieverEndAsync(string query, string runId, string? parentRunId);
public abstract Task HandleRetrieverEndAsync(
string query,
List<Document> documents,
string runId,
string? parentRunId);

/// <inheritdoc />
public abstract Task HandleRetrieverErrorAsync(Exception error, string query, string runId, string? parentRunId);

/// <summary>
///
/// </summary>
public bool IgnoreLlm { get; set; }

/// <summary>
///
/// </summary>
public bool IgnoreChain { get; set; }

/// <summary>
///
/// </summary>
public bool IgnoreAgent { get; set; }

public bool IgnoreRetriever { get; set; }

/// <summary>
///
/// </summary>
protected BaseCallbackHandler()
{
Name = Guid.NewGuid().ToString();
}

/// <summary>
///
/// </summary>
/// <param name="input"></param>
protected BaseCallbackHandler(IBaseCallbackHandlerInput input) : this()
{
input = input ?? throw new ArgumentNullException(nameof(input));

IgnoreLlm = input.IgnoreLlm;
IgnoreChain = input.IgnoreChain;
IgnoreAgent = input.IgnoreAgent;
}

/// <summary>
///
/// </summary>
/// <returns></returns>
public abstract IBaseCallbackHandler Copy();
}
48 changes: 45 additions & 3 deletions src/libs/LangChain.Core/Base/BaseChain.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LangChain.Abstractions.Chains.Base;
using LangChain.Abstractions.Schema;
using LangChain.Callback;
using LangChain.Chains;
using LangChain.Schema;

Expand All @@ -9,7 +10,7 @@ namespace LangChain.Base;
using LoadValues = Dictionary<string, object>;

/// <inheritdoc />
public abstract class BaseChain : IChain
public abstract class BaseChain(IChainInputs fields) : IChain
{
const string RunKey = "__run";

Expand Down Expand Up @@ -57,7 +58,7 @@ public abstract class BaseChain : IChain

throw new Exception("Return values have multiple keys, 'run' only supported when one key currently");
}

/// <summary>
/// Run the chain using a simple input/output.
/// </summary>
Expand All @@ -83,8 +84,49 @@ public virtual async Task<string> Run(Dictionary<string, object> input)
/// Execute the chain, using the values provided.
/// </summary>
/// <param name="values">The <see cref="ChainValues"/> to use.</param>
/// <param name="callbacks"></param>
/// <param name="tags"></param>
/// <param name="metadata"></param>
/// <returns></returns>
public async Task<IChainValues> CallAsync(
IChainValues values,
ICallbacks? callbacks = null,
List<string>? tags = null,
Dictionary<string, object>? metadata = null)
{
var callbackManager = await CallbackManager.Configure(
callbacks,
fields.Callbacks,
fields.Verbose,
tags,
fields.Tags,
metadata,
fields.Metadata);

var runManager = await callbackManager.HandleChainStart(this, values);

try
{
var result = await CallAsync(values, runManager);

await runManager.HandleChainEndAsync(values, result);

return result;
}
catch (Exception e)
{
await runManager.HandleChainErrorAsync(e, values);
throw;
}
}

/// <summary>
/// Execute the chain, using the values provided.
/// </summary>
/// <param name="values">The <see cref="ChainValues"/> to use.</param>
/// <param name="runManager"></param>
/// <returns></returns>
public abstract Task<IChainValues> CallAsync(IChainValues values);
protected abstract Task<IChainValues> CallAsync(IChainValues values, CallbackManagerForChainRun? runManager);

/// <summary>
///
Expand Down
37 changes: 37 additions & 0 deletions src/libs/LangChain.Core/Base/BaseChainInput.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using LangChain.Callback;

namespace LangChain.Base;

public interface IBaseChainInput
{
/// <summary>
/// Optional list of callback handlers (or callback manager). Defaults to None.
/// Callback handlers are called throughout the lifecycle of a call to a chain,
/// starting with on_chain_start, ending with on_chain_end or on_chain_error.
/// Each custom chain can optionally call additional callback methods, see Callback docs
/// for full details.
/// </summary>
public ICallbacks? Callbacks { get; set; }

/// <summary>
/// Whether or not run in verbose mode. In verbose mode, some intermediate logs
/// will be printed to the console.
/// </summary>
public bool Verbose { get; set; }

/// <summary>
/// Optional list of tags associated with the chain. Defaults to None.
/// These tags will be associated with each call to this chain,
/// and passed as arguments to the handlers defined in `callbacks`.
/// You can use these to eg identify a specific instance of a chain with its use case.
/// </summary>
public List<string> Tags { get; set; }

/// <summary>
/// Optional metadata associated with the chain. Defaults to None.
/// This metadata will be associated with each call to this chain,
/// and passed as arguments to the handlers defined in `callbacks`.
/// You can use these to eg identify a specific instance of a chain with its use case.
/// </summary>
public Dictionary<string, object> Metadata { get; set; }
}
6 changes: 2 additions & 4 deletions src/libs/LangChain.Core/Base/BaseLangChain.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ namespace LangChain.Base;
/// <inheritdoc />
public abstract class BaseLangChain : IBaseLangChainParams
{
private const bool DefaultVerbosity = false;

/// <summary>
///
/// </summary>
public bool? Verbose { get; set; }
public bool Verbose { get; set; }

/// <summary>
///
Expand All @@ -18,6 +16,6 @@ protected BaseLangChain(IBaseLangChainParams parameters)
{
parameters = parameters ?? throw new ArgumentNullException(nameof(parameters));

Verbose = parameters.Verbose ?? DefaultVerbosity;
Verbose = parameters.Verbose;
}
}
9 changes: 4 additions & 5 deletions src/libs/LangChain.Core/Base/ChainInputs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ namespace LangChain.Base;
/// <inheritdoc />
public class ChainInputs : IChainInputs
{
/// <inheritdoc />
public CallbackManager? CallbackManager { get; set; }

/// <inheritdoc />
public bool? Verbose { get; set; }
public ICallbacks? Callbacks { get; set; }
public List<string> Tags { get; set; }
public Dictionary<string, object> Metadata { get; set; }
public bool Verbose { get; set; }
}
4 changes: 1 addition & 3 deletions src/libs/LangChain.Core/Base/Handler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ namespace LangChain.Base;
/// <inheritdoc />
public abstract class Handler : BaseCallbackHandler
{
/// <inheritdoc />
public override IBaseCallbackHandler Copy()
protected Handler(IBaseCallbackHandlerInput input) : base(input)
{
throw new NotImplementedException();
}
}
Loading

0 comments on commit 6b7c5dc

Please sign in to comment.