Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Deployments refactor; Add deployment service and fix deployment config setting #831

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ down:

.PHONY: run-unit-tests
run-unit-tests:
poetry run pytest src/backend/tests/unit --cov=src/backend --cov-report=xml
poetry run pytest src/backend/tests/unit/$(file) --cov=src/backend --cov-report=xml

.PHONY: run-community-tests
run-community-tests:
Expand Down
27 changes: 7 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.exceptions import DeploymentNotFoundError
malexw marked this conversation as resolved.
Show resolved Hide resolved
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +14,11 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.

Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
deployment = deployment_service.get_deployment_by_name(name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the DeploymentNotFoundError trigger if no deployment is found when filtering through the DB?

Perhaps we should use the fallback logic in a:

if not deployment:
    .. get_default_deployment()

And wrap the whole thing in a try/catch instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand what you're seeing. With the new code here, deployment_service will throw if it can't find a deployment with the specified name. In that case, we catch and instead return a default deployment. And if there are no available deployments at all, get_default_deployment will also throw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I confused the get_deployment_by_name call with a similarly named method I think. Good to go then.

except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment()

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment(**kwargs)
6 changes: 0 additions & 6 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)


def start():
Expand Down Expand Up @@ -50,9 +47,6 @@ def start():

# SET UP ENVIRONMENT FOR DEPLOYMENTS
all_deployments = MANAGED_DEPLOYMENTS_SETUP.copy()
if use_community_features:
all_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)

selected_deployments = select_deployments_prompt(all_deployments, secrets)

for deployment in selected_deployments:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/config/default_agent.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import datetime

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import Tool
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.schemas.agent import AgentPublic

DEFAULT_AGENT_ID = "default"
DEFAULT_DEPLOYMENT = ModelDeploymentName.CoherePlatform
DEFAULT_DEPLOYMENT = CohereDeployment.name()
DEFAULT_MODEL = "command-r-plus"

def get_default_agent() -> AgentPublic:
Expand Down
137 changes: 16 additions & 121 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_installed_deployments() -> list[type[BaseDeployment]]:
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very small nit to rename get_available_deployments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to rename this to whatever, but the reason I wanted to get away from the name available is because the models have an is_available method on them, and it might give the impression that a function named get_available_deployments was filtering based on is_available.


def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
AVAILABLE_MODEL_DEPLOYMENTS = get_installed_deployments()
18 changes: 9 additions & 9 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +19,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:

Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.

Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -132,14 +132,14 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t see this method being used anymore.

"""
Create a new deployment by config.

Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.

Returns:
Deployment: Created deployment.
Expand All @@ -152,7 +152,7 @@ def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that we have deployment_class attribute for new configuration

is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
36 changes: 33 additions & 3 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.transaction import validate_transaction

Expand Down Expand Up @@ -127,14 +127,44 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def get_models_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Model]:
"""
List all models by user id

Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of models to be listed.

Returns:
list[Model]: List of models.
"""

return (
db.query(Model)
.join(
AgentDeploymentModel,
agent_id == AgentDeploymentModel.agent_id,
)
.filter(Model.deployment_id == AgentDeploymentModel.deployment_id)
.order_by(Model.name)
.limit(limit)
.offset(offset)
.all()
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentDefinition, model: str) -> Model:
"""
Create a new model by config if present

Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
deployment_config (DeploymentDefinition): Deployment config data.
model (str): Model data.

Returns:
Expand Down
Loading
Loading