diff --git a/CHANGELOG.md b/CHANGELOG.md index 36d234d..7e91ba3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelogs +#### v0.0.52 (2024-02-28) + +- Added support for the following parameters in `model_configurations` in `OpenAIManager`: + + - `ping_timeout` - allows for custom timeouts for each client. + + - `included_models` - allows for more flexible setting of models in Azure. + + - `weight` - allows for weighted selection of clients. + +- Improved Healthcheck behavior to cache if successful for a period of time, and always recheck if not. + +- Added `dimension` parameter for `embedding` models. + #### v0.0.51rc (2024-02-07) - Modification of `async_openai.types.context.ModelContextHandler` to a proxied object singleton. diff --git a/async_openai/client.py b/async_openai/client.py index c3b0931..b20c99b 100644 --- a/async_openai/client.py +++ b/async_openai/client.py @@ -1,7 +1,7 @@ import aiohttpx import contextlib from typing import Optional, Callable, Dict, Union, List - +from lazyops.utils.helpers import timed_cache from async_openai.schemas import * from async_openai.types.options import ApiType from async_openai.utils.logs import logger @@ -413,31 +413,40 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): await self.async_close() - - def ping(self, timeout: Optional[float] = 1.0) -> bool: + @timed_cache(secs = 120, cache_if_result = True) + def ping(self, timeout: Optional[float] = 1.0, base_url: Optional[str] = None) -> bool: """ Pings the API Endpoint to check if it's alive. """ try: # with contextlib.suppress(Exception): - response = self.client.get('/', timeout = timeout) - data = response.json() - # we should expect a 404 with a json response - # if self.debug_enabled: logger.info(f"API Ping: {data}\n{response.headers}") - if data.get('error'): return True + response = self.client.get(base_url or '/', timeout = timeout) + try: + data = response.json() + # we should expect a 404 with a json response + # if self.debug_enabled: logger.info(f"API Ping: {data}\n{response.headers}") + if data.get('error'): return True + except Exception as e: + logger.error(f"[{self.name} - {response.status_code}] API Ping Failed: {response.text[:500]}") except Exception as e: - logger.error(f"API Ping Failed: {e}") + logger.error(f"[{self.name}] API Ping Failed: {e}") return False - async def aping(self, timeout: Optional[float] = 1.0) -> bool: + @timed_cache(secs = 120, cache_if_result = True) + async def aping(self, timeout: Optional[float] = 1.0, base_url: Optional[str] = None) -> bool: """ Pings the API Endpoint to check if it's alive. """ - with contextlib.suppress(Exception): - response = await self.client.async_get('/', timeout = timeout) - data = response.json() - # we should expect a 404 with a json response - if data.get('error'): return True + try: + response = await self.client.async_get(base_url or '/', timeout = timeout) + try: + data = response.json() + # we should expect a 404 with a json response + if data.get('error'): return True + except Exception as e: + logger.error(f"[{self.name} - {response.status_code}] API Ping Failed: {response.text[:500]}") + except Exception as e: + logger.error(f"[{self.name}] API Ping Failed: {e}") return False diff --git a/async_openai/external_client.py b/async_openai/external_client.py index dfd24a6..8be34e2 100644 --- a/async_openai/external_client.py +++ b/async_openai/external_client.py @@ -291,11 +291,13 @@ async def aping(self, timeout: Optional[float] = 1.0) -> bool: """ Pings the API Endpoint to check if it's alive. """ - with contextlib.suppress(Exception): + try: response = await self.client.async_get('/', timeout = timeout) data = response.json() # we should expect a 404 with a json response if data.get('error'): return True + except Exception as e: + logger.error(f"[{self.name}] API Ping Failed: {e}") return False diff --git a/async_openai/loadbalancer.py b/async_openai/loadbalancer.py index 146ec35..35ca8f6 100644 --- a/async_openai/loadbalancer.py +++ b/async_openai/loadbalancer.py @@ -199,7 +199,6 @@ def get_api_client(self, client_name: Optional[str] = None, require_azure: Optio client_name = 'default' if client_name and client_name not in self.clients: self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) - if not client_name and require_azure: while not self.api.is_azure: self.increase_rotate_index() @@ -212,16 +211,23 @@ def get_api_client_from_list(self, client_names: List[str], require_azure: Optio Initializes a new OpenAI client or Returns an existing one from a list of client names. """ if not self.healthcheck: - name = random.choice(client_names) + name = self.manager.select_client_name_from_weights(client_names) if self.manager.has_client_weights else random.choice(client_names) return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs) + available = [] for client_name in client_names: if client_name not in self.clients: self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) if require_azure and not self.clients[client_name].is_azure: continue - if not self.clients[client_name].ping(): + if not self.clients[client_name].ping(**self.manager.get_client_ping_params(client_name)): continue - return self.clients[client_name] + if not self.manager.has_client_weights: + return self.clients[client_name] + available.append(client_name) + # return self.clients[client_name] + if available: + name = self.manager.select_client_name_from_weights(available) + return self.clients[name] raise ValueError(f'No healthy client found from: {client_names}') async def aget_api_client_from_list(self, client_names: List[str], require_azure: Optional[bool] = None, **kwargs) -> 'OpenAIClient': @@ -229,16 +235,23 @@ async def aget_api_client_from_list(self, client_names: List[str], require_azure Initializes a new OpenAI client or Returns an existing one from a list of client names. """ if not self.healthcheck: - name = random.choice(client_names) + name = self.manager.select_client_name_from_weights(client_names) if self.manager.has_client_weights else random.choice(client_names) return self.get_api_client(client_name = name, require_azure = require_azure, **kwargs) + available = [] for client_name in client_names: if client_name not in self.clients: self.clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) if require_azure and not self.clients[client_name].is_azure: continue - if not await self.clients[client_name].aping(): + if not await self.clients[client_name].aping(**self.manager.get_client_ping_params(client_name)): continue - return self.clients[client_name] + if not self.manager.has_client_weights: + return self.clients[client_name] + available.append(client_name) + + if available: + name = self.manager.select_client_name_from_weights(available) + return self.clients[name] raise ValueError(f'No healthy client found from: {client_names}') def __getitem__(self, key: Union[str, int]) -> 'OpenAIClient': diff --git a/async_openai/manager.py b/async_openai/manager.py index 22806b6..465d55f 100644 --- a/async_openai/manager.py +++ b/async_openai/manager.py @@ -15,6 +15,7 @@ from async_openai.types.context import ModelContextHandler from async_openai.utils.config import get_settings, OpenAISettings from async_openai.utils.external_config import ExternalProviderSettings +from async_openai.utils.helpers import weighted_choice from async_openai.types.functions import FunctionManager, OpenAIFunctions from async_openai.utils.logs import logger @@ -37,7 +38,17 @@ 'gpt-3.5-turbo-0613': 'gpt-35-turbo-0613', 'gpt-3.5-turbo-1106': 'gpt-35-turbo-1106', } - +DefaultAvailableModels = [ + 'gpt-4', + 'gpt-4-32k', + 'gpt-4-turbo', + 'gpt-4-1106-preview', + 'gpt-4-0125-preview', + 'text-embedding-ada-2', + 'text-embedding-3-small', + 'text-embedding-3-large', +] + list(DefaultModelMapping.values()) + class OpenAIManager(abc.ABC): name: Optional[str] = "openai" on_error: Optional[Callable] = None @@ -61,7 +72,11 @@ def __init__(self, **kwargs): """ Initializes the OpenAI API Client """ + self.client_weights: Optional[Dict[str, float]] = {} + self.client_ping_timeouts: Optional[Dict[str, float]] = {} self.client_model_exclusions: Optional[Dict[str, Dict[str, Union[bool, List[str]]]]] = {} + self.client_base_urls: Optional[Dict[str, str]] = {} + self.no_proxy_client_names: Optional[List[str]] = [] self.client_callbacks: Optional[List[Callable]] = [] self.functions: FunctionManager = OpenAIFunctions @@ -74,6 +89,13 @@ def __init__(self, **kwargs): self.external_client_default: Optional[str] = None # self._external_clients: + @property + def has_client_weights(self) -> bool: + """ + Returns if the client has weights + """ + return bool(self.client_weights) + def add_callback(self, callback: Callable): """ Adds a callback to the client @@ -271,16 +293,24 @@ def get_api_client_from_list( if not client_names: return self.apis.get_api_client(**kwargs) return self.apis.get_api_client_from_list(client_names = client_names, **kwargs) if not client_names: return self.get_api_client(**kwargs) + if not self.auto_healthcheck: - name = random.choice(client_names) + name = self.select_client_name_from_weights(client_names) if self.has_client_weights else random.choice(client_names) return self.get_api_client(client_name = name, **kwargs) + available = [] for client_name in client_names: if client_name not in self._clients: self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) - if not self._clients[client_name].ping(): + if not self._clients[client_name].ping(**self.get_client_ping_params(client_name)): continue - return self._clients[client_name] + if not self.has_client_weights: + return self._clients[client_name] + available.append(client_name) + + if available: + name = self.select_client_name_from_weights(available) + return self._clients[name] raise ValueError(f'No healthy client found from: {client_names}') async def aget_api_client_from_list( @@ -296,15 +326,22 @@ async def aget_api_client_from_list( return await self.apis.aget_api_client_from_list(client_name = client_name, **kwargs) if not client_names: return self.get_api_client(**kwargs) if not self.auto_healthcheck: - name = random.choice(client_names) + name = self.select_client_name_from_weights(client_names) if self.has_client_weights else random.choice(client_names) return self.get_api_client(client_name = name, **kwargs) + available = [] for client_name in client_names: if client_name not in self._clients: self._clients[client_name] = self.init_api_client(client_name = client_name, **kwargs) - if not await self._clients[client_name].aping(): + if not await self._clients[client_name].aping(**self.get_client_ping_params(client_name)): continue - return self._clients[client_name] + if not self.has_client_weights: + return self._clients[client_name] + available.append(client_name) + + if available: + name = self.select_client_name_from_weights(available) + return self._clients[name] raise ValueError(f'No healthy client found from: {client_names}') @@ -535,6 +572,26 @@ def _ensure_api(self): """ if self._api is None: self.configure_internal_apis() + def select_client_name_from_weights(self, names: List[str]) -> str: + """ + Returns the client weights + """ + return weighted_choice([(name, self.client_weights.get(name, 1.0)) for name in names]) + + # def get_client_ping_timeout(self, name: str) -> Optional[float]: + # """ + # Returns the client timeout + # """ + # return self.client_ping_timeouts.get(name, 1.0) + + def get_client_ping_params(self, name: str) -> Dict[str, Union[float, str]]: + """ + Returns the client ping parameters + """ + return { + 'timeout': self.client_ping_timeouts.get(name, 1.0), + 'base_url': self.client_base_urls.get(name), + } """ API Routes @@ -640,11 +697,13 @@ def register_default_endpoints(self): self.init_api_client('azure', is_azure = True) - def register_client_endpoints(self): + def register_client_endpoints(self): # sourcery skip: low-code-quality """ Register the Client Endpoints """ client_configs = copy.deepcopy(self.settings.client_configurations) + seen_models = set(DefaultAvailableModels) + has_weights = any(c.get('weight') for c in client_configs.values()) for name, config in client_configs.items(): is_enabled = config.pop('enabled', False) if not is_enabled: continue @@ -652,20 +711,33 @@ def register_client_endpoints(self): is_default = config.pop('default', False) proxy_disabled = config.pop('proxy_disabled', False) source_endpoint = config.get('api_base') + client_weight = config.pop('weight', 1.0) if has_weights else None + client_ping_timeout = config.pop('ping_timeout', None) + if self.debug_enabled is not None: config['debug_enabled'] = self.debug_enabled if excluded_models := config.pop('excluded_models', None): self.client_model_exclusions[name] = { 'models': excluded_models, 'is_azure': is_azure, } + seen_models.update(excluded_models) else: self.client_model_exclusions[name] = { 'models': None, 'is_azure': is_azure, } + if included_models := config.pop('included_models', None): + self.client_model_exclusions[name]['included_models'] = included_models + + if client_weight: self.client_weights[name] = client_weight + if client_ping_timeout is not None: self.client_ping_timeouts[name] = client_ping_timeout + if (self.settings.proxy.enabled and not proxy_disabled) and config.get('api_base'): # Initialize a non-proxy version of the client config['api_base'] = source_endpoint non_proxy_name = f'{name}_noproxy' + self.client_base_urls[name] = source_endpoint + if client_weight: self.client_weights[non_proxy_name] = client_weight + if client_ping_timeout is not None: self.client_ping_timeouts[non_proxy_name] = client_ping_timeout self.client_model_exclusions[non_proxy_name] = self.client_model_exclusions[name].copy() self.no_proxy_client_names.append(non_proxy_name) self.init_api_client(non_proxy_name, is_azure = is_azure, set_as_default = False, **config) @@ -675,8 +747,19 @@ def register_client_endpoints(self): ) config['api_base'] = self.settings.proxy.endpoint c = self.init_api_client(name, is_azure = is_azure, set_as_default = is_default, **config) - logger.info(f'Registered: `|g|{c.name}|e|` @ `{source_endpoint or c.base_url}` (Azure: {c.is_azure})', colored = True) - + msg = f'Registered: `|g|{c.name}|e|` @ `{source_endpoint or c.base_url}` (Azure: {c.is_azure}' + if has_weights: msg += f', Weight: {client_weight}' + msg += ')' + logger.info(msg, colored = True) + + # Set the models for inclusion + for name in self.client_model_exclusions: + if not self.client_model_exclusions[name].get('included_models'): continue + included_models = self.client_model_exclusions[name].pop('included_models') + self.client_model_exclusions[name]['models'] = [m for m in seen_models if m not in included_models] + # if self.settings.debug_enabled: + # logger.info(f'|g|{name}|e| Included: {included_models}, Excluded: {self.client_model_exclusions[name]["models"]}', colored = True) + def select_client_names( self, diff --git a/async_openai/types/functions.py b/async_openai/types/functions.py index 71c5215..eac9e9d 100644 --- a/async_openai/types/functions.py +++ b/async_openai/types/functions.py @@ -38,6 +38,7 @@ class BaseFunctionModel(BaseModel): function_name: Optional[str] = Field(None, hidden = True) function_model: Optional[str] = Field(None, hidden = True) function_duration: Optional[float] = Field(None, hidden = True) + function_client_name: Optional[str] = Field(None, hidden = True) if TYPE_CHECKING: function_usage: Optional[Usage] @@ -98,6 +99,7 @@ def _set_values_from_response( self, response: 'ChatResponse', name: Optional[str] = None, + client_name: Optional[str] = None, **kwargs ) -> 'BaseFunctionModel': """ @@ -107,6 +109,7 @@ def _set_values_from_response( self.function_usage = response.usage if response.response_ms: self.function_duration = response.response_ms / 1000 self.function_model = response.model + if client_name: self.function_client_name = client_name @property def function_cost(self) -> Optional[float]: @@ -427,10 +430,10 @@ async def arun_chat_function( **kwargs, ) except errors.InvalidRequestError as e: - self.logger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Invalid Request Error. |r|{e}|e|", colored=True) + self.logger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {chat.name}:{model}] Invalid Request Error. |r|{e}|e|", colored=True) raise e except errors.MaxRetriesExceeded as e: - self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Retrying...", colored=True) + self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {chat.name}:{model}] Retrying...", colored=True) return await self.arun_chat_function( messages = messages, cachable = cachable, @@ -445,7 +448,7 @@ async def arun_chat_function( **kwargs, ) except Exception as e: - self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {model}] Unknown Error Trying to run chat function: |r|{e}|e|", colored=True) + self.autologger.info(f"[{current_attempt}/{self.retry_limit}] [{self.name} - {chat.name}:{model}] Unknown Error Trying to run chat function: |r|{e}|e|", colored=True) return await self.arun_chat_function( messages = messages, cachable = cachable, @@ -502,6 +505,7 @@ def parse_response( response: 'ChatResponse', schema: Optional[Type[FunctionSchemaT]] = None, include_name: Optional[bool] = True, + client_name: Optional[str] = None, ) -> Optional[FunctionSchemaT]: # sourcery skip: extract-duplicate-method """ Parses the response @@ -509,7 +513,7 @@ def parse_response( schema = schema or self.schema try: result = schema.model_validate(response.function_results[0].arguments, from_attributes = True) - result._set_values_from_response(response, name = self.name if include_name else None) + result._set_values_from_response(response, name = self.name if include_name else None, client_name = client_name) return result except IndexError as e: self.autologger.error(f"[{self.name} - {response.model} - {response.usage}] No function results found: {e}\n{response.text}") @@ -685,7 +689,7 @@ def run_function_loop( **kwargs, ) - result = self.parse_response(response, include_name = True) + result = self.parse_response(response, include_name = True, client_name = chat.name) if result is not None: return result # Try Again @@ -700,10 +704,10 @@ def run_function_loop( cachable = False, **kwargs, ) - result = self.parse_response(response, include_name = True) + result = self.parse_response(response, include_name = True, client_name = chat.name) if result is not None: return result attempts += 1 - self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.") + self.autologger.error(f"[{chat.name}:{model}] Unable to parse the response for {self.name} after {self.max_attempts} attempts.") if raise_errors: raise errors.MaxRetriesExhausted( name = self.name, func_name = self.name, @@ -731,7 +735,7 @@ async def arun_function_loop( **kwargs, ) - result = self.parse_response(response, include_name = True) + result = self.parse_response(response, include_name = True, client_name = chat.name) if result is not None: return result # Try Again @@ -746,10 +750,10 @@ async def arun_function_loop( cachable = False, **kwargs, ) - result = self.parse_response(response, include_name = True) + result = self.parse_response(response, include_name = True, client_name = chat.name) if result is not None: return result attempts += 1 - self.autologger.error(f"Unable to parse the response for {self.name} after {self.max_attempts} attempts.") + self.autologger.error(f"[{chat.name}:{model}] Unable to parse the response for {self.name} after {self.max_attempts} attempts.") if raise_errors: raise errors.MaxRetriesExhausted( name = self.name, func_name = self.name, @@ -949,6 +953,7 @@ def execute( with_index: Optional[bool] = False, **function_kwargs ) -> Union[Optional['FunctionSchemaT'], Tuple[int, Optional['FunctionSchemaT']]]: + # sourcery skip: low-code-quality """ Runs the function """ @@ -983,7 +988,7 @@ def execute( if self.cache_enabled and function.is_valid_response(result): self.cache.set(key, result) - self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True) + self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit}, Client: {result.function_client_name})", prefix = key, colored = True) if is_iterator and with_index: return idx, result if function.is_valid_response(result) else (idx, None) return result if function.is_valid_response(result) else None @@ -1034,7 +1039,7 @@ async def aexecute( if self.cache_enabled and function.is_valid_response(result): await self.cache.aset(key, result) - self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit})", prefix = key, colored = True) + self.autologger.info(f"Function: {function.name} in {t.total_s} (Cache Hit: {cache_hit}, Client: {result.function_client_name})", prefix = key, colored = True) if is_iterator and with_index: return idx, result if function.is_valid_response(result) else (idx, None) return result if function.is_valid_response(result) else None diff --git a/async_openai/types/pricing.yaml b/async_openai/types/pricing.yaml index 3dded0a..b6ed092 100644 --- a/async_openai/types/pricing.yaml +++ b/async_openai/types/pricing.yaml @@ -15,6 +15,7 @@ gpt-4-1106-preview: gpt-4-0125-preview: aliases: - gpt-4-turbo-preview + - gpt-4-turbo-v context_length: 128000 costs: unit: 1000 @@ -161,19 +162,25 @@ text-embedding-ada-002: endpoints: - embeddings -text-embedding-3-large: +text-embedding-3-small: + aliases: + - t3small + - t3-small context_length: 8191 costs: unit: 1000 - input: 0.00013 + input: 0.00002 endpoints: - embeddings -text-embedding-3-small: +text-embedding-3-large: + aliases: + - t3large + - t3-large context_length: 8191 costs: unit: 1000 - input: 0.00002 + input: 0.00013 endpoints: - embeddings diff --git a/async_openai/utils/helpers.py b/async_openai/utils/helpers.py index aa1791a..0f9c0fd 100644 --- a/async_openai/utils/helpers.py +++ b/async_openai/utils/helpers.py @@ -1,9 +1,13 @@ +import random import inspect import aiohttpx +import bisect +import itertools + from datetime import datetime, timedelta -from typing import Dict, Optional, Iterator, AsyncIterator, Union +from typing import Dict, Optional, Iterator, AsyncIterator, Union, List, Tuple from lazyops.utils.helpers import timed, timer, is_coro_func __all__ = [ @@ -111,4 +115,18 @@ async def aparse_stream(response: aiohttpx.Response) -> AsyncIterator[str]: async for line in response.aiter_lines(): _line = parse_stream_line(line) if _line is not None: - yield _line \ No newline at end of file + yield _line + + +def weighted_choice(choices: Union[List[Tuple[str, float]], Dict[str, float]]) -> str: + """ + Randomly selects a choice based on the weights provided + """ + if isinstance(choices, dict): + choices = list(choices.items()) + weights = list(zip(*choices))[1] + return choices[bisect.bisect( + list(itertools.accumulate(weights)), + random.uniform(0, sum(weights)) + )][0] + diff --git a/async_openai/version.py b/async_openai/version.py index 4c87d21..f92cb62 100644 --- a/async_openai/version.py +++ b/async_openai/version.py @@ -1 +1 @@ -VERSION = '0.0.51rc2' \ No newline at end of file +VERSION = '0.0.52' \ No newline at end of file diff --git a/setup.py b/setup.py index 1d10693..bbe52f5 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ # 'file-io', 'backoff', 'tiktoken', - 'lazyops >= 0.2.74', # Pydantic Support + 'lazyops >= 0.2.76', # Pydantic Support 'pydantic', 'jinja2', 'pyyaml',