Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support nexa model with nexa-sdk #1053

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions camel/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .groq_config import GROQ_API_PARAMS, GroqConfig
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
from .nexa_config import NEXA_API_PARAMS, NexaConfig
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
from .reka_config import REKA_API_PARAMS, RekaConfig
Expand Down Expand Up @@ -62,4 +63,6 @@
'SAMBA_CLOUD_API_PARAMS',
'TogetherAIConfig',
'TOGETHERAI_API_PARAMS',
'NEXA_API_PARAMS',
'NexaConfig',
]
55 changes: 55 additions & 0 deletions camel/configs/nexa_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from typing import Sequence, Union

from openai._types import NOT_GIVEN, NotGiven

from camel.configs.base_config import BaseConfig


class NexaConfig(BaseConfig):
r"""Defines the parameters for generating chat completions using Nexa API

Args:
temperature (float, optional): Sampling temperature to use, between
:obj:`0` and :obj:`2`. Higher values make the output more random,
while lower values make it more focused and deterministic.
(default: :obj:`0.7`)
max_tokens (int, optional): The maximum number of new tokens to
generate in the chat completion. The total length of input tokens
and generated tokens is limited by the model's context length.
(default: :obj:`1024`)
top_p (float, optional): An alternative to sampling with temperature,
called nucleus sampling, where the model considers the results of
the tokens with top_p probability mass. So :obj:`0.1` means only
the tokens comprising the top 10% probability mass are considered.
(default: :obj:`1.0`)
stream (bool, optional): If True, partial message deltas will be sent
as data-only server-sent events as they become available.
(default: :obj:`False`)
stop (str or list, optional): List of stop words for early stopping
generating further tokens. (default: :obj:`None`)

"""

temperature: float = 0.7
max_tokens: int = 1024
top_p: float = 1.0
stream: bool = False
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN


NEXA_API_PARAMS = {param for param in NexaConfig.model_fields.keys()}
2 changes: 2 additions & 0 deletions camel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .mistral_model import MistralModel
from .model_factory import ModelFactory
from .nemotron_model import NemotronModel
from .nexa_model import NexaModel
from .ollama_model import OllamaModel
from .openai_audio_models import OpenAIAudioModels
from .openai_compatible_model import OpenAICompatibleModel
Expand Down Expand Up @@ -51,4 +52,5 @@
'RekaModel',
'SambaModel',
'TogetherAIModel',
'NexaModel',
]
3 changes: 3 additions & 0 deletions camel/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from camel.models.groq_model import GroqModel
from camel.models.litellm_model import LiteLLMModel
from camel.models.mistral_model import MistralModel
from camel.models.nexa_model import NexaModel # Add this import
from camel.models.ollama_model import OllamaModel
from camel.models.openai_compatible_model import OpenAICompatibleModel
from camel.models.openai_model import OpenAIModel
Expand Down Expand Up @@ -81,6 +82,8 @@ def create(
model_class = OllamaModel
elif model_platform.is_vllm:
model_class = VLLMModel
elif model_platform.is_nexa:
model_class = NexaModel
elif model_platform.is_openai_compatible_model:
model_class = OpenAICompatibleModel
elif model_platform.is_samba:
Expand Down
171 changes: 171 additions & 0 deletions camel/models/nexa_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

import os
import subprocess
from typing import Any, Dict, List, Optional, Union

from openai import OpenAI, Stream

from camel.configs import NEXA_API_PARAMS, NexaConfig
from camel.messages import OpenAIMessage
from camel.models import BaseModelBackend
from camel.types import ChatCompletion, ChatCompletionChunk, ModelType
from camel.utils import BaseTokenCounter, OpenAITokenCounter


class NexaModel(BaseModelBackend):
"""Nexa service interface."""

def __init__(
self,
model_type: Union[ModelType, str],
model_config_dict: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
url: Optional[str] = None,
token_counter: Optional[BaseTokenCounter] = None,
) -> None:
"""Constructor for Nexa backend.

Args:
model_type (Union[ModelType, str]): Model for which a backend is
created.
model_config_dict (Optional[Dict[str, Any]], optional): A
dictionary that will be fed into the API call. If None, default
configuration will be used. (default: None)
api_key (Optional[str], optional): The API key for the model
service. (default: None)
url (Optional[str], optional): The url to the model service. If not
provided, environment variable or default URL will be used.
(default: None)
token_counter (Optional[BaseTokenCounter], optional): Token counter
to use for the model. If not provided,
OpenAITokenCounter(ModelType.GPT_4O_MINI) will be used.
(default: None)
"""
if model_config_dict is None:
model_config_dict = NexaConfig().as_dict()
url = url or os.environ.get("NEXA_BASE_URL")
super().__init__(
model_type, model_config_dict, api_key, url, token_counter
)

if not self._url:
self._start_server()

self._client = OpenAI(
timeout=60,
max_retries=3,
base_url=self._url,
api_key=self._api_key,
)

def _start_server(self) -> None:
"""Starts the Nexa server in a subprocess."""
try:
subprocess.Popen(
[
"nexa",
"server",
self.model_type,
"--port",
"8000",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
print(
f"Nexa server started on http://localhost:8000 "
f"for {self.model_type} model."
)
except Exception as e:
print(f"Failed to start Nexa server: {e}.")

@property
def token_counter(self) -> BaseTokenCounter:
"""Initialize the token counter for the model backend.

Returns:
BaseTokenCounter: The token counter following the model's
tokenization style.
"""
if not self._token_counter:
self._token_counter = OpenAITokenCounter(ModelType.GPT_4O_MINI)
return self._token_counter

def check_model_config(self):
"""Check whether the model configuration contains any
unexpected arguments to Nexa API.

Raises:
ValueError: If the model configuration dictionary contains any
unexpected arguments to Nexa API.
"""
for param in self.model_config_dict:
if param not in NEXA_API_PARAMS:
raise ValueError(
f"Unexpected argument `{param}` is "
"input into Nexa model backend."
)

def run(
self,
messages: List[OpenAIMessage],
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
"""Runs inference of Nexa chat completion.

Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.

Returns:
Union[ChatCompletion, Stream[ChatCompletionChunk]]:
`ChatCompletion` in the non-stream mode, or
`Stream[ChatCompletionChunk]` in the stream mode.
"""
try:
response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)
return response
except Exception as e:
raise Exception(f"Error calling Nexa API: {e!s}")

@property
def token_limit(self) -> int:
"""Returns the maximum token limit for the given model.

Returns:
int: The maximum token limit for the given model.
"""
max_tokens = self.model_config_dict.get("max_tokens")
if isinstance(max_tokens, int):
return max_tokens
print(
"Must set `max_tokens` as an integer in `model_config_dict`"
" when setting up the model. Using 4096 as default value."
)
return 4096

@property
def stream(self) -> bool:
"""Returns whether the model is in stream mode, which sends partial
results each time.

Returns:
bool: Whether the model is in stream mode.
"""
return self.model_config_dict.get('stream', False)
6 changes: 6 additions & 0 deletions camel/types/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ class ModelPlatformType(Enum):
TOGETHER = "together"
OPENAI_COMPATIBLE_MODEL = "openai-compatible-model"
SAMBA = "samba-nova"
NEXA = "nexa"

@property
def is_openai(self) -> bool:
Expand Down Expand Up @@ -511,6 +512,11 @@ def is_samba(self) -> bool:
r"""Returns whether this platform is Samba Nova."""
return self is ModelPlatformType.SAMBA

@property
def is_nexa(self) -> bool:
"""Returns whether this platform is Nexa."""
return self is ModelPlatformType.NEXA


class AudioModelType(Enum):
TTS_1 = "tts-1"
Expand Down
50 changes: 50 additions & 0 deletions test/models/test_nexa_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import re

import pytest

from camel.configs import NexaConfig
from camel.models import NexaModel
from camel.utils import OpenAITokenCounter


@pytest.mark.model_backend
@pytest.mark.parametrize(
"model_type",
["llama3.2"],
)
def test_nexa_model(model_type: str):
model = NexaModel(model_type, url="http://localhost:8000/v1")
assert model.model_type == model_type
assert model.model_config_dict == NexaConfig().as_dict()
assert isinstance(model.token_counter, OpenAITokenCounter)
assert isinstance(model.token_limit, int)


@pytest.mark.model_backend
def test_nexa_model_unexpected_argument():
model_type = "llama3.2"
model_config_dict = {"unexpected_arg": "value"}

with pytest.raises(
ValueError,
match=re.escape(
(
"Unexpected argument `unexpected_arg` is "
"input into Nexa model backend."
)
),
):
_ = NexaModel(model_type, model_config_dict, api_key="nexa_api_key")
Loading