Skip to content

Commit

Permalink
release with pydantic v2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
trisongz committed Oct 18, 2023
1 parent b14e9b0 commit f6fd5cf
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 151 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# Changelogs

#### v0.0.40 (2023-10-18)

**Potentially Breaking Changes**

This version introduces full compatability with `pydantic v1/v2` where previous versions would only work with `pydantic v1`. Auto-detection and handling of deprecated methods of `pydantic` models are handled by `lazyops`, and require `lazyops >= 0.2.60`.

With `pydantic v2` support, there should be a slight performance increase in parsing `pydantic` objects, although the majority of the time is spent waiting for the API to respond.

Additionally, support is added for handling the response like a `dict` object, so you can access the response like `response['choices']` rather than `response.choices`.

#### v0.0.36 (2023-10-11)

Expand Down
69 changes: 31 additions & 38 deletions async_openai/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import asyncio
import aiohttpx
import contextlib
from pydantic import root_validator, Field, BaseModel
from typing import Optional, Type, Any, Union, List, Dict, Iterator, TypeVar, AsyncIterator, Generator, AsyncGenerator, TYPE_CHECKING
from lazyops.types import validator, lazyproperty
from lazyops.types.models import root_validator, pre_root_validator, Field, BaseModel, PYD_VERSION, get_pyd_schema

from async_openai.types.options import OpenAIModel, get_consumption_cost
from async_openai.types.resources import BaseResource, Usage
Expand Down Expand Up @@ -67,12 +67,6 @@ def validate_arguments(cls, v) -> Dict[str, Any]:
return json.loads(v)
return v

def __getitem__(self, key: str) -> Any:
"""
Mimic dict
"""
return getattr(self, key)

# TODO Add support for name
class ChatMessage(BaseResource):
content: Optional[str] = None
Expand All @@ -83,24 +77,12 @@ class ChatMessage(BaseResource):
def dict(self, *args, exclude_none: bool = True, **kwargs):
return super().dict(*args, exclude_none = exclude_none, **kwargs)

def get(self, key: str, default: Any = None, **kwargs) -> Any:
"""
Mimic dict
"""
return getattr(self, key, default)

def __getitem__(self, key: str) -> Any:
"""
Mimic dict
"""
return getattr(self, key)


class ChatChoice(BaseResource):
message: ChatMessage
index: int
logprobs: Optional[Any]
finish_reason: Optional[str]
logprobs: Optional[Any] = None
finish_reason: Optional[str] = None

def __getitem__(self, key: str) -> Any:
"""
Expand All @@ -113,10 +95,13 @@ class Function(BaseResource):
Represents a function
"""
# Must be a-z, A-Z, 0-9, or contain underscores and dashes
name: str = Field(..., max_length = 64, regex = r'^[a-zA-Z0-9_]+$')
if PYD_VERSION == 2:
name: str = Field(..., max_length = 64, pattern = r'^[a-zA-Z0-9_]+$')
else:
name: str = Field(..., max_length = 64, regex = r'^[a-zA-Z0-9_]+$')
parameters: Union[Dict[str, Any], SchemaType, str]
description: Optional[str] = None
source_object: Optional[SchemaType] = Field(default = None, exclude = True)
source_object: Optional[Union[SchemaType, Any]] = Field(default = None, exclude = True)

@root_validator(pre = True)
def validate_parameters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -126,15 +111,17 @@ def validate_parameters(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if params := values.get('parameters'):
if isinstance(params, dict):
pass
elif issubclass(params, BaseModel):
values['parameters'] = params.schema()
elif issubclass(params, BaseModel) or isinstance(params, type(BaseModel)):
values['parameters'] = get_pyd_schema(params)
# params.schema()
values['source_object'] = params
elif isinstance(params, str):
try:
values['parameters'] = json.loads(params)
except Exception as e:
raise ValueError(f'Invalid JSON: {params}, {e}. Must be a dict or pydantic BaseModel.') from e
else:
# logger.warning(f'Invalid parameters: {params}. Must be a dict or pydantic BaseModel.')
raise ValueError(f'Parameters must be a dict or pydantic BaseModel. Provided: {type(params)}')
return values

Expand Down Expand Up @@ -171,6 +158,7 @@ def validate_messages(cls, v) -> List[ChatMessage]:
v = [v]
for i in v:
if isinstance(i, dict):
# vals.append(pyd_parse_obj(ChatMessage, i, strict = False))
vals.append(ChatMessage.parse_obj(i))
elif isinstance(i, str):
vals.append(ChatMessage(content = i))
Expand Down Expand Up @@ -236,6 +224,7 @@ def dict(self, *args, exclude: Any = None, **kwargs):
"""
Returns the dict representation of the response
"""
# data = get_pyd_dict(self, *args, exclude = exclude, **kwargs)
data = super().dict(*args, exclude = exclude, **kwargs)
# data['stream'] = False
if data.get('model'):
Expand All @@ -244,14 +233,14 @@ def dict(self, *args, exclude: Any = None, **kwargs):
# del data['max_tokens']
return data

@root_validator()
@root_validator(pre = True)
def validate_obj(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the object
"""
# Auto validate max tokens
# if values['validate_max_tokens'] or (values.get('max_tokens') is not None and values['max_tokens'] <= 0):
if (values['validate_max_tokens'] and values.get('max_tokens')) \
if (values.get('validate_max_tokens') and values.get('max_tokens')) \
or (values.get('max_tokens') is not None and values['max_tokens'] <= 0):
values['max_tokens'] = get_max_chat_tokens(
messages = values['messages'],
Expand All @@ -268,15 +257,18 @@ def validate_obj(cls, values: Dict[str, Any]) -> Dict[str, Any]:


class ChatResponse(BaseResponse):
choices: Optional[List[ChatChoice]]
choices: Optional[List[ChatChoice]] = None
choice_model: Optional[Type[BaseResource]] = ChatChoice
_input_object: Optional[ChatObject] = None
input_object: Optional[ChatObject] = None

@lazyproperty
def messages(self) -> List[ChatMessage]:
"""
Returns the messages for the completions
"""
if self.choices_results:
return [choice.message for choice in self.choices]
return self._response.text
return self.response.text

@lazyproperty
def function_results(self) -> List[FunctionCall]:
Expand All @@ -291,9 +283,9 @@ def function_result_objects(self) -> List[Union[SchemaType, Dict[str, Any]]]:
Returns the function result objects for the completions
"""
results = []
source_function: Function = self._input_object.functions[0] if self._input_object.function_call == "auto" else (
source_function: Function = self.input_object.functions[0] if self.input_object.function_call == "auto" else (
[
f for f in self._input_object.functions if f.name == self._input_object.function_call
f for f in self.input_object.functions if f.name == self.input_object.function_call
]
)[0]

Expand All @@ -317,14 +309,14 @@ def input_text(self) -> str:
"""
Returns the input text for the input prompt
"""
return '\n'.join([f'{msg.role}: {msg.content}' for msg in self._input_object.messages])
return '\n'.join([f'{msg.role}: {msg.content}' for msg in self.input_object.messages])

@lazyproperty
def input_messages(self) -> List[ChatMessage]:
"""
Returns the input messages for the input prompt
"""
return self._input_object.messages
return self.input_object.messages

@lazyproperty
def text(self) -> str:
Expand All @@ -333,7 +325,7 @@ def text(self) -> str:
"""
if self.choices_results:
return '\n'.join([f'{msg.role}: {msg.content}' for msg in self.messages])
return self._response.text
return self.response.text


@lazyproperty
Expand All @@ -343,14 +335,14 @@ def only_text(self) -> str:
"""
if self.choices_results:
return '\n'.join([msg.content for msg in self.messages])
return self._response.text
return self.response.text

@lazyproperty
def chat_model(self):
"""
Returns the model for the completions
"""
return self._input_object.model or None
return self.input_object.model or None
# return OpenAIModel(value=self.model, mode='chat') if self.model else None

@lazyproperty
Expand All @@ -365,7 +357,7 @@ def _validate_usage(self):
Validate usage
"""
if self.usage and self.usage.total_tokens and self.usage.prompt_tokens: return
if self._response.status_code == 200:
if self.response.status_code == 200:
self.usage = Usage(
# prompt_tokens = get_token_count(self.input_text),
# completion_tokens = get_token_count(self.text),
Expand Down Expand Up @@ -549,6 +541,7 @@ async def ahandle_stream(

except Exception as e:
logger.trace(f'Error: {line}', e)
# self.ctx.stream_consumed = True
self._stream_consumed = True
for remaining_result in results.values():
if streaming:
Expand Down
31 changes: 10 additions & 21 deletions async_openai/schemas/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import asyncio
import aiohttpx
import contextlib
from pydantic import root_validator, Field
from typing import Optional, Type, Any, Union, List, Dict, Iterator, AsyncIterator, Generator, AsyncGenerator, TYPE_CHECKING
from lazyops.types import validator, lazyproperty
from lazyops.types.models import root_validator, pre_root_validator, Field

from async_openai.types.options import OpenAIModel, get_consumption_cost
from async_openai.types.resources import BaseResource, Usage
Expand Down Expand Up @@ -36,14 +36,9 @@ class StreamedCompletionChoice(BaseResource):
class CompletionChoice(BaseResource):
text: str
index: int
logprobs: Optional[Any]
finish_reason: Optional[str]
logprobs: Optional[Any] = None
finish_reason: Optional[str] = None

def __getitem__(self, key: str) -> Any:
"""
Mimic dict
"""
return getattr(self, key)


class CompletionObject(BaseResource):
Expand Down Expand Up @@ -130,12 +125,12 @@ def dict(self, *args, exclude: Any = None, **kwargs):
data['model'] = data['model'].value
return data

@root_validator()
@root_validator(pre = True)
def validate_obj(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""
Validate the object
"""
if (values['validate_max_tokens'] and values.get('max_tokens')) \
if (values.get('validate_max_tokens') and values.get('max_tokens')) \
or (values.get('max_tokens') is not None and values['max_tokens'] <= 0):
values['max_tokens'] = get_max_tokens(
text = values['prompt'],
Expand All @@ -148,9 +143,9 @@ def validate_obj(cls, values: Dict[str, Any]) -> Dict[str, Any]:


class CompletionResponse(BaseResponse):
choices: Optional[List[CompletionChoice]]
choices: Optional[List[CompletionChoice]] = None
choice_model: Optional[Type[BaseResource]] = CompletionChoice
_input_object: Optional[CompletionObject] = None
input_object: Optional[CompletionObject] = None

@lazyproperty
def text(self) -> str:
Expand All @@ -159,7 +154,7 @@ def text(self) -> str:
"""
if self.choices_results:
return ''.join([choice.text for choice in self.choices])
return self._response.text
return self.response.text

@lazyproperty
def openai_model(self):
Expand All @@ -181,9 +176,9 @@ def _validate_usage(self):
Validate usage
"""
if self.usage and self.usage.total_tokens: return
if self._response.status_code == 200:
if self.response.status_code == 200:
self.usage = Usage(
prompt_tokens = get_token_count(self._input_object.prompt),
prompt_tokens = get_token_count(self.input_object.prompt),
completion_tokens = get_token_count(self.text),
)
self.usage.total_tokens = self.usage.prompt_tokens + self.usage.completion_tokens
Expand Down Expand Up @@ -212,12 +207,6 @@ def dict(self, *args, exclude: Any = None, **kwargs):
data['completion_model'] = data['completion_model'].dict()
return data

def __getitem__(self, key: str) -> Any:
"""
Mimic dict
"""
return getattr(self, key)

def parse_stream_item(self, item: Union[Dict, Any], **kwargs) -> Optional[StreamedCompletionChoice]:
"""
Parses a single stream item
Expand Down
2 changes: 1 addition & 1 deletion async_openai/schemas/edits.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def text(self) -> str:
"""
if self.choices:
return ''.join([choice.text for choice in self.choices])
return self._response.text
return self.response.text


class EditRoute(BaseRoute):
Expand Down
4 changes: 2 additions & 2 deletions async_openai/schemas/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def dict(self, *args, exclude: Any = None, **kwargs):
class EmbeddingResponse(BaseResponse):
data: Optional[List[EmbeddingData]]
data_model: Optional[Type[BaseResource]] = EmbeddingData
_input_object: Optional[EmbeddingObject] = None
input_object: Optional[EmbeddingObject] = None

@lazyproperty
def embeddings(self) -> List[List[float]]:
Expand All @@ -70,7 +70,7 @@ def openai_model(self):
"""
Returns the model for the completions
"""
return self.headers.get('openai-model', self._input_object.model.value)
return self.headers.get('openai-model', self.input_object.model.value)

@lazyproperty
def consumption(self) -> int:
Expand Down
Loading

0 comments on commit f6fd5cf

Please sign in to comment.