Skip to content

Commit

Permalink
Merge pull request #680 from onekey-sec/refactor-progress-reporting
Browse files Browse the repository at this point in the history
processing: extract progress reporting from business logic
  • Loading branch information
qkaiser authored Nov 15, 2023
2 parents 861d76b + 9d940c8 commit 6cf2950
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 34 deletions.
55 changes: 47 additions & 8 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Type
from unittest import mock

import pytest
Expand All @@ -11,6 +11,11 @@
from unblob.handlers import BUILTIN_HANDLERS
from unblob.models import DirectoryHandler, Glob, Handler, HexString, MultiFile
from unblob.processing import DEFAULT_DEPTH, DEFAULT_PROCESS_NUM, ExtractionConfig
from unblob.ui import (
NullProgressReporter,
ProgressReporter,
RichConsoleProgressReporter,
)


class TestHandler(Handler):
Expand Down Expand Up @@ -174,18 +179,50 @@ def test_dir_for_file(tmp_path: Path):


@pytest.mark.parametrize(
"params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity",
"params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity, expected_progress_reporter",
[
pytest.param([], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 0, id="empty"),
pytest.param(
["--verbose"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 1, id="verbose-1"
[],
DEFAULT_DEPTH,
1,
DEFAULT_PROCESS_NUM,
0,
RichConsoleProgressReporter,
id="empty",
),
pytest.param(
["--verbose"],
DEFAULT_DEPTH,
1,
DEFAULT_PROCESS_NUM,
1,
NullProgressReporter,
id="verbose-1",
),
pytest.param(
["-vv"],
DEFAULT_DEPTH,
1,
DEFAULT_PROCESS_NUM,
2,
NullProgressReporter,
id="verbose-2",
),
pytest.param(
["-vvv"],
DEFAULT_DEPTH,
1,
DEFAULT_PROCESS_NUM,
3,
NullProgressReporter,
id="verbose-3",
),
pytest.param(
["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, mock.ANY, id="depth"
),
pytest.param(["-vv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 2, id="verbose-2"),
pytest.param(
["-vvv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 3, id="verbose-3"
["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, mock.ANY, id="process-num"
),
pytest.param(["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, id="depth"),
pytest.param(["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, id="process-num"),
],
)
def test_archive_success(
Expand All @@ -194,6 +231,7 @@ def test_archive_success(
expected_entropy_depth: int,
expected_process_num: int,
expected_verbosity: int,
expected_progress_reporter: Type[ProgressReporter],
tmp_path: Path,
):
runner = CliRunner()
Expand Down Expand Up @@ -225,6 +263,7 @@ def test_archive_success(
process_num=expected_process_num,
handlers=BUILTIN_HANDLERS,
verbose=expected_verbosity,
progress_reporter=expected_progress_reporter,
)
process_file_mock.assert_called_once_with(config, in_path, None)
logger_config_mock.assert_called_once_with(expected_verbosity, tmp_path, log_path)
Expand Down
4 changes: 4 additions & 0 deletions unblob/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ExtractionConfig,
process_file,
)
from .ui import NullProgressReporter, RichConsoleProgressReporter

logger = get_logger()

Expand Down Expand Up @@ -258,6 +259,9 @@ def cli(
dir_handlers=dir_handlers,
keep_extracted_chunks=keep_extracted_chunks,
verbose=verbose,
progress_reporter=NullProgressReporter
if verbose
else RichConsoleProgressReporter,
)

logger.info("Start processing file", file=file)
Expand Down
33 changes: 7 additions & 26 deletions unblob/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
import shutil
from operator import attrgetter
from pathlib import Path
from typing import Iterable, List, Optional, Sequence, Set, Tuple
from typing import Iterable, List, Optional, Sequence, Set, Tuple, Type

import attr
import magic
import plotext as plt
from rich import progress
from rich.style import Style
from structlog import get_logger
from unblob_native import math_tools as mt

Expand Down Expand Up @@ -45,6 +43,7 @@
UnknownError,
)
from .signals import terminate_gracefully
from .ui import NullProgressReporter, ProgressReporter

logger = get_logger()

Expand Down Expand Up @@ -94,6 +93,7 @@ class ExtractionConfig:
handlers: Handlers = BUILTIN_HANDLERS
dir_handlers: DirectoryHandlers = BUILTIN_DIR_HANDLERS
verbose: int = 1
progress_reporter: Type[ProgressReporter] = NullProgressReporter

def get_extract_dir_for(self, path: Path) -> Path:
"""Return extraction dir under root with the name of path."""
Expand Down Expand Up @@ -146,26 +146,11 @@ def _process_task(config: ExtractionConfig, task: Task) -> ProcessResult:
processor = Processor(config)
aggregated_result = ProcessResult()

if not config.verbose:
progress_display = progress.Progress(
progress.TextColumn(
"Extraction progress: {task.percentage:>3.0f}%",
style=Style(color="#00FFC8"),
),
progress.BarColumn(
complete_style=Style(color="#00FFC8"), style=Style(color="#002060")
),
)
progress_display.start()
overall_progress_task = progress_display.add_task("Extraction progress:")
progress_reporter = config.progress_reporter()

def process_result(pool, result):
if config.verbose == 0 and progress_display.tasks[0].total is not None:
progress_display.update(
overall_progress_task,
advance=1,
total=progress_display.tasks[0].total + len(result.subtasks),
)
progress_reporter.update(result)

for new_task in result.subtasks:
pool.submit(new_task)
aggregated_result.register(result)
Expand All @@ -176,14 +161,10 @@ def process_result(pool, result):
result_callback=process_result,
)

with pool:
with pool, progress_reporter:
pool.submit(task)
pool.process_until_done()

if not config.verbose:
progress_display.remove_task(overall_progress_task)
progress_display.stop()

return aggregated_result


Expand Down
57 changes: 57 additions & 0 deletions unblob/ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Protocol

from rich import progress
from rich.style import Style

from .models import TaskResult


class ProgressReporter(Protocol):
def __enter__(self):
...

def __exit__(self, _exc_type, _exc_value, _tb):
...

def update(self, result: TaskResult):
...


class NullProgressReporter:
def __enter__(self):
pass

def __exit__(self, _exc_type, _exc_value, _tb):
pass

def update(self, result: TaskResult):
pass


class RichConsoleProgressReporter:
def __init__(self):
self._progress = progress.Progress(
progress.TextColumn(
"Extraction progress: {task.percentage:>3.0f}%",
style=Style(color="#00FFC8"),
),
progress.BarColumn(
complete_style=Style(color="#00FFC8"), style=Style(color="#002060")
),
)
self._overall_progress_task = self._progress.add_task("Extraction progress:")

def __enter__(self):
self._progress.start()

def __exit__(self, _exc_type, _exc_value, _tb):
self._progress.remove_task(self._overall_progress_task)
self._progress.stop()

def update(self, result: TaskResult):
if (total := self._progress.tasks[0].total) is not None:
self._progress.update(
self._overall_progress_task,
advance=1,
total=total + len(result.subtasks),
)

0 comments on commit 6cf2950

Please sign in to comment.