Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Filip Michalsky committed Mar 26, 2024
1 parent 483b004 commit 3432a67
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 60 deletions.
13 changes: 9 additions & 4 deletions salesgpt/agents.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from copy import deepcopy
from typing import Any, Callable, Dict, List, Union

from langchain.agents import (AgentExecutor, LLMSingleActionAgent,
create_openai_tools_agent)
from langchain.agents import (
AgentExecutor,
LLMSingleActionAgent,
create_openai_tools_agent,
)
from langchain.chains import LLMChain, RetrievalQA
from langchain.chains.base import Chain
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.agents import (_convert_agent_action_to_messages,
_convert_agent_observation_to_messages)
from langchain_core.agents import (
_convert_agent_action_to_messages,
_convert_agent_observation_to_messages,
)
from langchain_core.language_models.llms import create_base_retry_decorator
from litellm import acompletion
from pydantic import Field
Expand Down
6 changes: 4 additions & 2 deletions salesgpt/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from langchain_community.chat_models import ChatLiteLLM

from salesgpt.logger import time_logger
from salesgpt.prompts import (SALES_AGENT_INCEPTION_PROMPT,
STAGE_ANALYZER_INCEPTION_PROMPT)
from salesgpt.prompts import (
SALES_AGENT_INCEPTION_PROMPT,
STAGE_ANALYZER_INCEPTION_PROMPT,
)


class StageAnalyzerChain(LLMChain):
Expand Down
13 changes: 5 additions & 8 deletions salesgpt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, SimpleChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from langchain_core.messages import (
AIMessage)
from langchain_openai import ChatOpenAI

from salesgpt.tools import completion_bedrock
Expand Down Expand Up @@ -60,17 +58,16 @@ def _generate(
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]

print(messages)
response = completion_bedrock(
model_id=self.model,
system_prompt=self.system_prompt,
messages=[{"content": last_message.content, "role": "user"}],
max_tokens=1000
max_tokens=1000,
)
print('output', response)
content = response['content'][0]['text']
print("output", response)
content = response["content"][0]["text"]
message = AIMessage(content=content)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])

14 changes: 9 additions & 5 deletions salesgpt/salesgptapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import json
import re

from langchain_community.chat_models import ChatLiteLLM
from langchain_community.chat_models import BedrockChat
from langchain_community.chat_models import BedrockChat, ChatLiteLLM
from langchain_openai import ChatOpenAI

from salesgpt.agents import SalesGPT
from salesgpt.models import BedrockCustomModel

Expand All @@ -23,8 +23,12 @@ def __init__(
self.verbose = verbose
self.max_num_turns = max_num_turns
self.model_name = model_name
if 'anthropic' in model_name:
self.llm = BedrockCustomModel(type='bedrock-model', model=model_name, system_prompt="You are a helpful assistant.")
if "anthropic" in model_name:
self.llm = BedrockCustomModel(
type="bedrock-model",
model=model_name,
system_prompt="You are a helpful assistant.",
)
else:
# self.llm = ChatOpenAI(model_name=model_name, temperature=0)
self.llm = ChatLiteLLM(temperature=0.2, model_name=model_name)
Expand Down Expand Up @@ -138,7 +142,7 @@ def do(self, human_input=None):
"tool_input": tool_input,
"action_output": action_output,
"action_input": action_input,
"model_name": self.model_name
"model_name": self.model_name,
}
return payload

Expand Down
90 changes: 49 additions & 41 deletions salesgpt/tools.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import os

import boto3
import requests
from langchain.agents import Tool
from langchain.chains import RetrievalQA
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.chat_models import BedrockChat
from langchain_community.vectorstores import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.chat_models import BedrockChat
from litellm import completion


def setup_knowledge_base(
product_catalog: str = None, model_name: str = "gpt-3.5-turbo"
Expand All @@ -24,8 +25,8 @@ def setup_knowledge_base(
text_splitter = CharacterTextSplitter(chunk_size=10, chunk_overlap=0)
texts = text_splitter.split_text(product_catalog)

llm = ChatOpenAI(model_name='gpt-4-turbo-preview', temperature=0)
llm = ChatOpenAI(model_name="gpt-4-turbo-preview", temperature=0)

embeddings = OpenAIEmbeddings()
docsearch = Chroma.from_texts(
texts, embeddings, collection_name="product-knowledge-base"
Expand All @@ -41,36 +42,39 @@ def completion_bedrock(model_id, system_prompt, messages, max_tokens=1000):
"""
High-level API call to generate a message with Anthropic Claude.
"""
bedrock_runtime = boto3.client(service_name='bedrock-runtime',region_name=os.environ.get('AWS_REGION_NAME'))

body = json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"system": system_prompt,
"messages": messages
})

response = bedrock_runtime.invoke_model(
body=body,
modelId=model_id
bedrock_runtime = boto3.client(
service_name="bedrock-runtime", region_name=os.environ.get("AWS_REGION_NAME")
)

body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"system": system_prompt,
"messages": messages,
}
)
response_body = json.loads(response.get('body').read())


response = bedrock_runtime.invoke_model(body=body, modelId=model_id)
response_body = json.loads(response.get("body").read())

return response_body


def get_product_id_from_query(query, product_price_id_mapping_path):
# Load product_price_id_mapping from a JSON file
with open(product_price_id_mapping_path, 'r') as f:
with open(product_price_id_mapping_path, "r") as f:
product_price_id_mapping = json.load(f)

# Serialize the product_price_id_mapping to a JSON string for inclusion in the prompt
product_price_id_mapping_json_str = json.dumps(product_price_id_mapping)

# Dynamically create the enum list from product_price_id_mapping keys
enum_list = list(product_price_id_mapping.values()) + ["No relevant product id found"]
enum_list = list(product_price_id_mapping.values()) + [
"No relevant product id found"
]
enum_list_str = json.dumps(enum_list)

prompt = f"""
You are an expert data scientist and you are working on a project to recommend products to customers based on their needs.
Given the following query:
Expand All @@ -94,50 +98,54 @@ def get_product_id_from_query(query, product_price_id_mapping_path):
}}
Return a valid directly parsable json, dont return in it within a code snippet or add any kind of explanation!!
"""
prompt+='{'
prompt += "{"
model_name = os.getenv("GPT_MODEL", "gpt-3.5-turbo-1106")

if 'anthropic' in model_name:
if "anthropic" in model_name:
response = completion_bedrock(
model_id=model_name,
system_prompt="You are a helpful assistant.",
messages=[{"content": prompt, "role": "user"}],
max_tokens=1000
max_tokens=1000,
)
product_id = response['content'][0]['text']

product_id = response["content"][0]["text"]

else:
response = completion(
model=model_name,
messages=[{"content": prompt, "role": "user"}],
max_tokens=1000,
temperature=0
temperature=0,
)
product_id = response.choices[0].message.content.strip()
return product_id



def generate_stripe_payment_link(query: str) -> str:
"""Generate a stripe payment link for a customer based on a single query string."""

# example testing payment gateway url
PAYMENT_GATEWAY_URL = os.getenv("PAYMENT_GATEWAY_URL", "https://agent-payments-gateway.vercel.app/payment")
PRODUCT_PRICE_MAPPING = os.getenv("PRODUCT_PRICE_MAPPING","example_product_price_id_mapping.json")

PAYMENT_GATEWAY_URL = os.getenv(
"PAYMENT_GATEWAY_URL", "https://agent-payments-gateway.vercel.app/payment"
)
PRODUCT_PRICE_MAPPING = os.getenv(
"PRODUCT_PRICE_MAPPING", "example_product_price_id_mapping.json"
)

# use LLM to get the price_id from query
price_id = get_product_id_from_query(query, PRODUCT_PRICE_MAPPING)
price_id = json.loads(price_id)
payload = json.dumps({"prompt": query,
**price_id,
'stripe_key': os.getenv("STRIPE_API_KEY")
})
payload = json.dumps(
{"prompt": query, **price_id, "stripe_key": os.getenv("STRIPE_API_KEY")}
)
headers = {
'Content-Type': 'application/json',
"Content-Type": "application/json",
}

response = requests.request("POST", PAYMENT_GATEWAY_URL, headers=headers, data=payload)

response = requests.request(
"POST", PAYMENT_GATEWAY_URL, headers=headers, data=payload
)
return response.text


Expand Down

0 comments on commit 3432a67

Please sign in to comment.