Skip to content

Commit

Permalink
Langchain::Assistant works with AWS Bedrock-hosted Anthropic models (#…
Browse files Browse the repository at this point in the history
…849)

* Langchain::Assistant works with AWS Bedrock-hosted Anthropic models

* specs

* Update adapter.rb

* Fixes

* changelog entry
  • Loading branch information
andreibondarev authored Oct 29, 2024
1 parent 7baf643 commit 683a85b
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 235 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
- [BREAKING]: A breaking change. After an upgrade, your app may need modifications to keep working correctly.
- [FEATURE]: A non-breaking improvement to the app. Either introduces new functionality, or improves on an existing feature.
- [BUGFIX]: Fixes a bug with a non-breaking change.
- [COMPAT]: Compatibility improvements - changes to make Administrate more compatible with different dependency versions.
- [COMPAT]: Compatibility improvements - changes to make Langchain.rb more compatible with different dependency versions.
- [OPTIM]: Optimization or performance increase.
- [DOCS]: Documentation changes. No changes to the library's behavior.
- [SECURITY]: A change which fixes a security vulnerability.

## [Unreleased]
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/858] Assistant, when using Anthropic, now also accepts image_url in the message.
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/861] Clean up passing `max_tokens` to Anthropic constructor and chat method
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/849] Langchain::Assistant now works with AWS Bedrock-hosted Anthropic models
- [OPTIM] [https://github.com/patterns-ai-core/langchainrb/pull/849] Simplify Langchain::LLM::AwsBedrock class

## [0.19.0] - 2024-10-23
- [BREAKING] [https://github.com/patterns-ai-core/langchainrb/pull/840] Rename `chat_completion_model_name` parameter to `chat_model` in Langchain::LLM parameters.
Expand Down
16 changes: 8 additions & 8 deletions Gemfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ GEM
faraday-multipart (>= 1)
ast (2.4.2)
aws-eventstream (1.3.0)
aws-partitions (1.937.0)
aws-sdk-bedrockruntime (1.9.0)
aws-sdk-core (~> 3, >= 3.193.0)
aws-sigv4 (~> 1.1)
aws-sdk-core (3.196.1)
aws-partitions (1.992.0)
aws-sdk-bedrockruntime (1.27.0)
aws-sdk-core (~> 3, >= 3.210.0)
aws-sigv4 (~> 1.5)
aws-sdk-core (3.210.0)
aws-eventstream (~> 1, >= 1.3.0)
aws-partitions (~> 1, >= 1.651.0)
aws-sigv4 (~> 1.8)
aws-partitions (~> 1, >= 1.992.0)
aws-sigv4 (~> 1.9)
jmespath (~> 1, >= 1.6.1)
aws-sigv4 (1.8.0)
aws-sigv4 (1.10.0)
aws-eventstream (~> 1, >= 1.0.2)
baran (0.1.12)
base64 (0.2.0)
Expand Down
13 changes: 7 additions & 6 deletions lib/langchain/assistant/llm/adapter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@ module LLM
# TODO: Fix the message truncation when context window is exceeded
class Adapter
def self.build(llm)
case llm
when Langchain::LLM::Anthropic
if llm.is_a?(Langchain::LLM::Anthropic)
LLM::Adapters::Anthropic.new
when Langchain::LLM::GoogleGemini, Langchain::LLM::GoogleVertexAI
elsif llm.is_a?(Langchain::LLM::AwsBedrock) && llm.defaults[:chat_model].include?("anthropic")
LLM::Adapters::AwsBedrockAnthropic.new
elsif llm.is_a?(Langchain::LLM::GoogleGemini) || llm.is_a?(Langchain::LLM::GoogleVertexAI)
LLM::Adapters::GoogleGemini.new
when Langchain::LLM::MistralAI
elsif llm.is_a?(Langchain::LLM::MistralAI)
LLM::Adapters::MistralAI.new
when Langchain::LLM::Ollama
elsif llm.is_a?(Langchain::LLM::Ollama)
LLM::Adapters::Ollama.new
when Langchain::LLM::OpenAI
elsif llm.is_a?(Langchain::LLM::OpenAI)
LLM::Adapters::OpenAI.new
else
raise ArgumentError, "Unsupported LLM type: #{llm.class}"
Expand Down
35 changes: 35 additions & 0 deletions lib/langchain/assistant/llm/adapters/aws_bedrock_anthropic.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# frozen_string_literal: true

module Langchain
class Assistant
module LLM
module Adapters
class AwsBedrockAnthropic < Anthropic
private

# @param [String] choice
# @param [Boolean] _parallel_tool_calls
# @return [Hash]
def build_tool_choice(choice, _parallel_tool_calls)
# Aws Bedrock hosted Anthropic does not support parallel tool calls
Langchain.logger.warn "WARNING: parallel_tool_calls is not supported by AWS Bedrock Anthropic currently"

tool_choice_object = {}

case choice
when "auto"
tool_choice_object[:type] = "auto"
when "any"
tool_choice_object[:type] = "any"
else
tool_choice_object[:type] = "tool"
tool_choice_object[:name] = choice
end

tool_choice_object
end
end
end
end
end
end
183 changes: 69 additions & 114 deletions lib/langchain/llm/aws_bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,51 +7,40 @@ module Langchain::LLM
# gem 'aws-sdk-bedrockruntime', '~> 1.1'
#
# Usage:
# llm = Langchain::LLM::AwsBedrock.new(llm_options: {})
# llm = Langchain::LLM::AwsBedrock.new(default_options: {})
#
class AwsBedrock < Base
DEFAULTS = {
chat_model: "anthropic.claude-v2",
completion_model: "anthropic.claude-v2",
chat_model: "anthropic.claude-3-5-sonnet-20240620-v1:0",
completion_model: "anthropic.claude-v2:1",
embedding_model: "amazon.titan-embed-text-v1",
max_tokens_to_sample: 300,
temperature: 1,
top_k: 250,
top_p: 0.999,
stop_sequences: ["\n\nHuman:"],
anthropic_version: "bedrock-2023-05-31",
return_likelihoods: "NONE",
count_penalty: {
scale: 0,
apply_to_whitespaces: false,
apply_to_punctuations: false,
apply_to_numbers: false,
apply_to_stopwords: false,
apply_to_emojis: false
},
presence_penalty: {
scale: 0,
apply_to_whitespaces: false,
apply_to_punctuations: false,
apply_to_numbers: false,
apply_to_stopwords: false,
apply_to_emojis: false
},
frequency_penalty: {
scale: 0,
apply_to_whitespaces: false,
apply_to_punctuations: false,
apply_to_numbers: false,
apply_to_stopwords: false,
apply_to_emojis: false
}
return_likelihoods: "NONE"
}.freeze

attr_reader :client, :defaults

SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon cohere].freeze
SUPPORTED_COMPLETION_PROVIDERS = %i[
anthropic
ai21
cohere
meta
].freeze

SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[
anthropic
ai21
mistral
].freeze

SUPPORTED_EMBEDDING_PROVIDERS = %i[
amazon
cohere
].freeze

def initialize(aws_client_options: {}, default_options: {})
depends_on "aws-sdk-bedrockruntime", req: "aws-sdk-bedrockruntime"
Expand All @@ -64,8 +53,7 @@ def initialize(aws_client_options: {}, default_options: {})
temperature: {},
max_tokens: {default: @defaults[:max_tokens_to_sample]},
metadata: {},
system: {},
anthropic_version: {default: "bedrock-2023-05-31"}
system: {}
)
chat_parameters.ignore(:n, :user)
chat_parameters.remap(stop: :stop_sequences)
Expand Down Expand Up @@ -100,23 +88,25 @@ def embed(text:, **params)
# @param params extra parameters passed to Aws::BedrockRuntime::Client#invoke_model
# @return [Langchain::LLM::AnthropicResponse], [Langchain::LLM::CohereResponse] or [Langchain::LLM::AI21Response] Response object
#
def complete(prompt:, **params)
raise "Completion provider #{completion_provider} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(completion_provider)
def complete(
prompt:,
model: @defaults[:completion_model],
**params
)
raise "Completion provider #{model} is not supported." unless SUPPORTED_COMPLETION_PROVIDERS.include?(provider_name(model))

raise "Model #{@defaults[:completion_model]} only supports #chat." if @defaults[:completion_model].include?("claude-3")

parameters = compose_parameters params
parameters = compose_parameters(params, model)

parameters[:prompt] = wrap_prompt prompt

response = client.invoke_model({
model_id: @defaults[:completion_model],
model_id: model,
body: parameters.to_json,
content_type: "application/json",
accept: "application/json"
})

parse_response response
parse_response(response, model)
end

# Generate a chat completion for a given prompt
Expand All @@ -137,10 +127,11 @@ def complete(prompt:, **params)
# @return [Langchain::LLM::AnthropicResponse] Response object
def chat(params = {}, &block)
parameters = chat_parameters.to_params(params)
parameters = compose_parameters(parameters, parameters[:model])

raise ArgumentError.new("messages argument is required") if Array(parameters[:messages]).empty?

raise "Model #{parameters[:model]} does not support chat completions." unless Langchain::LLM::AwsBedrock::SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(completion_provider)
unless SUPPORTED_CHAT_COMPLETION_PROVIDERS.include?(provider_name(parameters[:model]))
raise "Chat provider #{parameters[:model]} is not supported."
end

if block
response_chunks = []
Expand Down Expand Up @@ -168,12 +159,26 @@ def chat(params = {}, &block)
accept: "application/json"
})

parse_response response
parse_response(response, parameters[:model])
end
end

private

def parse_model_id(model_id)
model_id
.gsub("us.", "") # Meta append "us." to their model ids
.split(".")
end

def provider_name(model_id)
parse_model_id(model_id).first.to_sym
end

def model_name(model_id)
parse_model_id(model_id).last
end

def completion_provider
@defaults[:completion_model].split(".").first.to_sym
end
Expand All @@ -200,15 +205,17 @@ def max_tokens_key
end
end

def compose_parameters(params)
if completion_provider == :anthropic
compose_parameters_anthropic params
elsif completion_provider == :cohere
compose_parameters_cohere params
elsif completion_provider == :ai21
compose_parameters_ai21 params
elsif completion_provider == :meta
compose_parameters_meta params
def compose_parameters(params, model_id)
if provider_name(model_id) == :anthropic
compose_parameters_anthropic(params)
elsif provider_name(model_id) == :cohere
compose_parameters_cohere(params)
elsif provider_name(model_id) == :ai21
params
elsif provider_name(model_id) == :meta
params
elsif provider_name(model_id) == :mistral
params
end
end

Expand All @@ -220,15 +227,17 @@ def compose_embedding_parameters(params)
end
end

def parse_response(response)
if completion_provider == :anthropic
def parse_response(response, model_id)
if provider_name(model_id) == :anthropic
Langchain::LLM::AnthropicResponse.new(JSON.parse(response.body.string))
elsif completion_provider == :cohere
elsif provider_name(model_id) == :cohere
Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
elsif completion_provider == :ai21
elsif provider_name(model_id) == :ai21
Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
elsif completion_provider == :meta
elsif provider_name(model_id) == :meta
Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
elsif provider_name(model_id) == :mistral
Langchain::LLM::MistralAIResponse.new(JSON.parse(response.body.string))
end
end

Expand Down Expand Up @@ -276,61 +285,7 @@ def compose_parameters_cohere(params)
end

def compose_parameters_anthropic(params)
default_params = @defaults.merge(params)

{
max_tokens_to_sample: default_params[:max_tokens_to_sample],
temperature: default_params[:temperature],
top_k: default_params[:top_k],
top_p: default_params[:top_p],
stop_sequences: default_params[:stop_sequences],
anthropic_version: default_params[:anthropic_version]
}
end

def compose_parameters_ai21(params)
default_params = @defaults.merge(params)

{
maxTokens: default_params[:max_tokens_to_sample],
temperature: default_params[:temperature],
topP: default_params[:top_p],
stopSequences: default_params[:stop_sequences],
countPenalty: {
scale: default_params[:count_penalty][:scale],
applyToWhitespaces: default_params[:count_penalty][:apply_to_whitespaces],
applyToPunctuations: default_params[:count_penalty][:apply_to_punctuations],
applyToNumbers: default_params[:count_penalty][:apply_to_numbers],
applyToStopwords: default_params[:count_penalty][:apply_to_stopwords],
applyToEmojis: default_params[:count_penalty][:apply_to_emojis]
},
presencePenalty: {
scale: default_params[:presence_penalty][:scale],
applyToWhitespaces: default_params[:presence_penalty][:apply_to_whitespaces],
applyToPunctuations: default_params[:presence_penalty][:apply_to_punctuations],
applyToNumbers: default_params[:presence_penalty][:apply_to_numbers],
applyToStopwords: default_params[:presence_penalty][:apply_to_stopwords],
applyToEmojis: default_params[:presence_penalty][:apply_to_emojis]
},
frequencyPenalty: {
scale: default_params[:frequency_penalty][:scale],
applyToWhitespaces: default_params[:frequency_penalty][:apply_to_whitespaces],
applyToPunctuations: default_params[:frequency_penalty][:apply_to_punctuations],
applyToNumbers: default_params[:frequency_penalty][:apply_to_numbers],
applyToStopwords: default_params[:frequency_penalty][:apply_to_stopwords],
applyToEmojis: default_params[:frequency_penalty][:apply_to_emojis]
}
}
end

def compose_parameters_meta(params)
default_params = @defaults.merge(params)

{
temperature: default_params[:temperature],
top_p: default_params[:top_p],
max_gen_len: default_params[:max_tokens_to_sample]
}
params.merge(anthropic_version: "bedrock-2023-05-31")
end

def response_from_chunks(chunks)
Expand Down
Loading

0 comments on commit 683a85b

Please sign in to comment.