diff --git a/camel/configs/__init__.py b/camel/configs/__init__.py index d3e1f9bf83..a6a352cfde 100644 --- a/camel/configs/__init__.py +++ b/camel/configs/__init__.py @@ -17,6 +17,7 @@ from .groq_config import GROQ_API_PARAMS, GroqConfig from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig from .mistral_config import MISTRAL_API_PARAMS, MistralConfig +from .nexa_config import NEXA_API_PARAMS, NexaConfig from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig from .reka_config import REKA_API_PARAMS, RekaConfig @@ -62,4 +63,6 @@ 'SAMBA_CLOUD_API_PARAMS', 'TogetherAIConfig', 'TOGETHERAI_API_PARAMS', + 'NEXA_API_PARAMS', + 'NexaConfig', ] diff --git a/camel/configs/nexa_config.py b/camel/configs/nexa_config.py new file mode 100644 index 0000000000..260872d151 --- /dev/null +++ b/camel/configs/nexa_config.py @@ -0,0 +1,55 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from __future__ import annotations + +from typing import Sequence, Union + +from openai._types import NOT_GIVEN, NotGiven + +from camel.configs.base_config import BaseConfig + + +class NexaConfig(BaseConfig): + r"""Defines the parameters for generating chat completions using Nexa API + + Args: + temperature (float, optional): Sampling temperature to use, between + :obj:`0` and :obj:`2`. Higher values make the output more random, + while lower values make it more focused and deterministic. + (default: :obj:`0.7`) + max_tokens (int, optional): The maximum number of new tokens to + generate in the chat completion. The total length of input tokens + and generated tokens is limited by the model's context length. + (default: :obj:`1024`) + top_p (float, optional): An alternative to sampling with temperature, + called nucleus sampling, where the model considers the results of + the tokens with top_p probability mass. So :obj:`0.1` means only + the tokens comprising the top 10% probability mass are considered. + (default: :obj:`1.0`) + stream (bool, optional): If True, partial message deltas will be sent + as data-only server-sent events as they become available. + (default: :obj:`False`) + stop (str or list, optional): List of stop words for early stopping + generating further tokens. (default: :obj:`None`) + + """ + + temperature: float = 0.7 + max_tokens: int = 1024 + top_p: float = 1.0 + stream: bool = False + stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN + + +NEXA_API_PARAMS = {param for param in NexaConfig.model_fields.keys()} diff --git a/camel/models/__init__.py b/camel/models/__init__.py index 780c647c6d..4766353013 100644 --- a/camel/models/__init__.py +++ b/camel/models/__init__.py @@ -20,6 +20,7 @@ from .mistral_model import MistralModel from .model_factory import ModelFactory from .nemotron_model import NemotronModel +from .nexa_model import NexaModel from .ollama_model import OllamaModel from .openai_audio_models import OpenAIAudioModels from .openai_compatible_model import OpenAICompatibleModel @@ -51,4 +52,5 @@ 'RekaModel', 'SambaModel', 'TogetherAIModel', + 'NexaModel', ] diff --git a/camel/models/model_factory.py b/camel/models/model_factory.py index a787187d1e..779c1f0b97 100644 --- a/camel/models/model_factory.py +++ b/camel/models/model_factory.py @@ -20,6 +20,7 @@ from camel.models.groq_model import GroqModel from camel.models.litellm_model import LiteLLMModel from camel.models.mistral_model import MistralModel +from camel.models.nexa_model import NexaModel # Add this import from camel.models.ollama_model import OllamaModel from camel.models.openai_compatible_model import OpenAICompatibleModel from camel.models.openai_model import OpenAIModel @@ -81,6 +82,8 @@ def create( model_class = OllamaModel elif model_platform.is_vllm: model_class = VLLMModel + elif model_platform.is_nexa: + model_class = NexaModel elif model_platform.is_openai_compatible_model: model_class = OpenAICompatibleModel elif model_platform.is_samba: diff --git a/camel/models/nexa_model.py b/camel/models/nexa_model.py new file mode 100644 index 0000000000..a68113424d --- /dev/null +++ b/camel/models/nexa_model.py @@ -0,0 +1,171 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== + +import os +import subprocess +from typing import Any, Dict, List, Optional, Union + +from openai import OpenAI, Stream + +from camel.configs import NEXA_API_PARAMS, NexaConfig +from camel.messages import OpenAIMessage +from camel.models import BaseModelBackend +from camel.types import ChatCompletion, ChatCompletionChunk, ModelType +from camel.utils import BaseTokenCounter, OpenAITokenCounter + + +class NexaModel(BaseModelBackend): + """Nexa service interface.""" + + def __init__( + self, + model_type: Union[ModelType, str], + model_config_dict: Optional[Dict[str, Any]] = None, + api_key: Optional[str] = None, + url: Optional[str] = None, + token_counter: Optional[BaseTokenCounter] = None, + ) -> None: + """Constructor for Nexa backend. + + Args: + model_type (Union[ModelType, str]): Model for which a backend is + created. + model_config_dict (Optional[Dict[str, Any]], optional): A + dictionary that will be fed into the API call. If None, default + configuration will be used. (default: None) + api_key (Optional[str], optional): The API key for the model + service. (default: None) + url (Optional[str], optional): The url to the model service. If not + provided, environment variable or default URL will be used. + (default: None) + token_counter (Optional[BaseTokenCounter], optional): Token counter + to use for the model. If not provided, + OpenAITokenCounter(ModelType.GPT_4O_MINI) will be used. + (default: None) + """ + if model_config_dict is None: + model_config_dict = NexaConfig().as_dict() + url = url or os.environ.get("NEXA_BASE_URL") + super().__init__( + model_type, model_config_dict, api_key, url, token_counter + ) + + if not self._url: + self._start_server() + + self._client = OpenAI( + timeout=60, + max_retries=3, + base_url=self._url, + api_key=self._api_key, + ) + + def _start_server(self) -> None: + """Starts the Nexa server in a subprocess.""" + try: + subprocess.Popen( + [ + "nexa", + "server", + self.model_type, + "--port", + "8000", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + print( + f"Nexa server started on http://localhost:8000 " + f"for {self.model_type} model." + ) + except Exception as e: + print(f"Failed to start Nexa server: {e}.") + + @property + def token_counter(self) -> BaseTokenCounter: + """Initialize the token counter for the model backend. + + Returns: + BaseTokenCounter: The token counter following the model's + tokenization style. + """ + if not self._token_counter: + self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI) + return self._token_counter + + def check_model_config(self): + """Check whether the model configuration contains any + unexpected arguments to Nexa API. + + Raises: + ValueError: If the model configuration dictionary contains any + unexpected arguments to Nexa API. + """ + for param in self.model_config_dict: + if param not in NEXA_API_PARAMS: + raise ValueError( + f"Unexpected argument `{param}` is " + "input into Nexa model backend." + ) + + def run( + self, + messages: List[OpenAIMessage], + ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: + """Runs inference of Nexa chat completion. + + Args: + messages (List[OpenAIMessage]): Message list with the chat history + in OpenAI API format. + + Returns: + Union[ChatCompletion, Stream[ChatCompletionChunk]]: + `ChatCompletion` in the non-stream mode, or + `Stream[ChatCompletionChunk]` in the stream mode. + """ + try: + response = self._client.chat.completions.create( + messages=messages, + model=self.model_type, + **self.model_config_dict, + ) + return response + except Exception as e: + raise Exception(f"Error calling Nexa API: {e!s}") + + @property + def token_limit(self) -> int: + """Returns the maximum token limit for the given model. + + Returns: + int: The maximum token limit for the given model. + """ + max_tokens = self.model_config_dict.get("max_tokens") + if isinstance(max_tokens, int): + return max_tokens + print( + "Must set `max_tokens` as an integer in `model_config_dict`" + " when setting up the model. Using 4096 as default value." + ) + return 4096 + + @property + def stream(self) -> bool: + """Returns whether the model is in stream mode, which sends partial + results each time. + + Returns: + bool: Whether the model is in stream mode. + """ + return self.model_config_dict.get('stream', False) diff --git a/camel/types/enums.py b/camel/types/enums.py index 680830bdec..70019c3b2d 100644 --- a/camel/types/enums.py +++ b/camel/types/enums.py @@ -439,6 +439,7 @@ class ModelPlatformType(Enum): TOGETHER = "together" OPENAI_COMPATIBLE_MODEL = "openai-compatible-model" SAMBA = "samba-nova" + NEXA = "nexa" @property def is_openai(self) -> bool: @@ -511,6 +512,11 @@ def is_samba(self) -> bool: r"""Returns whether this platform is Samba Nova.""" return self is ModelPlatformType.SAMBA + @property + def is_nexa(self) -> bool: + """Returns whether this platform is Nexa.""" + return self is ModelPlatformType.NEXA + class AudioModelType(Enum): TTS_1 = "tts-1" diff --git a/test/models/test_nexa_model.py b/test/models/test_nexa_model.py new file mode 100644 index 0000000000..0b27dd59d6 --- /dev/null +++ b/test/models/test_nexa_model.py @@ -0,0 +1,50 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import re + +import pytest + +from camel.configs import NexaConfig +from camel.models import NexaModel +from camel.utils import OpenAITokenCounter + + +@pytest.mark.model_backend +@pytest.mark.parametrize( + "model_type", + ["llama3.2"], +) +def test_nexa_model(model_type: str): + model = NexaModel(model_type, url="http://localhost:8000/v1") + assert model.model_type == model_type + assert model.model_config_dict == NexaConfig().as_dict() + assert isinstance(model.token_counter, OpenAITokenCounter) + assert isinstance(model.token_limit, int) + + +@pytest.mark.model_backend +def test_nexa_model_unexpected_argument(): + model_type = "llama3.2" + model_config_dict = {"unexpected_arg": "value"} + + with pytest.raises( + ValueError, + match=re.escape( + ( + "Unexpected argument `unexpected_arg` is " + "input into Nexa model backend." + ) + ), + ): + _ = NexaModel(model_type, model_config_dict, api_key="nexa_api_key")