diff --git a/src/backend/alembic/versions/2024_10_28_74ba7e1b4810_update_agent_deployment_model.py b/src/backend/alembic/versions/2024_10_28_74ba7e1b4810_update_agent_deployment_model.py new file mode 100644 index 0000000000..e774649b5d --- /dev/null +++ b/src/backend/alembic/versions/2024_10_28_74ba7e1b4810_update_agent_deployment_model.py @@ -0,0 +1,47 @@ +"""update agent deployment model + +Revision ID: 74ba7e1b4810 +Revises: 20b03fd331e8 +Create Date: 2024-10-28 13:27:22.299287 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = '74ba7e1b4810' +down_revision: Union[str, None] = '20b03fd331e8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('agents', sa.Column('deployment_id', sa.String(), nullable=True)) + op.add_column('agents', sa.Column('model_id', sa.String(), nullable=True)) + op.create_foreign_key('agents_model_id_fkey', 'agents', 'models', ['model_id'], ['id'], ondelete='CASCADE') + op.create_foreign_key('agents_deployment_id_fkey', 'agents', 'deployments', ['deployment_id'], ['id'], ondelete='CASCADE') + # set the deployment_id and model_id for the agents using agent_deployment_model table + # and then drop the table agent_deployment_model + op.execute( + """ + UPDATE agents + SET deployment_id = agent_deployment_model.deployment_id, + model_id = agent_deployment_model.model_id + FROM agent_deployment_model + WHERE agents.id = agent_deployment_model.agent_id; + """ + ) + op.drop_table('agent_deployment_model') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('agents_deployment_id_fkey', 'agents', type_='foreignkey') + op.drop_constraint('agents_model_id_fkey', 'agents', type_='foreignkey') + op.drop_column('agents', 'model_id') + op.drop_column('agents', 'deployment_id') + # ### end Alembic commands ### diff --git a/src/backend/crud/agent.py b/src/backend/crud/agent.py index 5219689ccd..4eaa3c58ea 100644 --- a/src/backend/crud/agent.py +++ b/src/backend/crud/agent.py @@ -3,9 +3,8 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.expression import false, true -from backend.database_models import Deployment -from backend.database_models.agent import Agent, AgentDeploymentModel -from backend.schemas.agent import AgentVisibility, UpdateAgentRequest +from backend.database_models.agent import Agent +from backend.schemas.agent import AgentVisibility, UpdateAgentDB from backend.services.transaction import validate_transaction @@ -78,59 +77,6 @@ def get_agent_by_name(db: Session, agent_name: str, user_id: str) -> Agent: return agent -@validate_transaction -def get_association_by_deployment_name( - db: Session, agent: Agent, deployment_name: str -) -> AgentDeploymentModel: - """ - Get an agent deployment model association by deployment name. - - Args: - db (Session): Database session. - agent (Agent): Agent to get the association. - deployment_name (str): Deployment name. - - Returns: - AgentDeploymentModel: Agent deployment model association. - """ - return ( - db.query(AgentDeploymentModel) - .join(Deployment, Deployment.id == AgentDeploymentModel.deployment_id) - .filter( - Deployment.name == deployment_name, - AgentDeploymentModel.agent_id == agent.id, - ) - .first() - ) - - -@validate_transaction -def get_association_by_deployment_id( - db: Session, agent: Agent, deployment_id: str -) -> AgentDeploymentModel: - """ - Get an agent deployment model association by deployment id. - - Args: - db (Session): Database session. - agent (Agent): Agent to get the association. - deployment_id (str): Deployment ID. - - Returns: - AgentDeploymentModel: Agent deployment model association. - """ - return ( - db.query(AgentDeploymentModel) - .filter( - AgentDeploymentModel.deployment_id == deployment_id, - AgentDeploymentModel.agent_id == agent.id, - AgentDeploymentModel.is_default_deployment == true(), - AgentDeploymentModel.is_default_model == true(), - ) - .first() - ) - - @validate_transaction def get_agents( db: Session, @@ -176,93 +122,9 @@ def get_agents( return query.all() -@validate_transaction -def get_agent_model_deployment_association( - db: Session, agent: Agent, model_id: str, deployment_id: str -) -> AgentDeploymentModel: - """ - Get an agent model deployment association. - - Args: - db (Session): Database session. - agent (Agent): Agent to get the association. - model_id (str): Model ID. - deployment_id (str): Deployment ID. - - Returns: - AgentDeploymentModel: Agent model deployment association. - """ - return ( - db.query(AgentDeploymentModel) - .filter( - AgentDeploymentModel.agent_id == agent.id, - AgentDeploymentModel.model_id == model_id, - AgentDeploymentModel.deployment_id == deployment_id, - ) - .first() - ) - - -@validate_transaction -def delete_agent_model_deployment_association( - db: Session, agent: Agent, model_id: str, deployment_id: str -): - """ - Delete an agent model deployment association. - - Args: - db (Session): Database session. - agent (Agent): Agent to delete the association. - model_id (str): Model ID. - deployment_id (str): Deployment ID. - """ - db.query(AgentDeploymentModel).filter( - AgentDeploymentModel.agent_id == agent.id, - AgentDeploymentModel.model_id == model_id, - AgentDeploymentModel.deployment_id == deployment_id, - ).delete() - db.commit() - - -@validate_transaction -def assign_model_deployment_to_agent( - db: Session, - agent: Agent, - model_id: str, - deployment_id: str, - deployment_config: dict[str, str] = {}, - set_default: bool = False, -) -> Agent: - """ - Assign a model and deployment to an agent. - - Args: - agent (Agent): Agent to assign the model and deployment. - model_id (str): Model ID. - deployment_id (str): Deployment ID. - deployment_config (dict[str, str]): Deployment configuration. - set_default (bool): Set the model and deployment as default. - - Returns: - Agent: Agent with the assigned model and deployment. - """ - agent_deployment = AgentDeploymentModel( - agent_id=agent.id, - model_id=model_id, - deployment_id=deployment_id, - is_default_deployment=set_default, - is_default_model=set_default, - deployment_config=deployment_config, - ) - db.add(agent_deployment) - db.commit() - db.refresh(agent) - return agent - - @validate_transaction def update_agent( - db: Session, agent: Agent, new_agent: UpdateAgentRequest, user_id: str + db: Session, agent: Agent, new_agent: UpdateAgentDB, user_id: str ) -> Agent: """ Update an agent. @@ -278,7 +140,9 @@ def update_agent( if agent.is_private and agent.user_id != user_id: return None - for attr, value in new_agent.model_dump(exclude_none=True).items(): + new_agent_cleaned = new_agent.dict(exclude_unset=True, exclude_none=True) + + for attr, value in new_agent_cleaned.items(): setattr(agent, attr, value) db.commit() diff --git a/src/backend/crud/deployment.py b/src/backend/crud/deployment.py index 6c2090291a..a6a94c7046 100644 --- a/src/backend/crud/deployment.py +++ b/src/backend/crud/deployment.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import Session -from backend.database_models import AgentDeploymentModel, Deployment +from backend.database_models import Deployment from backend.model_deployments.utils import class_name_validator from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate @@ -92,70 +92,9 @@ def get_available_deployments( """ all_deployments = db.query(Deployment).all() return [deployment for deployment in all_deployments if deployment.is_available][ - offset : offset + limit + offset: offset + limit ] - -def get_deployments_by_agent_id( - db: Session, agent_id: str, offset: int = 0, limit: int = 100 -) -> list[Deployment]: - """ - List all deployments by user id - - Args: - db (Session): Database session. - agent_id (str): User ID - offset (int): Offset to start the list. - limit (int): Limit of deployments to be listed. - - Returns: - list[Deployment]: List of deployments. - """ - return ( - db.query(Deployment) - .join( - AgentDeploymentModel, - Deployment.id == AgentDeploymentModel.deployment_id, - ) - .filter(AgentDeploymentModel.agent_id == agent_id) - .limit(limit) - .offset(offset) - .all() - ) - - -def get_available_deployments_by_agent_id( - db: Session, agent_id: str, offset: int = 0, limit: int = 100 -) -> list[Deployment]: - """ - List all deployments by user id - - Args: - db (Session): Database session. - agent_id (str): User ID - offset (int): Offset to start the list. - limit (int): Limit of deployments to be listed. - - Returns: - list[Deployment]: List of deployments. - """ - agent_deployments = ( - db.query(Deployment) - .join( - AgentDeploymentModel, - Deployment.id == AgentDeploymentModel.deployment_id, - ) - .filter(AgentDeploymentModel.agent_id == agent_id) - .limit(limit) - .offset(offset) - .all() - ) - - return [deployment for deployment in agent_deployments if deployment.is_available][ - offset : offset + limit - ] - - @validate_transaction def update_deployment( db: Session, deployment: Deployment, new_deployment: DeploymentUpdate diff --git a/src/backend/crud/model.py b/src/backend/crud/model.py index 84122891a1..a891c74ccc 100644 --- a/src/backend/crud/model.py +++ b/src/backend/crud/model.py @@ -1,6 +1,6 @@ from sqlalchemy.orm import Session -from backend.database_models import AgentDeploymentModel, Deployment +from backend.database_models import Deployment from backend.database_models.model import Model from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.model import ModelCreate, ModelUpdate @@ -127,36 +127,6 @@ def delete_model(db: Session, model_id: str) -> None: db.commit() -def get_models_by_agent_id( - db: Session, agent_id: str, offset: int = 0, limit: int = 100 -) -> list[Model]: - """ - List all models by user id - - Args: - db (Session): Database session. - agent_id (str): User ID - offset (int): Offset to start the list. - limit (int): Limit of models to be listed. - - Returns: - list[Model]: List of models. - """ - - return ( - db.query(Model) - .join( - AgentDeploymentModel, - agent_id == AgentDeploymentModel.agent_id, - ) - .filter(Model.deployment_id == AgentDeploymentModel.deployment_id) - .order_by(Model.name) - .limit(limit) - .offset(offset) - .all() - ) - - def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model: """ Create a new model by config if present diff --git a/src/backend/database_models/agent.py b/src/backend/database_models/agent.py index 2b5a8f4fcf..a86b7b6558 100644 --- a/src/backend/database_models/agent.py +++ b/src/backend/database_models/agent.py @@ -7,31 +7,6 @@ from backend.database_models.base import Base -class AgentDeploymentModel(Base): - __tablename__ = "agent_deployment_model" - - agent_id: Mapped[str] = mapped_column(ForeignKey("agents.id", ondelete="CASCADE")) - deployment_id: Mapped[str] = mapped_column( - ForeignKey("deployments.id", ondelete="CASCADE") - ) - model_id: Mapped[str] = mapped_column(ForeignKey("models.id", ondelete="CASCADE")) - deployment_config: Mapped[Optional[dict]] = mapped_column(JSON) - is_default_deployment: Mapped[bool] = mapped_column(Boolean, default=False) - is_default_model: Mapped[bool] = mapped_column(Boolean, default=False) - - agent = relationship("Agent", back_populates="agent_deployment_associations") - deployment = relationship( - "Deployment", back_populates="agent_deployment_associations" - ) - model = relationship("Model", back_populates="agent_deployment_associations") - - __table_args__ = ( - UniqueConstraint( - "deployment_id", "agent_id", "model_id", name="deployment_agent_model_uc" - ), - ) - - class Agent(Base): __tablename__ = "agents" @@ -54,116 +29,32 @@ class Agent(Base): ) is_private: Mapped[bool] = mapped_column(Boolean, default=False) - deployments = relationship( - "Deployment", - secondary="agent_deployment_model", - back_populates="agents", - overlaps="deployments,models,agents,agent,agent_deployment_associations,deployment,model", - ) - models = relationship( - "Model", - secondary="agent_deployment_model", - back_populates="agents", - overlaps="deployments,models,agents,agent,agent_deployment_associations,model", + deployment_id: Mapped[Optional[str]] = mapped_column( + ForeignKey( + "deployments.id", name="agents_deployment_id_fkey", ondelete="CASCADE" + ) ) - agent_deployment_associations = relationship( - "AgentDeploymentModel", back_populates="agent" + + model_id: Mapped[Optional[str]] = mapped_column( + ForeignKey( + "models.id", name="agents_model_id_fkey", ondelete="CASCADE" + ) ) user = relationship("User", back_populates="agents") + assigned_deployment = relationship("Deployment", backref="agents") + assigned_model = relationship("Model", backref="agents") + # TODO Eugene - add the composite index here if needed __table_args__ = ( UniqueConstraint("name", "version", "user_id", name="_name_version_user_uc"), ) @property - def default_model_association(self): - default_association = next( - ( - agent_deployment - for agent_deployment in self.agent_deployment_associations - if agent_deployment.is_default_deployment - and agent_deployment.is_default_model - ), - None, - ) - if not default_association: - default_association = ( - self.agent_deployment_associations[0] - if self.agent_deployment_associations - else None - ) - return default_association - - @property - def deployment(self): - default_model_association = next( - ( - agent_deployment - for agent_deployment in self.agent_deployment_associations - if agent_deployment.is_default_deployment - and agent_deployment.is_default_model - ), - None, - ) - if not default_model_association: - default_model_association = ( - self.agent_deployment_associations[0] - if self.agent_deployment_associations - else None - ) - # TODO Eugene - return the deployment object here when FE is ready Discuss with Scott - return ( - default_model_association.deployment.name - if default_model_association - else None - ) + def model(self) -> Optional[str]: + return self.assigned_model.name if self.assigned_model else None + # Property for deployment name @property - def model(self): - default_model_association = next( - ( - agent_deployment - for agent_deployment in self.agent_deployment_associations - if agent_deployment.is_default_deployment - and agent_deployment.is_default_model - ), - None, - ) - if not default_model_association: - default_model_association = ( - self.agent_deployment_associations[0] - if self.agent_deployment_associations - else None - ) - # TODO Eugene - return the model object here when FE is ready Discuss with Scott - return ( - default_model_association.model.name if default_model_association else None - ) - - def set_default_agent_deployment_model(self, deployment_id: str, model_id: str): - default_model_deployment = next( - ( - agent_deployment - for agent_deployment in self.agent_deployment_associations - if agent_deployment.is_default_deployment - and agent_deployment.is_default_model - ), - None, - ) - if default_model_deployment: - default_model_deployment.is_default_deployment = False - default_model_deployment.is_default_model = False - - new_default_model_deployment = next( - ( - agent_deployment - for agent_deployment in self.agent_deployment_associations - if agent_deployment.deployment_id == deployment_id - and agent_deployment.model_id == model_id - ), - None, - ) - if new_default_model_deployment: - new_default_model_deployment.is_default_deployment = True - new_default_model_deployment.is_default_model = True + def deployment(self) -> Optional[str]: + return self.assigned_deployment.name if self.assigned_deployment else None diff --git a/src/backend/database_models/deployment.py b/src/backend/database_models/deployment.py index 7dc4fb3aac..579a2441d6 100644 --- a/src/backend/database_models/deployment.py +++ b/src/backend/database_models/deployment.py @@ -20,16 +20,6 @@ class Deployment(Base): models = relationship("Model", back_populates="deployment") - agents = relationship( - "Agent", - secondary="agent_deployment_model", - back_populates="deployments", - overlaps="deployments,models,agents,agent,agent_deployment_associations,deployment", - ) - agent_deployment_associations = relationship( - "AgentDeploymentModel", back_populates="deployment" - ) - __table_args__ = (UniqueConstraint("name", name="deployment_name_uc"),) def __str__(self): @@ -37,13 +27,7 @@ def __str__(self): @property def is_available(self) -> bool: - # Check if an agent has a deployment config set - for agent_assoc in self.agent_deployment_associations: - if not agent_assoc.deployment_config: - continue - if all(value != "" for value in agent_assoc.deployment_config.values()): - return True - # if no agent has a deployment config set, check if the deployment has a default config + # check if the deployment has a default config if not self.default_deployment_config: return False return all(value != "" for value in self.default_deployment_config.values()) diff --git a/src/backend/database_models/model.py b/src/backend/database_models/model.py index 202b4b5548..ff630e06c2 100644 --- a/src/backend/database_models/model.py +++ b/src/backend/database_models/model.py @@ -18,15 +18,7 @@ class Model(Base): ) deployment = relationship("Deployment", back_populates="models") - agent_deployment_associations = relationship( - "AgentDeploymentModel", back_populates="model" - ) - agents = relationship( - "Agent", - secondary="agent_deployment_model", - back_populates="models", - overlaps="deployments,models,agents,agent,agent_deployment_associations,model", - ) + def __str__(self): return self.name diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index bd16455b22..0be7784d8f 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -14,7 +14,10 @@ AgentToolMetadata as AgentToolMetadataModel, ) from backend.database_models.database import DBSessionDep -from backend.routers.utils import get_deployment_model_from_agent +from backend.routers.utils import ( + get_default_deployment_model, + get_deployment_model_from_agent, +) from backend.schemas.agent import ( Agent, AgentPublic, @@ -25,6 +28,7 @@ CreateAgentToolMetadataRequest, DeleteAgent, DeleteAgentToolMetadata, + UpdateAgentDB, UpdateAgentRequest, UpdateAgentToolMetadataRequest, ) @@ -52,6 +56,7 @@ ) router.name = RouterName.AGENT + @router.post( "", response_model=AgentPublic, @@ -61,9 +66,9 @@ ], ) async def create_agent( - session: DBSessionDep, - agent: CreateAgentRequest, - ctx: Context = Depends(get_context), + session: DBSessionDep, + agent: CreateAgentRequest, + ctx: Context = Depends(get_context), ) -> AgentPublic: """ Create an agent. @@ -81,45 +86,35 @@ async def create_agent( user_id = ctx.get_user_id() logger = ctx.get_logger() - agent_data = AgentModel( - name=agent.name, - description=agent.description, - preamble=agent.preamble, - temperature=agent.temperature, - user_id=user_id, - organization_id=agent.organization_id, - tools=agent.tools, - is_private=agent.is_private, - ) deployment_db, model_db = get_deployment_model_from_agent(agent, session) + default_deployment_db, default_model_db = get_default_deployment_model(session) try: - created_agent = agent_crud.create_agent(session, agent_data) - - if agent.tools_metadata: - for tool_metadata in agent.tools_metadata: - await update_or_create_tool_metadata( - created_agent, tool_metadata, session, ctx - ) - if deployment_db and model_db: - deployment_config = ( - agent.deployment_config - if agent.deployment_config - else deployment_db.default_deployment_config - ) - agent_crud.assign_model_deployment_to_agent( - session, - agent=created_agent, - deployment_id=deployment_db.id, - model_id=model_db.id, - deployment_config=deployment_config, - set_default=True, + agent_data = AgentModel( + name=agent.name, + description=agent.description, + preamble=agent.preamble, + temperature=agent.temperature, + user_id=user_id, + organization_id=agent.organization_id, + tools=agent.tools, + is_private=agent.is_private, + deployment_id=deployment_db.id if deployment_db else default_deployment_db.id if default_deployment_db else None, + model_id=model_db.id if model_db else default_model_db.id if default_model_db else None, ) - agent_schema = Agent.model_validate(created_agent) - ctx.with_agent(agent_schema) + created_agent = agent_crud.create_agent(session, agent_data) + + if agent.tools_metadata: + for tool_metadata in agent.tools_metadata: + await update_or_create_tool_metadata( + created_agent, tool_metadata, session, ctx + ) + + agent_schema = Agent.model_validate(created_agent) + ctx.with_agent(agent_schema) + return created_agent - return created_agent except Exception as e: logger.exception(event=e) raise HTTPException(status_code=500, detail=str(e)) @@ -127,13 +122,13 @@ async def create_agent( @router.get("", response_model=list[AgentPublic]) async def list_agents( - *, - offset: int = 0, - limit: int = 100, - session: DBSessionDep, - visibility: AgentVisibility = AgentVisibility.ALL, - organization_id: Optional[str] = None, - ctx: Context = Depends(get_context), + *, + offset: int = 0, + limit: int = 100, + session: DBSessionDep, + visibility: AgentVisibility = AgentVisibility.ALL, + organization_id: Optional[str] = None, + ctx: Context = Depends(get_context), ) -> list[AgentPublic]: """ List all agents. @@ -171,7 +166,7 @@ async def list_agents( @router.get("/{agent_id}", response_model=AgentPublic) async def get_agent_by_id( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) ) -> Agent: """ Args: @@ -205,9 +200,9 @@ async def get_agent_by_id( @router.get("/{agent_id}/deployments", response_model=list[DeploymentSchema]) -async def get_agent_deployments( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) -) -> list[DeploymentSchema]: +async def get_agent_deployment( + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) +) -> DeploymentSchema: """ Args: agent_id (str): Agent ID. @@ -226,10 +221,7 @@ async def get_agent_deployments( agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - return [ - DeploymentSchema.custom_transform(deployment) - for deployment in agent.deployments - ] + return DeploymentSchema.custom_transform(agent.deployment) @router.put( @@ -241,10 +233,10 @@ async def get_agent_deployments( ], ) async def update_agent( - agent_id: str, - new_agent: UpdateAgentRequest, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + new_agent: UpdateAgentRequest, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> AgentPublic: """ Update an agent by ID. @@ -277,54 +269,12 @@ async def update_agent( try: db_deployment, db_model = get_deployment_model_from_agent(new_agent, session) - deployment_config = new_agent.deployment_config - is_default_deployment = new_agent.is_default_deployment - # Remove association fields - handled manually - new_agent_cleaned = new_agent.dict( - exclude={ - "model", - "deployment", - "deployment_config", - "is_default_deployment", - "is_default_model", - } - ) + new_agent_db = UpdateAgentDB(**new_agent.dict()) if db_deployment and db_model: - current_association = agent_crud.get_agent_model_deployment_association( - session, agent, db_model.id, db_deployment.id - ) - if current_association: - current_config = current_association.deployment_config - agent_crud.delete_agent_model_deployment_association( - session, agent, db_model.id, db_deployment.id - ) - if not deployment_config: - deployment_config = ( - current_config - if current_config - else current_association.deployment.default_deployment_config - ) - agent = agent_crud.assign_model_deployment_to_agent( - session, - agent, - db_model.id, - db_deployment.id, - deployment_config, - is_default_deployment, - ) - else: - deployment_config = db_deployment.default_deployment_config - agent = agent_crud.assign_model_deployment_to_agent( - session, - agent, - db_model.id, - db_deployment.id, - deployment_config, - is_default_deployment, - ) - + new_agent_db.model_id = db_model.id + new_agent_db.deployment_id = db_deployment.id agent = agent_crud.update_agent( - session, agent, UpdateAgentRequest(**new_agent_cleaned), user_id + session, agent, new_agent_db, user_id ) agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) @@ -340,9 +290,9 @@ async def update_agent( @router.delete("/{agent_id}", response_model=DeleteAgent) async def delete_agent( - agent_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgent: """ Delete an agent by ID. @@ -374,10 +324,10 @@ async def delete_agent( async def handle_tool_metadata_update( - agent: Agent, - new_agent: Agent, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent: Agent, + new_agent: Agent, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> Agent: """Update or create tool metadata for an agent. @@ -415,10 +365,10 @@ async def handle_tool_metadata_update( async def update_or_create_tool_metadata( - agent: Agent, - new_tool_metadata: AgentToolMetadata, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent: Agent, + new_tool_metadata: AgentToolMetadata, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> None: """Update or create tool metadata for an agent. @@ -444,7 +394,7 @@ async def update_or_create_tool_metadata( @router.get("/{agent_id}/tool-metadata", response_model=list[AgentToolMetadataPublic]) async def list_agent_tool_metadata( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) ) -> list[AgentToolMetadataPublic]: """ List all agent tool metadata by agent ID. @@ -476,10 +426,10 @@ async def list_agent_tool_metadata( response_model=AgentToolMetadataPublic, ) def create_agent_tool_metadata( - session: DBSessionDep, - agent_id: str, - agent_tool_metadata: CreateAgentToolMetadataRequest, - ctx: Context = Depends(get_context), + session: DBSessionDep, + agent_id: str, + agent_tool_metadata: CreateAgentToolMetadataRequest, + ctx: Context = Depends(get_context), ) -> AgentToolMetadataPublic: """ Create an agent tool metadata. @@ -525,11 +475,11 @@ def create_agent_tool_metadata( @router.put("/{agent_id}/tool-metadata/{agent_tool_metadata_id}") async def update_agent_tool_metadata( - agent_id: str, - agent_tool_metadata_id: str, - session: DBSessionDep, - new_agent_tool_metadata: UpdateAgentToolMetadataRequest, - ctx: Context = Depends(get_context), + agent_id: str, + agent_tool_metadata_id: str, + session: DBSessionDep, + new_agent_tool_metadata: UpdateAgentToolMetadataRequest, + ctx: Context = Depends(get_context), ) -> AgentToolMetadata: """ Update an agent tool metadata by ID. @@ -569,10 +519,10 @@ async def update_agent_tool_metadata( @router.delete("/{agent_id}/tool-metadata/{agent_tool_metadata_id}") async def delete_agent_tool_metadata( - agent_id: str, - agent_tool_metadata_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + agent_tool_metadata_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgentToolMetadata: """ Delete an agent tool metadata by ID. @@ -611,9 +561,9 @@ async def delete_agent_tool_metadata( @router.post("/batch_upload_file", response_model=list[UploadAgentFileResponse]) async def batch_upload_file( - session: DBSessionDep, - files: list[FastAPIUploadFile] = RequestFile(...), - ctx: Context = Depends(get_context), + session: DBSessionDep, + files: list[FastAPIUploadFile] = RequestFile(...), + ctx: Context = Depends(get_context), ) -> UploadAgentFileResponse: user_id = ctx.get_user_id() @@ -635,10 +585,10 @@ async def batch_upload_file( @router.delete("/{agent_id}/files/{file_id}") async def delete_agent_file( - agent_id: str, - file_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + file_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgentFileResponse: """ Delete an agent file by ID. diff --git a/src/backend/routers/utils.py b/src/backend/routers/utils.py index fe082e97f7..dada42e225 100644 --- a/src/backend/routers/utils.py +++ b/src/backend/routers/utils.py @@ -1,3 +1,4 @@ +from backend.config.deployments import ModelDeploymentName from backend.database_models.database import DBSessionDep from backend.schemas.agent import Agent @@ -6,15 +7,34 @@ def get_deployment_model_from_agent(agent: Agent, session: DBSessionDep): from backend.crud import deployment as deployment_crud model_db = None - deployment_db = deployment_crud.get_deployment_by_name(session, agent.deployment) - if not deployment_db: - deployment_db = deployment_crud.get_deployment(session, agent.deployment) + deployment_db = None + if agent.deployment: + deployment_db = deployment_crud.get_deployment_by_name(session, agent.deployment) + if not deployment_db: + deployment_db = deployment_crud.get_deployment(session, agent.deployment) + if deployment_db: + model_db = next( + ( + model + for model in deployment_db.models + if model.name == agent.model or model.id == agent.model + ), + None, + ) + return deployment_db, model_db + + +def get_default_deployment_model(session: DBSessionDep): + from backend.crud import deployment as deployment_crud + + deployment_db = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + model_db = None if deployment_db: model_db = next( ( model for model in deployment_db.models - if model.name == agent.model or model.id == agent.model + if model.name == 'command-r-plus' ), None, ) diff --git a/src/backend/schemas/agent.py b/src/backend/schemas/agent.py index 5d8dc8441f..610910adf8 100644 --- a/src/backend/schemas/agent.py +++ b/src/backend/schemas/agent.py @@ -4,8 +4,6 @@ from pydantic import BaseModel, Field -from backend.schemas.deployment import DeploymentWithModels as DeploymentSchema - class AgentToolMetadataArtifactsType(StrEnum): DOMAIN = "domain" @@ -74,7 +72,6 @@ class Agent(AgentBase): temperature: float tools: Optional[list[str]] tools_metadata: list[AgentToolMetadataPublic] - deployments: list[DeploymentSchema] deployment: Optional[str] model: Optional[str] is_private: Optional[bool] @@ -98,7 +95,6 @@ class CreateAgentRequest(BaseModel): tools: Optional[list[str]] = None tools_metadata: Optional[list[CreateAgentToolMetadataRequest]] = None deployment_config: Optional[dict[str, str]] = None - is_default_deployment: Optional[bool] = False # model_id or model_name model: str # deployment_id or deployment_name @@ -115,26 +111,37 @@ class ListAgentsResponse(BaseModel): agents: list[Agent] -class UpdateAgentRequest(BaseModel): +class UpdateAgentNoDeploymentModel(BaseModel): name: Optional[str] = None version: Optional[int] = None description: Optional[str] = None preamble: Optional[str] = None temperature: Optional[float] = None - model: Optional[str] = None - deployment: Optional[str] = None - deployment_config: Optional[dict[str, str]] = None - is_default_deployment: Optional[bool] = False - is_default_model: Optional[bool] = False - organization_id: Optional[str] = None tools: Optional[list[str]] = None - tools_metadata: Optional[list[CreateAgentToolMetadataRequest]] = None + organization_id: Optional[str] = None is_private: Optional[bool] = None class Config: from_attributes = True use_enum_values = True +class UpdateAgentDB(UpdateAgentNoDeploymentModel): + model_id: Optional[str] = None + deployment_id: Optional[str] = None + + class Config: + from_attributes = True + use_enum_values = True + + +class UpdateAgentRequest(UpdateAgentNoDeploymentModel): + deployment: Optional[str] = None + model: Optional[str] = None + tools_metadata: Optional[list[CreateAgentToolMetadataRequest]] = None + class Config: + from_attributes = True + use_enum_values = True + class DeleteAgent(BaseModel): pass diff --git a/src/backend/tests/integration/crud/test_deployment.py b/src/backend/tests/integration/crud/test_deployment.py index 65a5f6640e..48c2c7bc74 100644 --- a/src/backend/tests/integration/crud/test_deployment.py +++ b/src/backend/tests/integration/crud/test_deployment.py @@ -85,21 +85,9 @@ def test_list_deployments_with_pagination(session): def test_get_available_deployments(session, user): session.query(Deployment).delete() deployment = get_factory("Deployment", session).create() - another_deployment = get_factory("Deployment", session).create( + _ = get_factory("Deployment", session).create( default_deployment_config={} ) - agent = get_factory("Agent", session).create(user=user) - model = get_factory("Model", session).create(deployment=deployment) - another_model = get_factory("Model", session).create(deployment=another_deployment) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=model - ) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=another_deployment, - model=another_model, - deployment_config={}, - ) deployments = deployment_crud.get_available_deployments(session) @@ -114,47 +102,6 @@ def test_get_available_deployments_empty(session, user): assert len(deployments) == 0 -def test_get_available_deployments_by_agent_id(session, user): - session.query(Deployment).delete() - deployment = get_factory("Deployment", session).create() - another_deployment = get_factory("Deployment", session).create( - default_deployment_config={} - ) - agent = get_factory("Agent", session).create(user=user) - model = get_factory("Model", session).create(deployment_id=deployment.id) - another_model = get_factory("Model", session).create( - deployment_id=another_deployment.id - ) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=model - ) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=another_model, deployment_config={} - ) - - deployments = deployment_crud.get_available_deployments_by_agent_id( - session, agent.id - ) - - assert len(deployments) == 1 - assert deployments[0].id == deployment.id - - -def test_get_deployments_by_agent_id(session, user): - session.query(Deployment).delete() - deployment = get_factory("Deployment", session).create() - agent = get_factory("Agent", session).create(user=user) - model = get_factory("Model", session).create(deployment_id=deployment.id) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=model - ) - - deployments = deployment_crud.get_deployments_by_agent_id(session, agent.id) - - assert len(deployments) == 1 - assert deployments[0].id == deployment.id - - def test_update_deployment(session, deployment): new_deployment_data = DeploymentUpdate( name="NewName", diff --git a/src/backend/tests/integration/crud/test_model.py b/src/backend/tests/integration/crud/test_model.py index b0cb2ac9ec..5fbd3bacf0 100644 --- a/src/backend/tests/integration/crud/test_model.py +++ b/src/backend/tests/integration/crud/test_model.py @@ -172,21 +172,3 @@ def test_delete_nonexistent_model(session): model_crud.delete_model(session, "123") # no error model = model_crud.get_model(session, "123") assert model is None - - -def test_get_models_by_agent_id(session, user, deployment): - agent = get_factory("Agent", session).create(user=user) - for i in range(10): - model = get_factory("Model", session).create( - name=f"Test Model {i}", deployment=deployment - ) - - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=model - ) - - models = model_crud.get_models_by_agent_id(session, agent.id) - - assert len(models) == 10 - for i, model in enumerate(models): - assert model.name == f"Test Model {i}" diff --git a/src/backend/tests/unit/crud/test_agent.py b/src/backend/tests/unit/crud/test_agent.py index 5da2fafdc8..d7348c674d 100644 --- a/src/backend/tests/unit/crud/test_agent.py +++ b/src/backend/tests/unit/crud/test_agent.py @@ -5,11 +5,13 @@ from backend.config.tools import Tool from backend.crud import agent as agent_crud from backend.database_models.agent import Agent -from backend.schemas.agent import AgentVisibility, UpdateAgentRequest +from backend.schemas.agent import AgentVisibility, UpdateAgentDB from backend.tests.unit.factories import get_factory def test_create_agent(session, user): + deployment = get_factory("Deployment", session).create() + model = get_factory("Model", session).create(deployment_id=deployment.id) agent_data = Agent( user_id=user.id, version=1, @@ -17,6 +19,8 @@ def test_create_agent(session, user): description="test", preamble="test", temperature=0.5, + model_id=model.id, + deployment_id=deployment.id, tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], is_private=True, ) @@ -30,6 +34,8 @@ def test_create_agent(session, user): assert agent.temperature == 0.5 assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] assert agent.is_private + assert agent.deployment == deployment.name + assert agent.model == model.name agent = agent_crud.get_agent_by_id(session, agent.id, user.id) assert agent.user_id == user.id @@ -39,6 +45,9 @@ def test_create_agent(session, user): assert agent.preamble == "test" assert agent.temperature == 0.5 assert agent.tools == [Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID] + assert agent.is_private + assert agent.deployment == deployment.name + assert agent.model == model.name def test_create_agent_empty_non_required_fields(session, user): @@ -198,6 +207,8 @@ def test_list_agents_with_pagination(session, user): def test_update_agent(session, user): + deployment = get_factory("Deployment", session).create() + model = get_factory("Model", session).create(deployment_id=deployment.id) agent = get_factory("Agent", session).create( name="test_agent", description="This is a test agent", @@ -205,16 +216,22 @@ def test_update_agent(session, user): preamble="test", temperature=0.5, user=user, + deployment_id=deployment.id, + model_id=model.id, tools=[Tool.Wiki_Retriever_LangChain.value.ID, Tool.Search_File.value.ID], ) - new_agent_data = UpdateAgentRequest( + new_deployment = get_factory("Deployment", session).create() + new_model = get_factory("Model", session).create(deployment_id=new_deployment.id) + new_agent_data = UpdateAgentDB( name="new_test_agent", description="This is a new test agent", version=2, preamble="new_test", temperature=0.6, tools=[Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID], + model_id=new_model.id, + deployment_id=new_deployment.id, ) agent = agent_crud.update_agent(session, agent, new_agent_data, user.id) @@ -224,6 +241,10 @@ def test_update_agent(session, user): assert agent.preamble == new_agent_data.preamble assert agent.temperature == new_agent_data.temperature assert agent.tools == [Tool.Python_Interpreter.value.ID, Tool.Calculator.value.ID] + assert agent.model_id == new_model.id + assert agent.deployment_id == new_deployment.id + assert agent.model == new_model.name + assert agent.deployment == new_deployment.name def test_delete_agent(session, user): @@ -248,53 +269,11 @@ def test_delete_agent_by_another_user(session, user): assert status is False -def test_get_association_by_deployment_name(session, user): - agent = get_factory("Agent", session).create(user=user) - deployment = get_factory("Deployment", session).create() - model = get_factory("Model", session).create(deployment_id=deployment.id) - new_association = get_factory("AgentDeploymentModel", session).create( - agent=agent, deployment=deployment, model=model - ) - association = agent_crud.get_association_by_deployment_name( - session, agent, deployment.name - ) - assert association.agent_id == agent.id - assert association.deployment_id == deployment.id - assert association.model_id == model.id - assert new_association.deployment.name == deployment.name - - -def test_get_association_by_deployment_id(session, user): - agent = get_factory("Agent", session).create(user=user) - deployment = get_factory("Deployment", session).create() - model = get_factory("Model", session).create(deployment_id=deployment.id) - new_association = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) - association = agent_crud.get_association_by_deployment_id( - session, agent, deployment.id - ) - assert association.agent_id == agent.id - assert association.deployment_id == deployment.id - assert association.model_id == model.id - assert new_association.deployment.id == deployment.id - - def test_get_agents_by_user_id(session, user): - agent = get_factory("Agent", session).create(user=user) deployment = get_factory("Deployment", session).create() model = get_factory("Model", session).create(deployment_id=deployment.id) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session).create(user=user, deployment_id=deployment.id, model_id=model.id) + agents = agent_crud.get_agents(session, user_id=user.id) assert len(agents) == 1 assert agents[0].user_id == user.id @@ -305,91 +284,12 @@ def test_get_agents_by_organization_id(session): organization = get_factory("Organization", session).create() user = get_factory("User", session).create() user.organizations.append(organization) - agent = get_factory("Agent", session).create(user=user, organization=organization) deployment = get_factory("Deployment", session).create() model = get_factory("Model", session).create(deployment_id=deployment.id) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session).create(user=user, organization=organization, deployment_id=deployment.id, + model_id=model.id) + agents = agent_crud.get_agents(session, user.id, organization_id=organization.id) assert len(agents) == 1 assert agents[0].user_id == user.id assert agents[0].id == agent.id - - -def test_get_agent_model_deployment_association(session): - organization = get_factory("Organization", session).create() - user = get_factory("User", session).create() - user.organizations.append(organization) - agent = get_factory("Agent", session).create(user=user, organization=organization) - deployment = get_factory("Deployment", session).create() - model = get_factory("Model", session).create(deployment_id=deployment.id) - new_association = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) - association = agent_crud.get_agent_model_deployment_association( - session, agent, model.id, deployment.id - ) - - assert association.id == new_association.id - - -def test_delete_agent_model_deployment_association(session): - organization = get_factory("Organization", session).create() - user = get_factory("User", session).create() - user.organizations.append(organization) - agent = get_factory("Agent", session).create(user=user, organization=organization) - deployment = get_factory("Deployment", session).create() - model = get_factory("Model", session).create(deployment_id=deployment.id) - _ = get_factory("AgentDeploymentModel", session).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) - agent_crud.delete_agent_model_deployment_association( - session, agent, model.id, deployment.id - ) - - association = agent_crud.get_agent_model_deployment_association( - session, agent, model.id, deployment.id - ) - - assert association is None - - -def test_delete_non_existing_agent_model_deployment_association(session): - organization = get_factory("Organization", session).create() - user = get_factory("User", session).create() - user.organizations.append(organization) - agent = get_factory("Agent", session).create(user=user, organization=organization) - - agent_crud.delete_agent_model_deployment_association(session, agent, "123", "123") - - -def test_assign_model_deployment_to_agent(session): - organization = get_factory("Organization", session).create() - user = get_factory("User", session).create() - user.organizations.append(organization) - agent = get_factory("Agent", session).create(user=user, organization=organization) - deployment = get_factory("Deployment", session).create() - model = get_factory("Model", session).create(deployment_id=deployment.id) - - agent_crud.assign_model_deployment_to_agent(session, agent, model.id, deployment.id) - - association = agent_crud.get_agent_model_deployment_association( - session, agent, model.id, deployment.id - ) - - assert association.agent_id == agent.id - assert association.deployment_id == deployment.id - assert association.model_id == model.id diff --git a/src/backend/tests/unit/factories/__init__.py b/src/backend/tests/unit/factories/__init__.py index c72fb83953..4b239c5992 100644 --- a/src/backend/tests/unit/factories/__init__.py +++ b/src/backend/tests/unit/factories/__init__.py @@ -1,7 +1,4 @@ from backend.tests.unit.factories.agent import AgentFactory -from backend.tests.unit.factories.agent_deployment_model import ( - AgentDeploymentModelFactory, -) from backend.tests.unit.factories.agent_tool_metadata import AgentToolMetadataFactory from backend.tests.unit.factories.blacklist import BlacklistFactory from backend.tests.unit.factories.citation import CitationFactory @@ -48,7 +45,6 @@ "AgentToolMetadata": AgentToolMetadataFactory, "Model": ModelFactory, "Deployment": DeploymentFactory, - "AgentDeploymentModel": AgentDeploymentModelFactory, "ConversationFileAssociation": ConversationFileAssociationFactory, "MessageFileAssociation": MessageFileAssociationFactory, "Group": GroupFactory, diff --git a/src/backend/tests/unit/factories/agent.py b/src/backend/tests/unit/factories/agent.py index 0b04348157..14e74e67e2 100644 --- a/src/backend/tests/unit/factories/agent.py +++ b/src/backend/tests/unit/factories/agent.py @@ -13,6 +13,8 @@ class Meta: user = factory.SubFactory(UserFactory) user_id = factory.SelfAttribute("user.id") organization_id = None + deployment_id = None + model_id = None name = factory.Faker("sentence") description = factory.Faker("sentence") preamble = factory.Faker("sentence") diff --git a/src/backend/tests/unit/factories/agent_deployment_model.py b/src/backend/tests/unit/factories/agent_deployment_model.py deleted file mode 100644 index 1cb250d04e..0000000000 --- a/src/backend/tests/unit/factories/agent_deployment_model.py +++ /dev/null @@ -1,24 +0,0 @@ -import factory - -from backend.database_models.agent import AgentDeploymentModel -from backend.tests.unit.factories.agent import AgentFactory -from backend.tests.unit.factories.base import BaseFactory -from backend.tests.unit.factories.deployment import DeploymentFactory -from backend.tests.unit.factories.model import ModelFactory - - -class AgentDeploymentModelFactory(BaseFactory): - class Meta: - model = AgentDeploymentModel - - agent = factory.SubFactory(AgentFactory) - deployment = factory.SubFactory(DeploymentFactory) - model = factory.SubFactory(ModelFactory) - agent_id = factory.SelfAttribute("agent.id") - deployment_id = factory.SelfAttribute("deployment.id") - model_id = factory.SelfAttribute("model.id") - deployment_config: factory.Faker( - "pydict", nb_elements=3, variable_nb_elements=True, value_types=["str"] - ) - is_default_deployment: False - is_default_model: False diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index 725c2a752e..e7b7d5df75 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -6,7 +6,6 @@ from backend.config.deployments import ModelDeploymentName from backend.config.tools import Tool -from backend.crud import agent as agent_crud from backend.crud import deployment as deployment_crud from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata @@ -310,43 +309,6 @@ def list_public_and_private_agents( assert len(response_agents) == 5 -def test_list_agent_deployments( - session_client: TestClient, session: Session, user -) -> None: - agent = get_factory("Agent", session).create(user=user) - for i in range(3): - deployment = get_factory("Deployment", session).create( - name=f"test deployment {i}" - ) - model = get_factory("Model", session).create( - deployment=deployment, name=f"test r+ ({i})", cohere_name="command-r-plus" - ) - agent_crud.assign_model_deployment_to_agent( - session, - agent, - model.id, - deployment.id, - deployment_config=deployment.default_deployment_config, - ) - model1 = get_factory("Model", session).create( - deployment=deployment, name=f"test r ({i})", cohere_name="command-r" - ) - agent_crud.assign_model_deployment_to_agent( - session, - agent, - model1.id, - deployment.id, - deployment_config=deployment.default_deployment_config, - ) - - response = session_client.get( - f"/v1/agents/{agent.id}/deployments", headers={"User-Id": user.id} - ) - assert response.status_code == 200 - response_deployments = response.json() - assert len(response_deployments) == 3 - - def test_list_agents_with_pagination( session_client: TestClient, session: Session, user ) -> None: @@ -367,6 +329,7 @@ def test_list_agents_with_pagination( response_agents = response.json() assert len(response_agents) == 1 + def test_get_agent(session_client: TestClient, session: Session, user) -> None: agent = get_factory("Agent", session).create(name="test agent", user_id=user.id) agent_tool_metadata = get_factory("AgentToolMetadata", session).create( diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 559865f040..f26599aa66 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -9,7 +9,6 @@ from backend.chat.enums import StreamEvent from backend.config.deployments import ModelDeploymentName -from backend.database_models import Agent from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User @@ -27,43 +26,6 @@ def user(session_chat: Session) -> User: return get_factory("User", session_chat).create() -@pytest.fixture() -def default_agent_copy(session_chat: Session, user: User) -> Agent: - agent = session_chat.query(Agent).get("default") - # to avoid agent related entities sessions conflicts(conversations created, ...) - # during ROLLBACK we need to create a copy of the default db agent - # and test the streaming chat with the new agent stored in the DB - agent_defaults = ( - agent.default_model_association if agent.default_model_association else None - ) - new_deployment = get_factory("Deployment", session_chat).create( - default_deployment_config=( - agent_defaults.deployment.default_deployment_config - if agent_defaults - else None - ) - ) - new_model = get_factory("Model", session_chat).create( - deployment=new_deployment, - cohere_name=agent_defaults.model.cohere_name if agent_defaults else None, - ) - new_agent = get_factory("Agent", session_chat).create(user=user, tools=[]) - get_factory("AgentDeploymentModel", session_chat).create( - agent=new_agent, - deployment=new_deployment, - model=new_model, - is_default_deployment=True, - is_default_model=True, - deployment_config=( - agent_defaults.deployment.default_deployment_config - if agent_defaults - else None - ), - ) - - return new_agent - - # STREAMING CHAT TESTS @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_new_chat( @@ -88,16 +50,11 @@ def test_streaming_new_chat( def test_streaming_new_chat_with_agent( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user, tools=[]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) - get_factory("AgentDeploymentModel", session_chat).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session_chat).create(user=user, tools=[], deployment_id=deployment.id, + model_id=model.id) + response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -117,16 +74,11 @@ def test_streaming_new_chat_with_agent( def test_streaming_new_chat_with_agent_existing_conversation( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user, tools=[]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) - get_factory("AgentDeploymentModel", session_chat).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session_chat).create(user=user, tools=[], deployment_id=deployment.id, + model_id=model.id) + agent.preamble = "you are a smart assistant" session_chat.refresh(agent) @@ -218,16 +170,11 @@ def test_streaming_chat_with_existing_conversation_from_other_agent( def test_streaming_chat_with_tools_not_in_agent_tools( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user, tools=["wikipedia"]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) - get_factory("AgentDeploymentModel", session_chat).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session_chat).create(user=user, tools=["wikipedia"], deployment_id=deployment.id, + model_id=model.id) + response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -249,16 +196,11 @@ def test_streaming_chat_with_tools_not_in_agent_tools( def test_streaming_chat_with_agent_tools_and_empty_request_tools( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) - get_factory("AgentDeploymentModel", session_chat).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) + agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"], + deployment_id=deployment.id, model_id=model.id) + response = session_client_chat.post( "/v1/chat-stream", headers={ diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts index 2043c8e93d..dbe174354b 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/schemas.gen.ts @@ -82,13 +82,6 @@ export const $AgentPublic = { ], title: 'Tools Metadata', }, - deployments: { - items: { - $ref: '#/components/schemas/DeploymentWithModels', - }, - type: 'array', - title: 'Deployments', - }, deployment: { anyOf: [ { @@ -135,7 +128,6 @@ export const $AgentPublic = { 'preamble', 'temperature', 'tools', - 'deployments', 'deployment', 'model', 'is_private', @@ -1038,18 +1030,6 @@ export const $CreateAgentRequest = { ], title: 'Deployment Config', }, - is_default_deployment: { - anyOf: [ - { - type: 'boolean', - }, - { - type: 'null', - }, - ], - title: 'Is Default Deployment', - default: false, - }, model: { type: 'string', title: 'Model', @@ -1449,78 +1429,6 @@ export const $DeploymentUpdate = { title: 'DeploymentUpdate', } as const; -export const $DeploymentWithModels = { - properties: { - id: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Id', - }, - name: { - type: 'string', - title: 'Name', - }, - description: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Description', - }, - is_available: { - type: 'boolean', - title: 'Is Available', - default: false, - }, - is_community: { - anyOf: [ - { - type: 'boolean', - }, - { - type: 'null', - }, - ], - title: 'Is Community', - default: false, - }, - env_vars: { - anyOf: [ - { - items: { - type: 'string', - }, - type: 'array', - }, - { - type: 'null', - }, - ], - title: 'Env Vars', - }, - models: { - items: { - $ref: '#/components/schemas/ModelSimple', - }, - type: 'array', - title: 'Models', - }, - }, - type: 'object', - required: ['name', 'env_vars', 'models'], - title: 'DeploymentWithModels', -} as const; - export const $Document = { properties: { text: { @@ -2120,44 +2028,6 @@ export const $ModelCreate = { title: 'ModelCreate', } as const; -export const $ModelSimple = { - properties: { - id: { - type: 'string', - title: 'Id', - }, - name: { - type: 'string', - title: 'Name', - }, - cohere_name: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Cohere Name', - }, - description: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Description', - }, - }, - type: 'object', - required: ['id', 'name', 'cohere_name', 'description'], - title: 'ModelSimple', -} as const; - export const $ModelUpdate = { properties: { name: { @@ -3292,55 +3162,32 @@ export const $UpdateAgentRequest = { ], title: 'Temperature', }, - model: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Model', - }, - deployment: { - anyOf: [ - { - type: 'string', - }, - { - type: 'null', - }, - ], - title: 'Deployment', - }, - deployment_config: { + tools: { anyOf: [ { - additionalProperties: { + items: { type: 'string', }, - type: 'object', + type: 'array', }, { type: 'null', }, ], - title: 'Deployment Config', + title: 'Tools', }, - is_default_deployment: { + organization_id: { anyOf: [ { - type: 'boolean', + type: 'string', }, { type: 'null', }, ], - title: 'Is Default Deployment', - default: false, + title: 'Organization Id', }, - is_default_model: { + is_private: { anyOf: [ { type: 'boolean', @@ -3349,10 +3196,9 @@ export const $UpdateAgentRequest = { type: 'null', }, ], - title: 'Is Default Model', - default: false, + title: 'Is Private', }, - organization_id: { + deployment: { anyOf: [ { type: 'string', @@ -3361,21 +3207,18 @@ export const $UpdateAgentRequest = { type: 'null', }, ], - title: 'Organization Id', + title: 'Deployment', }, - tools: { + model: { anyOf: [ { - items: { - type: 'string', - }, - type: 'array', + type: 'string', }, { type: 'null', }, ], - title: 'Tools', + title: 'Model', }, tools_metadata: { anyOf: [ @@ -3391,17 +3234,6 @@ export const $UpdateAgentRequest = { ], title: 'Tools Metadata', }, - is_private: { - anyOf: [ - { - type: 'boolean', - }, - { - type: 'null', - }, - ], - title: 'Is Private', - }, }, type: 'object', title: 'UpdateAgentRequest', diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts index 2bc613de69..2dec2e21ed 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/services.gen.ts @@ -61,8 +61,8 @@ import type { GenerateTitleV1ConversationsConversationIdGenerateTitlePostResponse, GetAgentByIdV1AgentsAgentIdGetData, GetAgentByIdV1AgentsAgentIdGetResponse, - GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetData, - GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetResponse, + GetAgentDeploymentV1AgentsAgentIdDeploymentsGetData, + GetAgentDeploymentV1AgentsAgentIdDeploymentsGetResponse, GetConversationV1ConversationsConversationIdGetData, GetConversationV1ConversationsConversationIdGetResponse, GetDeploymentV1DeploymentsDeploymentIdGetData, @@ -1393,7 +1393,7 @@ export class DefaultService { } /** - * Get Agent Deployments + * Get Agent Deployment * Args: * agent_id (str): Agent ID. * session (DBSessionDep): Database session. @@ -1409,9 +1409,9 @@ export class DefaultService { * @returns Deployment Successful Response * @throws ApiError */ - public getAgentDeploymentsV1AgentsAgentIdDeploymentsGet( - data: GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetData - ): CancelablePromise { + public getAgentDeploymentV1AgentsAgentIdDeploymentsGet( + data: GetAgentDeploymentV1AgentsAgentIdDeploymentsGetData + ): CancelablePromise { return this.httpRequest.request({ method: 'GET', url: '/v1/agents/{agent_id}/deployments', diff --git a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts index c8329f7488..46e38e6459 100644 --- a/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts +++ b/src/interfaces/assistants_web/src/cohere-client/generated/types.gen.ts @@ -12,7 +12,6 @@ export type AgentPublic = { temperature: number; tools: Array | null; tools_metadata?: Array | null; - deployments: Array; deployment: string | null; model: string | null; is_private: boolean | null; @@ -191,7 +190,6 @@ export type CreateAgentRequest = { deployment_config?: { [key: string]: string; } | null; - is_default_deployment?: boolean | null; model: string; deployment: string; organization_id?: string | null; @@ -281,16 +279,6 @@ export type DeploymentUpdate = { } | null; }; -export type DeploymentWithModels = { - id?: string | null; - name: string; - description?: string | null; - is_available?: boolean; - is_community?: boolean | null; - env_vars: Array | null; - models: Array; -}; - export type Document = { text: string; document_id: string; @@ -428,13 +416,6 @@ export type ModelCreate = { deployment_id: string; }; -export type ModelSimple = { - id: string; - name: string; - cohere_name: string | null; - description: string | null; -}; - export type ModelUpdate = { name?: string | null; cohere_name?: string | null; @@ -691,17 +672,12 @@ export type UpdateAgentRequest = { description?: string | null; preamble?: string | null; temperature?: number | null; - model?: string | null; - deployment?: string | null; - deployment_config?: { - [key: string]: string; - } | null; - is_default_deployment?: boolean | null; - is_default_model?: boolean | null; - organization_id?: string | null; tools?: Array | null; - tools_metadata?: Array | null; + organization_id?: string | null; is_private?: boolean | null; + deployment?: string | null; + model?: string | null; + tools_metadata?: Array | null; }; export type UpdateAgentToolMetadataRequest = { @@ -1039,11 +1015,11 @@ export type DeleteAgentV1AgentsAgentIdDeleteData = { export type DeleteAgentV1AgentsAgentIdDeleteResponse = DeleteAgent; -export type GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetData = { +export type GetAgentDeploymentV1AgentsAgentIdDeploymentsGetData = { agentId: string; }; -export type GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetResponse = Array; +export type GetAgentDeploymentV1AgentsAgentIdDeploymentsGetResponse = Array; export type ListAgentToolMetadataV1AgentsAgentIdToolMetadataGetData = { agentId: string; @@ -1786,7 +1762,7 @@ export type $OpenApiTs = { }; '/v1/agents/{agent_id}/deployments': { get: { - req: GetAgentDeploymentsV1AgentsAgentIdDeploymentsGetData; + req: GetAgentDeploymentV1AgentsAgentIdDeploymentsGetData; res: { /** * Successful Response diff --git a/src/interfaces/assistants_web/src/constants/conversation.ts b/src/interfaces/assistants_web/src/constants/conversation.ts index 4a55b10e35..057032c900 100644 --- a/src/interfaces/assistants_web/src/constants/conversation.ts +++ b/src/interfaces/assistants_web/src/constants/conversation.ts @@ -21,7 +21,6 @@ export const DEFAULT_AGENT_TOOLS = [TOOL_SEARCH_FILE_ID, TOOL_READ_DOCUMENT_ID, export const BASE_AGENT: AgentPublic = { id: '', - deployments: [], name: 'Command R+', description: 'Ask questions and get answers based on your files.', created_at: new Date().toISOString(),