Skip to content

Commit

Permalink
Create proper enums for job statuses (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkrzem authored Oct 22, 2024
1 parent 951993f commit c218523
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 24 deletions.
30 changes: 15 additions & 15 deletions src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
and_,
update,
col,
delete,
)

from .models.agentgroup import AgentGroup
from .models.configentry import ConfigEntry
from .models.job import Job
from .models.job import Job, JobStatus
from .models.jobagent import JobAgent
from .models.match import Match
from .schema import MatchesSchema, ConfigSchema
Expand Down Expand Up @@ -67,7 +68,11 @@ def cancel_job(self, job: JobId, error=None) -> None:
session.execute(
update(Job)
.where(Job.id == job)
.values(status="cancelled", finished=int(time()), error=error)
.values(
status=JobStatus.cancelled,
finished=int(time()),
error=error,
)
)
session.commit()

Expand All @@ -85,23 +90,18 @@ def get_job(self, job: JobId) -> Job:
return self.__get_job(session, job)

def get_valid_jobs(self, username_filter: Optional[str]) -> List[Job]:
"""Retrieves valid (accessible and not removed) jobs from the database."""
"""Retrieves valid (accessible) jobs from the database."""
with self.session() as session:
query = (
select(Job)
.where(Job.status != "removed")
.order_by(col(Job.submitted).desc())
)
query = select(Job).order_by(col(Job.submitted).desc())
if username_filter:
query = query.where(Job.rule_author == username_filter)
return session.exec(query).all()

def remove_query(self, job: JobId) -> None:
"""Sets the job status to removed."""
"""Delete the job, linked match and job agent from the database."""
with self.session() as session:
session.execute(
update(Job).where(Job.id == job).values(status="removed")
)
delete_query = delete(Job).where(Job.id == job)
session.execute(delete_query)
session.commit()

def add_match(self, job: JobId, match: Match) -> None:
Expand Down Expand Up @@ -149,7 +149,7 @@ def agent_finish_job(self, job: Job) -> None:
session.execute(
update(Job)
.where(Job.internal_id == job.internal_id)
.values(finished=int(time()), status="done")
.values(finished=int(time()), status=JobStatus.done)
)
session.commit()

Expand Down Expand Up @@ -220,7 +220,7 @@ def init_job_datasets(self, job: JobId, num_datasets: int) -> None:
.values(
total_datasets=num_datasets,
datasets_left=num_datasets,
status="processing",
status=JobStatus.processing,
)
)
session.commit()
Expand Down Expand Up @@ -253,7 +253,7 @@ def create_search_task(
with self.session() as session:
obj = Job(
id=job,
status="new",
status=JobStatus.new,
rule_name=rule_name,
rule_author=rule_author,
raw_yara=raw_yara,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""add jobstatus
Revision ID: 6b495d5a4855
Revises: dbb81bd4d47f
Create Date: 2024-10-15 08:17:30.036531
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "6b495d5a4855"
down_revision = "dbb81bd4d47f"
branch_labels = None
depends_on = None

job_status = sa.Enum(
"done", "new", "cancelled", "processing", name="jobstatus"
)


def upgrade() -> None:
op.drop_constraint("jobagent_job_id_fkey", "jobagent", type_="foreignkey")
op.create_foreign_key(
constraint_name="jobagent_job_id_fkey",
source_table="jobagent",
referent_table="job",
local_cols=["job_id"],
remote_cols=["internal_id"],
ondelete="CASCADE",
)

op.drop_constraint("match_job_id_fkey", "match", type_="foreignkey")
op.create_foreign_key(
constraint_name="match_job_id_fkey",
source_table="match",
referent_table="job",
local_cols=["job_id"],
remote_cols=["internal_id"],
ondelete="CASCADE",
)

op.execute("DELETE FROM job WHERE status = 'removed';")

job_status.create(op.get_bind())
op.alter_column(
"job",
"status",
existing_type=sa.VARCHAR(),
type_=job_status,
postgresql_using="status::jobstatus",
nullable=True,
)


def downgrade() -> None:
op.alter_column(
"job",
"status",
existing_type=job_status,
type_=sa.VARCHAR(),
nullable=False,
)

op.execute("DROP TYPE IF EXISTS jobstatus")

op.drop_constraint("jobagent_job_id_fkey", "jobagent", type_="foreignkey")
op.create_foreign_key(
constraint_name="jobagent_job_id_fkey",
source_table="jobagent",
referent_table="job",
local_cols=["job_id"],
remote_cols=["internal_id"],
)

op.drop_constraint("match_job_id_fkey", "match", type_="foreignkey")
op.create_foreign_key(
constraint_name="match_job_id_fkey",
source_table="match",
referent_table="job",
local_cols=["job_id"],
remote_cols=["internal_id"],
)
18 changes: 17 additions & 1 deletion src/models/job.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import enum

from sqlalchemy.dialects import postgresql

from sqlmodel import SQLModel, Field, ARRAY, String, Column, Relationship
from typing import Optional, List, Union, TYPE_CHECKING

Expand All @@ -6,11 +10,20 @@
from ..models.jobagent import JobAgent


class JobStatus(enum.Enum):
done = "done"
new = "new"
cancelled = "cancelled"
processing = "processing"


class JobView(SQLModel):
"""Public fields of mquery jobs."""

__table_args__ = {"extend_existing": True}

id: str
status: str
status: JobStatus = Field(sa_column=Column(postgresql.ENUM(JobStatus, name="jobstatus"))) # type: ignore
error: Optional[str]
rule_name: str
rule_author: str
Expand All @@ -29,6 +42,9 @@ class JobView(SQLModel):
total_datasets: int
agents_left: int

class Config:
arbitrary_types_allowed = True


class Job(JobView, table=True):
"""Job object in the database. Internal ID is an implementation detail."""
Expand Down
5 changes: 4 additions & 1 deletion src/models/jobagent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlalchemy import Column, ForeignKey
from sqlmodel import SQLModel, Field, Relationship
from typing import Union, TYPE_CHECKING

Expand All @@ -12,7 +13,9 @@ class JobAgent(SQLModel, table=True):
id: Union[int, None] = Field(default=None, primary_key=True)
task_in_progress: int

job_id: int = Field(foreign_key="job.internal_id")
job_id: int = Field(
sa_column=Column(ForeignKey("job.internal_id", ondelete="CASCADE"))
)
job: "Job" = Relationship(back_populates="agents")

agent_id: int = Field(foreign_key="agentgroup.id")
Expand Down
5 changes: 4 additions & 1 deletion src/models/match.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlalchemy import ForeignKey
from sqlmodel import SQLModel, Field, ARRAY, String, Column, JSON, Relationship
from typing import List, Union, Dict, Any

Expand All @@ -15,5 +16,7 @@ class Match(SQLModel, table=True):
# A list of yara rules matched to this file
matches: List[str] = Field(sa_column=Column(ARRAY(String)))

job_id: int = Field(foreign_key="job.internal_id")
job_id: int = Field(
sa_column=Column(ForeignKey("job.internal_id", ondelete="CASCADE"))
)
job: Job = Relationship(back_populates="matches")
3 changes: 1 addition & 2 deletions src/mqueryfront/src/utils.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
export const isStatusFinished = (status) =>
["done", "cancelled", "removed"].includes(status);
["done", "cancelled"].includes(status);

const statusClassMap = {
done: "success",
new: "info",
processing: "info",
cancelled: "danger",
removed: "dark",
};

export const isAuthEnabled = (config) =>
Expand Down
8 changes: 4 additions & 4 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .util import make_sha256_tag
from .config import app_config
from .plugins import PluginManager
from .models.job import Job
from .models.job import Job, JobStatus
from .models.match import Match
from .lib.yaraparse import parse_yara, combine_rules
from .lib.ursadb import Json, UrsaDb
Expand Down Expand Up @@ -182,7 +182,7 @@ def start_search(job_id: JobId) -> None:
"""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
if job.status == JobStatus.cancelled:
logging.info("Job was cancelled, returning...")
return

Expand Down Expand Up @@ -232,7 +232,7 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
"""Queries ursadb and creates yara scans tasks with file batches."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
if job.status == JobStatus.cancelled:
logging.info("Job was cancelled, returning...")
return

Expand Down Expand Up @@ -271,7 +271,7 @@ def run_yara_batch(job_id: JobId, iterator: str, batch_size: int) -> None:
"""Actually scans files, and updates a database with the results."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
if job.status == JobStatus.cancelled:
logging.info("Job was cancelled, returning...")
return

Expand Down

0 comments on commit c218523

Please sign in to comment.