From 6f65c01251f4138706aeab8626824f0968f3bb61 Mon Sep 17 00:00:00 2001 From: Peter Jung Date: Fri, 6 Dec 2024 12:39:05 +0100 Subject: [PATCH] Twitter posting fucntion for Microchain (#567) --- .../microchain_agent/microchain_agent_keys.py | 6 ++++ .../microchain_agent/twitter_functions.py | 29 ++++++++++++++----- .../social_media/twitter_handler.py | 19 ++++++++---- 3 files changed, 40 insertions(+), 14 deletions(-) create mode 100644 prediction_market_agent/agents/microchain_agent/microchain_agent_keys.py diff --git a/prediction_market_agent/agents/microchain_agent/microchain_agent_keys.py b/prediction_market_agent/agents/microchain_agent/microchain_agent_keys.py new file mode 100644 index 00000000..b74bd147 --- /dev/null +++ b/prediction_market_agent/agents/microchain_agent/microchain_agent_keys.py @@ -0,0 +1,6 @@ +from prediction_market_agent.utils import APIKeys + + +class MicrochainAgentKeys(APIKeys): + # Double check to make sure you want to actually post on public social media. + ENABLE_SOCIAL_MEDIA: bool = False diff --git a/prediction_market_agent/agents/microchain_agent/twitter_functions.py b/prediction_market_agent/agents/microchain_agent/twitter_functions.py index 863452d3..e1a743c3 100644 --- a/prediction_market_agent/agents/microchain_agent/twitter_functions.py +++ b/prediction_market_agent/agents/microchain_agent/twitter_functions.py @@ -1,21 +1,34 @@ from microchain import Function +from prediction_market_agent_tooling.loggers import logger + +from prediction_market_agent.agents.microchain_agent.microchain_agent_keys import ( + MicrochainAgentKeys, +) +from prediction_market_agent.agents.social_media_agent.social_media.twitter_handler import ( + POST_MAX_LENGTH, + TwitterHandler, +) class SendTweet(Function): @property def description(self) -> str: - return "Use this function to post a tweet on Twitter." + return f"Use this function to post a tweet on Twitter. Maximum length of the tweet is {POST_MAX_LENGTH} characters." @property def example_args(self) -> list[str]: - return ["This is my message."] + return ["This is my tweet."] - def __call__( - self, - message: str, - ) -> str: - # TODO: Complete the logic. - return "Message sent." + def __call__(self, tweet: str) -> str: + if TwitterHandler.does_post_length_exceed_max_length(tweet): + return f"Tweet length exceeds the maximum allowed length of {POST_MAX_LENGTH} characters, because it is {len(tweet)} characters long. Please shorten the tweet." + if MicrochainAgentKeys().ENABLE_SOCIAL_MEDIA: + # Use the raw `post_tweet`, instead of `post`, let the general agent shorten the tweet on its own if it's required. + TwitterHandler().post_tweet(tweet) + else: + # Log as error, so we are notified about it, if we forget to turn it on in production. + logger.error(f"Social media is disabled. Tweeting skipped: {tweet}") + return "Tweet sent." TWITTER_FUNCTIONS: list[type[Function]] = [ diff --git a/prediction_market_agent/agents/social_media_agent/social_media/twitter_handler.py b/prediction_market_agent/agents/social_media_agent/social_media/twitter_handler.py index f275453d..8e87d611 100644 --- a/prediction_market_agent/agents/social_media_agent/social_media/twitter_handler.py +++ b/prediction_market_agent/agents/social_media_agent/social_media/twitter_handler.py @@ -13,7 +13,7 @@ from prediction_market_agent.agents.social_media_agent.social_media.abstract_handler import ( AbstractSocialMediaHandler, ) -from prediction_market_agent.utils import APIKeys, SocialMediaAPIKeys +from prediction_market_agent.utils import SocialMediaAPIKeys class TwitterHandler(AbstractSocialMediaHandler): @@ -36,7 +36,7 @@ def __init__( self.llm = ChatOpenAI( temperature=0, model=model, - api_key=APIKeys().openai_api_key_secretstr_v1, + api_key=keys.openai_api_key_secretstr_v1, ) @observe() @@ -57,11 +57,12 @@ def make_tweet_more_concise(self, tweet: str) -> str: def does_post_length_exceed_max_length(tweet: str) -> bool: return len(tweet) > POST_MAX_LENGTH - def post(self, text: str, reasoning_reply_tweet: str) -> None: + def post(self, text: str, reasoning_reply_tweet: str | None = None) -> None: quote_tweet_id = self.post_else_retry_with_summarization(text) - self.post_else_retry_with_summarization( - reasoning_reply_tweet, quote_tweet_id=quote_tweet_id - ) + if reasoning_reply_tweet is not None: + self.post_else_retry_with_summarization( + reasoning_reply_tweet, quote_tweet_id=quote_tweet_id + ) def post_else_retry_with_summarization( self, text: str, quote_tweet_id: str | None = None @@ -79,6 +80,12 @@ def post_else_retry_with_summarization( f"Tweet too long. Length: {len(text)}, max length: {POST_MAX_LENGTH}" ) return None + return self.post_tweet(text, quote_tweet_id) + + def post_tweet(self, text: str, quote_tweet_id: str | None = None) -> str | None: + """ + Posts the provided text on Twitter. + """ posted_tweet = self.client.create_tweet( text=text, quote_tweet_id=quote_tweet_id )