Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updater subclass #606

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions assemblyline_v4_service/common/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

import os
import requests
import time
import traceback

from abc import ABC, abstractmethod

from assemblyline_core.safelist_client import SafelistClient
from io import StringIO
DEFAULT_SERVICE_SERVER = "http://localhost:5003"
Expand All @@ -27,7 +31,17 @@ def __init__(self, message, status_code, api_response=None, api_version=None):
self.status_code = status_code


class ServiceAPI:
class ServiceAPI(ABC):
@abstractmethod
def get_safelist(self, tag_list: list[str] | None = None):
pass

@abstractmethod
def lookup_safelist(self, qhash):
pass


class HostedServiceAPI(ServiceAPI):
def __init__(self, service_attributes, logger):
self.log = logger
self.service_api_host = os.environ.get("SERVICE_API_HOST", DEFAULT_SERVICE_SERVER)
Expand Down Expand Up @@ -68,7 +82,7 @@ def _with_retries(self, func, url):
retries += 1
time.sleep(min(2, 2 ** (retries - 7)))

def get_safelist(self, tag_list=None):
def get_safelist(self, tag_list: list[str] | None = None):
if DEVELOPMENT_MODE:
return {}

Expand All @@ -93,12 +107,12 @@ def lookup_safelist(self, qhash):
raise


class PrivilegedServiceAPI:
class PrivilegedServiceAPI(ServiceAPI):
def __init__(self, logger):
self.log = logger
self.safelist_client = SafelistClient()

def get_safelist(self, tag_list=None):
def get_safelist(self, tag_list: list[str] | None = None):
if DEVELOPMENT_MODE:
return {}
tag_types = None
Expand Down
92 changes: 56 additions & 36 deletions assemblyline_v4_service/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import hashlib
import logging
import os
import requests
import shutil
import tarfile
import tempfile
Expand All @@ -13,11 +12,13 @@
from typing import Dict, Optional
from pathlib import Path

import requests

from assemblyline.common import exceptions, log, version
from assemblyline.common.digests import get_sha256_for_file
from assemblyline.odm.messages.task import Task as ServiceTask
from assemblyline_v4_service.common import helper
from assemblyline_v4_service.common.api import PrivilegedServiceAPI, ServiceAPI
from assemblyline_v4_service.common.api import PrivilegedServiceAPI, HostedServiceAPI, ServiceAPI
from assemblyline_v4_service.common.request import ServiceRequest
from assemblyline_v4_service.common.task import Task
from assemblyline_v4_service.common.ontology_helper import OntologyHelper
Expand All @@ -42,6 +43,7 @@ def is_recoverable_runtime_error(error):


class ServiceBase:
"""Base class for Assemblyline services"""
def __init__(self, config: Optional[Dict] = None) -> None:
# Load the service attributes from the service manifest
self.service_attributes = helper.get_service_attributes()
Expand All @@ -63,24 +65,18 @@ def __init__(self, config: Optional[Dict] = None) -> None:
self.log.warning = self._warning
self.log.error = self._error

self._task = None
self._task: Task | None = None

self._working_directory = None
self._working_directory: str | None = None

# Initialize interface for interacting with system safelist
self._api_interface = None
self._api_interface: ServiceAPI | None = None

self.dependencies = self._get_dependencies_info()
self.ontology = OntologyHelper(self.log, self.service_attributes.name)

# Updater-related
self.rules_directory: str = None
self.rules_list: list = []
self.update_time: int = None
self.rules_hash: str = None

@property
def api_interface(self):
def api_interface(self) -> ServiceAPI:
return self.get_api_interface()

def _get_dependencies_info(self) -> Dict[str, Dict[str, str]]:
Expand All @@ -96,14 +92,10 @@ def _get_dependencies_info(self) -> Dict[str, Dict[str, str]]:
def _cleanup(self) -> None:
self._task = None
self._working_directory = None
if self.dependencies.get('updates', None):
try:
self._download_rules()
except Exception as e:
raise Exception(f"Something went wrong while trying to load {self.name} rules: {str(e)}")

def _handle_execute_failure(self, exception, stack_info) -> None:
# Clear the result, in case it caused the problem
assert self._task
self._task.result = None

# Clear the extracted and supplementary files
Expand All @@ -120,6 +112,7 @@ def _handle_execute_failure(self, exception, stack_info) -> None:
self._task.save_error(stack_info, recoverable=False)

def _success(self) -> None:
assert self._task
self._task.success()

def _warning(self, msg: str, *args, **kwargs) -> None:
Expand All @@ -132,12 +125,12 @@ def _error(self, msg: str, *args, **kwargs) -> None:
msg = f"({self._task.sid}/{self._task.sha256}): {msg}"
self._log_error(msg, *args, **kwargs)

def get_api_interface(self):
def get_api_interface(self) -> ServiceAPI:
if not self._api_interface:
if PRIVILEGED:
self._api_interface = PrivilegedServiceAPI(self.log)
else:
self._api_interface = ServiceAPI(self.service_attributes, self.log)
self._api_interface = HostedServiceAPI(self.service_attributes, self.log)

return self._api_interface

Expand All @@ -153,7 +146,7 @@ def get_service_version(self) -> str:

# noinspection PyMethodMayBeStatic
def get_tool_version(self) -> Optional[str]:
return self.rules_hash
return None

def handle_task(self, task: ServiceTask) -> None:
try:
Expand Down Expand Up @@ -187,22 +180,6 @@ def start(self) -> None:

def start_service(self) -> None:
self.log.info(f"Starting service: {self.service_attributes.name} ({self.service_attributes.version})")

if self.dependencies.get('updates', None):
# Start with a clean update dir
if os.path.exists(UPDATES_DIR):
for files in os.scandir(UPDATES_DIR):
path = os.path.join(UPDATES_DIR, files)
try:
shutil.rmtree(path)
except OSError:
os.remove(path)

try:
self._download_rules()
except Exception as e:
raise Exception(f"Something went wrong while trying to load {self.name} rules: {str(e)}")

self.start()

def stop(self) -> None:
Expand Down Expand Up @@ -234,6 +211,48 @@ def working_directory(self):

return self._working_directory


class RulesServiceBase(ServiceBase):
"""Base Class for services with updateable rules"""
def __init__(self, config: Optional[Dict] = None) -> None:
self.rules_directory: str | None = None
self.rules_list: list = []
self.update_time: int | None = None
self.rules_hash: str | None = None
super().__init__(config)

def get_tool_version(self) -> Optional[str]:
return self.rules_hash

def start_service(self) -> None:
self.log.info(f"Starting service: {self.service_attributes.name} ({self.service_attributes.version})")

if self.dependencies.get('updates', None):
# Start with a clean update dir
if os.path.exists(UPDATES_DIR):
for files in os.scandir(UPDATES_DIR):
path = os.path.join(UPDATES_DIR, files)
try:
shutil.rmtree(path)
except OSError:
os.remove(path)

try:
self._download_rules()
except Exception as e:
raise Exception(f"Something went wrong while trying to load {self.name} rules: {str(e)}")

self.start()

def _cleanup(self) -> None:
self._task = None
self._working_directory = None
if self.dependencies.get('updates', None):
try:
self._download_rules()
except Exception as e:
raise Exception(f"Something went wrong while trying to load {self.name} rules: {str(e)}")

# Only relevant for services using updaters (reserving 'updates' as the defacto container name)
def _download_rules(self):
scheme, verify = 'http', None
Expand Down Expand Up @@ -301,6 +320,7 @@ def _download_rules(self):

# Generate the rules_hash and init rules_list based on the raw files in the rules_directory from updater
def _gen_rules_hash(self) -> str:
assert self.rules_directory
self.rules_list = [str(f) for f in Path(self.rules_directory).rglob("*") if os.path.isfile(str(f))]
all_sha256s = [get_sha256_for_file(f) for f in self.rules_list]

Expand Down
6 changes: 3 additions & 3 deletions assemblyline_v4_service/common/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import tempfile

from PIL import Image
from typing import Any, Dict, Optional, TextIO, Union
from typing import Any, Dict, Optional, TextIO

from assemblyline.common import forge
from assemblyline.common import log as al_log
from assemblyline.common.classification import Classification
from assemblyline_v4_service.common.api import ServiceAPI, PrivilegedServiceAPI
from assemblyline_v4_service.common.api import ServiceAPI
from assemblyline_v4_service.common.extractor.ocr import ocr_detections
from assemblyline_v4_service.common.result import Heuristic, Result, ResultKeyValueSection
from assemblyline_v4_service.common.task import Task, MaxExtractedExceeded
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self, task: Task) -> None:

def add_extracted(self, path: str, name: str, description: str,
classification: Optional[Classification] = None,
safelist_interface: Optional[Union[ServiceAPI, PrivilegedServiceAPI]] = None,
safelist_interface: Optional[ServiceAPI] = None,
allow_dynamic_recursion: bool = False, parent_relation: str = 'EXTRACTED') -> bool:
"""
Add an extracted file for additional processing.
Expand Down
4 changes: 2 additions & 2 deletions assemblyline_v4_service/common/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from assemblyline.common.digests import get_digests_for_file, get_sha256_for_file
from assemblyline.common.isotime import now_as_iso
from assemblyline.odm.messages.task import Task as ServiceTask
from assemblyline_v4_service.common.api import ServiceAPI, PrivilegedServiceAPI
from assemblyline_v4_service.common.api import ServiceAPI
from assemblyline_v4_service.common.result import Result
from assemblyline_v4_service.common.helper import get_service_manifest

Expand Down Expand Up @@ -106,7 +106,7 @@ def _add_file(self, path: str, name: str, description: str,

def add_extracted(self, path: str, name: str, description: str,
classification: Optional[Classification] = None,
safelist_interface: Optional[Union[ServiceAPI, PrivilegedServiceAPI]] = None,
safelist_interface: Optional[ServiceAPI] = None,
allow_dynamic_recursion: bool = False, parent_relation: str = 'EXTRACTED') -> bool:

# Service-based safelisting of files has to be configured at the global configuration
Expand Down