Skip to content

Commit

Permalink
Merge pull request #13 from Genaios/fix/provider-bugs
Browse files Browse the repository at this point in the history
Add mistral to bedrock and upgrade anthropic
  • Loading branch information
asarvazyan authored Apr 26, 2024
2 parents 21efd11 + a395b25 commit a608aa8
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 15 deletions.
17 changes: 8 additions & 9 deletions text_machina/src/constrainers/length.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,19 @@ def get_constraints(self) -> Dict[str, int]:
"""
min_new_tokens, max_new_tokens = self.estimate()

if self.provider in ["openai", "azure_openai"]:
return {"max_tokens": max_new_tokens}
elif self.provider == "anthropic":
return {"max_tokens_to_sample": max_new_tokens}
elif self.provider == "cohere":
if self.provider in [
"openai",
"azure_openai",
"anthropic",
"cohere",
"ai21",
"inference_server",
]:
return {"max_tokens": max_new_tokens}
elif self.provider == "vertex":
return {"max_output_tokens": max_new_tokens}
elif self.provider == "bedrock":
return {"maxTokenCount": max_new_tokens}
elif self.provider == "ai21":
return {"max_tokens": max_new_tokens}
elif self.provider == "inference_server":
return {"max_tokens": max_new_tokens}
return {
"max_new_tokens": max_new_tokens,
"min_new_tokens": min_new_tokens,
Expand Down
14 changes: 9 additions & 5 deletions text_machina/src/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,15 @@ def generate_completion(
generation_config: Dict,
) -> str:
try:
completion = self.client.completions.create(
model=self.model_config.model_name,
prompt=prompt,
**generation_config,
).completion
completion = (
self.client.messages.create(
model=self.model_config.model_name,
messages=[{"role": "user", "content": prompt}],
**generation_config,
)
.content[0]
.text
)
except Exception as e:
_logger.info(f"Unrecoverable exception during the request: {e}")
return GENERATION_ERROR
Expand Down
11 changes: 11 additions & 0 deletions text_machina/src/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def get_request_body(
"anthropic",
"cohere",
"meta",
"mistral",
}

if bedrock_provider == "amazon":
Expand Down Expand Up @@ -118,6 +119,13 @@ def get_request_body(
"maxTokenCount"
)
request_body = {"prompt": prompt, **generation_config}
elif bedrock_provider == "mistral":
if "maxTokenCount" in generation_config:
generation_config["max_tokens"] = generation_config.pop(
"maxTokenCount"
)
request_body = {"prompt": prompt, **generation_config}

return json.dumps(request_body)

def get_completion_from_response_body(self, response_body: Dict) -> str:
Expand All @@ -138,6 +146,7 @@ def get_completion_from_response_body(self, response_body: Dict) -> str:
"anthropic",
"cohere",
"meta",
"mistral",
}

if bedrock_provider == "amazon":
Expand All @@ -150,5 +159,7 @@ def get_completion_from_response_body(self, response_body: Dict) -> str:
completion = response_body["generations"][0]["text"]
elif bedrock_provider == "meta":
completion = response_body["generation"]
elif bedrock_provider == "mistral":
completion = response_body["outputs"][0]["text"]

return completion
4 changes: 3 additions & 1 deletion text_machina/version.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
_MAJOR = "0"
_MINOR = "2"
_REVISION = "7"
_REVISION = "8"

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}".format(_MAJOR, _MINOR, _REVISION)

__version__ = VERSION


def _is_newer_than(version: str) -> bool:
"""True if current version is newer than 'version'."""
Expand Down

0 comments on commit a608aa8

Please sign in to comment.