Skip to content

Commit

Permalink
feat: Jamba 1.5 Support (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
amirai21 authored Aug 21, 2024
1 parent 82d3abe commit 6c131ac
Show file tree
Hide file tree
Showing 36 changed files with 940 additions and 55 deletions.
86 changes: 86 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,93 @@
# CHANGELOG

## v2.13.0-rc.2 (2024-08-21)

### Feature

* feat: a few more classes to TypedDict ([`99f9160`](https://github.com/AI21Labs/ai21-python/commit/99f91605ade7c1bd9fc43e739b8c3bc5404ea45d))

* feat: response json example ([`75f9d7c`](https://github.com/AI21Labs/ai21-python/commit/75f9d7c704eb3c5103ee7b023df506597bf61f5e))

* feat: unioned chat message types + async unit tests ([`fbb8267`](https://github.com/AI21Labs/ai21-python/commit/fbb82672a50a63580846d0c7bc6c478d233f3701))

### Fix

* fix: PR comments addressed ([`a66d4e1`](https://github.com/AI21Labs/ai21-python/commit/a66d4e17f9913c42b8666d9e37d0157ff86ae051))

* fix: remove keys, again ([`9588c29`](https://github.com/AI21Labs/ai21-python/commit/9588c29257d87cff935239f6ab4edc9219c3e6bc))

## v2.13.0-rc.1 (2024-08-20)

### Chore

* chore(release): v2.13.0-rc.1 [skip ci] ([`6075e8e`](https://github.com/AI21Labs/ai21-python/commit/6075e8e7e85572a9ab848af17454e0c10d6649b0))

### Documentation

* docs: Fixed Conversational RAG README (#199)

* docs: README for RAG

* docs: Added link ([`e20978a`](https://github.com/AI21Labs/ai21-python/commit/e20978a8917e98a9b2ba86d11f6544d277201698))

### Feature

* feat: multiple tools example ([`d95fc69`](https://github.com/AI21Labs/ai21-python/commit/d95fc69986b8f142540abf2040b03fc26a9bb58d))

* feat: revert TypedDict ([`e5b6b48`](https://github.com/AI21Labs/ai21-python/commit/e5b6b48d26a57f2be89dab4c3bdc3c32f6b72352))

* feat: test TypedDict usage ([`3952220`](https://github.com/AI21Labs/ai21-python/commit/39522205e5fafd71673a3add5181c1c11cde511a))

* feat: more flow in the func calling example ([`067ad50`](https://github.com/AI21Labs/ai21-python/commit/067ad50d8106d775316ab6e802f45ce7012a81d9))

* feat: added models to __all__ ([`4ab26c2`](https://github.com/AI21Labs/ai21-python/commit/4ab26c29c6e3cc613cd3e64560ca5238f10f280d))

* feat: fix DocumentSchema ([`98e4b4b`](https://github.com/AI21Labs/ai21-python/commit/98e4b4b6e4367fcf57d98b3b65183600431361da))

* feat: simplify models ([`c9e1b3a`](https://github.com/AI21Labs/ai21-python/commit/c9e1b3a69c15edfe12659199b397643334c82c90))

* feat: PR changes ([`f5a1d5e`](https://github.com/AI21Labs/ai21-python/commit/f5a1d5e78864fc6da332cd67068ec8a084bcb5ed))

* feat: unit tests ([`6fb5c6e`](https://github.com/AI21Labs/ai21-python/commit/6fb5c6e5b0ccf3f8791dc2784d1a176297b5fe09))

* feat: fix formatter ([`a778c8e`](https://github.com/AI21Labs/ai21-python/commit/a778c8e2aeb901a4eed6456b9e967740acf1eaef))

* feat: jamba 1.5 features - tools calls (function calling), documents, response_format ([`5e186b9`](https://github.com/AI21Labs/ai21-python/commit/5e186b928e36d8f3e6d54fb1b85926e7302e2fb7))

### Fix

* fix: replace jamba-instruct usage with jamba-1.5, add documentation (#201)

* fix: replace jamba-instruct usage with jamba-1.5, add jamba-instruct to legacy models, log headers in http_client

* fix: formatting

* fix: model name change ([`47f9668`](https://github.com/AI21Labs/ai21-python/commit/47f96685e3124396531b82c696f707e2eb3a865f))

* fix: lint ([`58e407d`](https://github.com/AI21Labs/ai21-python/commit/58e407d58bb0b3447fe98482a4b71f7960cdf7ec))

* fix: remove key ([`9eb439f`](https://github.com/AI21Labs/ai21-python/commit/9eb439fc4b357e8acc8b57fe29882865e5eb4bb2))

* fix: fix all unit tests and a fix ([`b41b31a`](https://github.com/AI21Labs/ai21-python/commit/b41b31af45308d6ae18f5cda43d03ad80f14a43d))

### Unknown

* Merge pull request #202 from AI21Labs/chat-model-with-tools-docs-and-response-format

feat: Chat model with tools docs and response format ([`9f4ed9b`](https://github.com/AI21Labs/ai21-python/commit/9f4ed9b8dfb238233839e25a5bf245c952db359b))

* Merge branch 'chat-model-with-tools-docs-and-response-format' of github.com:AI21Labs/ai21-python into chat-model-with-tools-docs-and-response-format ([`c1fff86`](https://github.com/AI21Labs/ai21-python/commit/c1fff869bcf9c556bbb68444f8e8d3a10861bf0a))

* Revert "feat: revert TypedDict"

This reverts commit e5b6b48d26a57f2be89dab4c3bdc3c32f6b72352. ([`157635e`](https://github.com/AI21Labs/ai21-python/commit/157635e445dc47972057b41088261b712efa1ea1))

## v2.12.0 (2024-08-07)

### Chore

* chore(release): v2.12.0 [skip ci] ([`38430ad`](https://github.com/AI21Labs/ai21-python/commit/38430ad46686ce94b40d66991598dae67a3dbdd8))

### Feature

* feat: :sparkles: add conversational RAG resource (#198) ([`fca3729`](https://github.com/AI21Labs/ai21-python/commit/fca372988e62c09e6ae03d6cc91ab83f62db2fd9))
Expand Down
45 changes: 38 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ messages = [

chat_completions = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)
```

Expand Down Expand Up @@ -207,7 +207,7 @@ client = AsyncAI21Client(
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
)

print(response)
Expand All @@ -227,8 +227,9 @@ A more detailed example can be found [here](examples/studio/chat/chat_completion
### Supported Models:

- j2-light
- j2-mid
- j2-ultra
- [j2-ultra](#Chat)
- [j2-mid](#Completion)
- [jamba-instruct](#Chat-Completion)

you can read more about the models [here](https://docs.ai21.com/reference/j2-complete-api-ref#jurassic-2-models).

Expand Down Expand Up @@ -270,6 +271,36 @@ completion_response = client.completion.create(
)
```

### Chat Completion

```python
from ai21 import AI21Client
from ai21.models.chat import ChatMessage

system = "You're a support engineer in a SaaS company"
messages = [
ChatMessage(content=system, role="system"),
ChatMessage(content="Hello, I need help with a signup process.", role="user"),
ChatMessage(content="Hi Alice, I can help you with that. What seems to be the problem?", role="assistant"),
ChatMessage(content="I am having trouble signing up for your product with my Google account.", role="user"),
]

client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct",
max_tokens=100,
temperature=0.7,
top_p=1.0,
stop=["\n"],
)

print(response)
```

Note that jamba-instruct supports async and streaming as well.

</details>

For a more detailed example, see the completion [examples](examples/studio/completion.py).
Expand All @@ -290,7 +321,7 @@ client = AI21Client()

response = client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-instruct",
stream=True,
)
for chunk in response:
Expand All @@ -314,7 +345,7 @@ client = AsyncAI21Client()
async def main():
response = await client.chat.completions.create(
messages=messages,
model="jamba-instruct-preview",
model="jamba-1.5-mini",
stream=True,
)
async for chunk in response:
Expand Down Expand Up @@ -700,7 +731,7 @@ messages = [
]

response = client.chat.completions.create(
model="jamba-instruct",
model="jamba-1.5-mini",
messages=messages,
)
```
Expand Down
26 changes: 21 additions & 5 deletions ai21/clients/studio/resources/chat/async_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from ai21.clients.studio.resources.studio_resource import AsyncStudioResource
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.async_stream import AsyncStream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -17,13 +21,16 @@ class AsyncChatCompletions(AsyncStudioResource, BaseChatCompletions):
async def create(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
stream: Optional[False] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -32,27 +39,33 @@ async def create(
async def create(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
stream: Literal[True],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> AsyncStream[ChatCompletionChunk]:
pass

async def create(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | AsyncStream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +83,9 @@ async def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
13 changes: 11 additions & 2 deletions ai21/clients/studio/resources/chat/base_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from abc import ABC
from typing import List, Optional, Union, Any, Dict, Literal

from ai21.models.chat import ChatMessage
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.types import NotGiven
from ai21.utils.typing import remove_not_given
from ai21.models._pydantic_compatibility import _to_dict
Expand Down Expand Up @@ -33,13 +36,16 @@ def _get_model(self, model: Optional[str], model_id: Optional[str]) -> str:
def _create_body(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
max_tokens: Optional[int] | NotGiven,
temperature: Optional[float] | NotGiven,
top_p: Optional[float] | NotGiven,
stop: Optional[Union[str, List[str]]] | NotGiven,
n: Optional[int] | NotGiven,
stream: Literal[False] | Literal[True] | NotGiven,
tools: List[ToolDefinition] | NotGiven,
response_format: ResponseFormat | NotGiven,
documents: List[DocumentSchema] | NotGiven,
**kwargs: Any,
) -> Dict[str, Any]:
return remove_not_given(
Expand All @@ -52,6 +58,9 @@ def _create_body(
"stop": stop,
"n": n,
"stream": stream,
"tools": tools,
"response_format": response_format,
"documents": documents,
**kwargs,
}
)
24 changes: 20 additions & 4 deletions ai21/clients/studio/resources/chat/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from ai21.clients.studio.resources.studio_resource import StudioResource
from ai21.clients.studio.resources.chat.base_chat_completions import BaseChatCompletions
from ai21.models import ChatMessage as J2ChatMessage
from ai21.models.chat import ChatMessage, ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat import ChatCompletionResponse, ChatCompletionChunk
from ai21.models.chat.chat_message import ChatMessageParam
from ai21.models.chat.document_schema import DocumentSchema
from ai21.models.chat.response_format import ResponseFormat
from ai21.models.chat.tool_defintions import ToolDefinition
from ai21.stream.stream import Stream
from ai21.types import NotGiven, NOT_GIVEN

Expand All @@ -17,13 +21,16 @@ class ChatCompletions(StudioResource, BaseChatCompletions):
def create(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse:
pass
Expand All @@ -32,27 +39,33 @@ def create(
def create(
self,
model: str,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
stream: Literal[True],
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> Stream[ChatCompletionChunk]:
pass

def create(
self,
messages: List[ChatMessage],
messages: List[ChatMessageParam],
model: Optional[str] = None,
max_tokens: int | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
top_p: float | NotGiven = NOT_GIVEN,
stop: str | List[str] | NotGiven = NOT_GIVEN,
n: int | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
tools: List[ToolDefinition] | NotGiven = NOT_GIVEN,
response_format: ResponseFormat | NotGiven = NOT_GIVEN,
documents: List[DocumentSchema] | NotGiven = NOT_GIVEN,
**kwargs: Any,
) -> ChatCompletionResponse | Stream[ChatCompletionChunk]:
if any(isinstance(item, J2ChatMessage) for item in messages):
Expand All @@ -70,6 +83,9 @@ def create(
top_p=top_p,
n=n,
stream=stream or False,
tools=tools,
response_format=response_format,
documents=documents,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion ai21/http_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def execute_http_request(

if response.status_code != httpx.codes.OK:
_logger.error(
f"Calling {method} {self._base_url} failed with a non-200 response code: {response.status_code}"
f"Calling {method} {self._base_url} failed with a non-200 "
f"response code: {response.status_code} headers: {response.headers}"
)
handle_non_success_response(response.status_code, response.text)

Expand Down
Loading

0 comments on commit 6c131ac

Please sign in to comment.