Skip to content

Commit

Permalink
TLK-1864 agents deployments models refactoring (#824)
Browse files Browse the repository at this point in the history
* TLK-1864 agents deployments models refactoring

* TLK-1864 agents deployments models refactoring - review fixes

* TLK-1864 agents deployments models refactoring - review fixes
  • Loading branch information
EugeneLightsOn authored Nov 11, 2024
1 parent fa64235 commit 32be86b
Show file tree
Hide file tree
Showing 22 changed files with 273 additions and 1,094 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""update agent deployment model
Revision ID: 74ba7e1b4810
Revises: 20b03fd331e8
Create Date: 2024-10-28 13:27:22.299287
"""
from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = '74ba7e1b4810'
down_revision: Union[str, None] = '20b03fd331e8'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('agents', sa.Column('deployment_id', sa.String(), nullable=True))
op.add_column('agents', sa.Column('model_id', sa.String(), nullable=True))
op.create_foreign_key('agents_model_id_fkey', 'agents', 'models', ['model_id'], ['id'], ondelete='CASCADE')
op.create_foreign_key('agents_deployment_id_fkey', 'agents', 'deployments', ['deployment_id'], ['id'], ondelete='CASCADE')
# set the deployment_id and model_id for the agents using agent_deployment_model table
# and then drop the table agent_deployment_model
op.execute(
"""
UPDATE agents
SET deployment_id = agent_deployment_model.deployment_id,
model_id = agent_deployment_model.model_id
FROM agent_deployment_model
WHERE agents.id = agent_deployment_model.agent_id;
"""
)
op.drop_table('agent_deployment_model')
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('agents_deployment_id_fkey', 'agents', type_='foreignkey')
op.drop_constraint('agents_model_id_fkey', 'agents', type_='foreignkey')
op.drop_column('agents', 'model_id')
op.drop_column('agents', 'deployment_id')
# ### end Alembic commands ###
148 changes: 6 additions & 142 deletions src/backend/crud/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import false, true

from backend.database_models import Deployment
from backend.database_models.agent import Agent, AgentDeploymentModel
from backend.schemas.agent import AgentVisibility, UpdateAgentRequest
from backend.database_models.agent import Agent
from backend.schemas.agent import AgentVisibility, UpdateAgentDB
from backend.services.transaction import validate_transaction


Expand Down Expand Up @@ -78,59 +77,6 @@ def get_agent_by_name(db: Session, agent_name: str, user_id: str) -> Agent:
return agent


@validate_transaction
def get_association_by_deployment_name(
db: Session, agent: Agent, deployment_name: str
) -> AgentDeploymentModel:
"""
Get an agent deployment model association by deployment name.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
deployment_name (str): Deployment name.
Returns:
AgentDeploymentModel: Agent deployment model association.
"""
return (
db.query(AgentDeploymentModel)
.join(Deployment, Deployment.id == AgentDeploymentModel.deployment_id)
.filter(
Deployment.name == deployment_name,
AgentDeploymentModel.agent_id == agent.id,
)
.first()
)


@validate_transaction
def get_association_by_deployment_id(
db: Session, agent: Agent, deployment_id: str
) -> AgentDeploymentModel:
"""
Get an agent deployment model association by deployment id.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
deployment_id (str): Deployment ID.
Returns:
AgentDeploymentModel: Agent deployment model association.
"""
return (
db.query(AgentDeploymentModel)
.filter(
AgentDeploymentModel.deployment_id == deployment_id,
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.is_default_deployment == true(),
AgentDeploymentModel.is_default_model == true(),
)
.first()
)


@validate_transaction
def get_agents(
db: Session,
Expand Down Expand Up @@ -176,93 +122,9 @@ def get_agents(
return query.all()


@validate_transaction
def get_agent_model_deployment_association(
db: Session, agent: Agent, model_id: str, deployment_id: str
) -> AgentDeploymentModel:
"""
Get an agent model deployment association.
Args:
db (Session): Database session.
agent (Agent): Agent to get the association.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
Returns:
AgentDeploymentModel: Agent model deployment association.
"""
return (
db.query(AgentDeploymentModel)
.filter(
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.model_id == model_id,
AgentDeploymentModel.deployment_id == deployment_id,
)
.first()
)


@validate_transaction
def delete_agent_model_deployment_association(
db: Session, agent: Agent, model_id: str, deployment_id: str
):
"""
Delete an agent model deployment association.
Args:
db (Session): Database session.
agent (Agent): Agent to delete the association.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
"""
db.query(AgentDeploymentModel).filter(
AgentDeploymentModel.agent_id == agent.id,
AgentDeploymentModel.model_id == model_id,
AgentDeploymentModel.deployment_id == deployment_id,
).delete()
db.commit()


@validate_transaction
def assign_model_deployment_to_agent(
db: Session,
agent: Agent,
model_id: str,
deployment_id: str,
deployment_config: dict[str, str] = {},
set_default: bool = False,
) -> Agent:
"""
Assign a model and deployment to an agent.
Args:
agent (Agent): Agent to assign the model and deployment.
model_id (str): Model ID.
deployment_id (str): Deployment ID.
deployment_config (dict[str, str]): Deployment configuration.
set_default (bool): Set the model and deployment as default.
Returns:
Agent: Agent with the assigned model and deployment.
"""
agent_deployment = AgentDeploymentModel(
agent_id=agent.id,
model_id=model_id,
deployment_id=deployment_id,
is_default_deployment=set_default,
is_default_model=set_default,
deployment_config=deployment_config,
)
db.add(agent_deployment)
db.commit()
db.refresh(agent)
return agent


@validate_transaction
def update_agent(
db: Session, agent: Agent, new_agent: UpdateAgentRequest, user_id: str
db: Session, agent: Agent, new_agent: UpdateAgentDB, user_id: str
) -> Agent:
"""
Update an agent.
Expand All @@ -278,7 +140,9 @@ def update_agent(
if agent.is_private and agent.user_id != user_id:
return None

for attr, value in new_agent.model_dump(exclude_none=True).items():
new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True)

for attr, value in new_agent_cleaned.items():
setattr(agent, attr, value)

db.commit()
Expand Down
65 changes: 2 additions & 63 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models import Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
Expand Down Expand Up @@ -92,70 +92,9 @@ def get_available_deployments(
"""
all_deployments = db.query(Deployment).all()
return [deployment for deployment in all_deployments if deployment.is_available][
offset : offset + limit
offset: offset + limit
]


def get_deployments_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Deployment]:
"""
List all deployments by user id
Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of deployments to be listed.
Returns:
list[Deployment]: List of deployments.
"""
return (
db.query(Deployment)
.join(
AgentDeploymentModel,
Deployment.id == AgentDeploymentModel.deployment_id,
)
.filter(AgentDeploymentModel.agent_id == agent_id)
.limit(limit)
.offset(offset)
.all()
)


def get_available_deployments_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Deployment]:
"""
List all deployments by user id
Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of deployments to be listed.
Returns:
list[Deployment]: List of deployments.
"""
agent_deployments = (
db.query(Deployment)
.join(
AgentDeploymentModel,
Deployment.id == AgentDeploymentModel.deployment_id,
)
.filter(AgentDeploymentModel.agent_id == agent_id)
.limit(limit)
.offset(offset)
.all()
)

return [deployment for deployment in agent_deployments if deployment.is_available][
offset : offset + limit
]


@validate_transaction
def update_deployment(
db: Session, deployment: Deployment, new_deployment: DeploymentUpdate
Expand Down
32 changes: 1 addition & 31 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models import Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.model import ModelCreate, ModelUpdate
Expand Down Expand Up @@ -127,36 +127,6 @@ def delete_model(db: Session, model_id: str) -> None:
db.commit()


def get_models_by_agent_id(
db: Session, agent_id: str, offset: int = 0, limit: int = 100
) -> list[Model]:
"""
List all models by user id
Args:
db (Session): Database session.
agent_id (str): User ID
offset (int): Offset to start the list.
limit (int): Limit of models to be listed.
Returns:
list[Model]: List of models.
"""

return (
db.query(Model)
.join(
AgentDeploymentModel,
agent_id == AgentDeploymentModel.agent_id,
)
.filter(Model.deployment_id == AgentDeploymentModel.deployment_id)
.order_by(Model.name)
.limit(limit)
.offset(offset)
.all()
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
"""
Create a new model by config if present
Expand Down
Loading

0 comments on commit 32be86b

Please sign in to comment.