diff --git a/tests/test_cli.py b/tests/test_cli.py index 463f356876..7620247f66 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Type from unittest import mock import pytest @@ -11,6 +11,11 @@ from unblob.handlers import BUILTIN_HANDLERS from unblob.models import DirectoryHandler, Glob, Handler, HexString, MultiFile from unblob.processing import DEFAULT_DEPTH, DEFAULT_PROCESS_NUM, ExtractionConfig +from unblob.ui import ( + NullProgressReporter, + ProgressReporter, + RichConsoleProgressReporter, +) class TestHandler(Handler): @@ -174,18 +179,50 @@ def test_dir_for_file(tmp_path: Path): @pytest.mark.parametrize( - "params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity", + "params, expected_depth, expected_entropy_depth, expected_process_num, expected_verbosity, expected_progress_reporter", [ - pytest.param([], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 0, id="empty"), pytest.param( - ["--verbose"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 1, id="verbose-1" + [], + DEFAULT_DEPTH, + 1, + DEFAULT_PROCESS_NUM, + 0, + RichConsoleProgressReporter, + id="empty", + ), + pytest.param( + ["--verbose"], + DEFAULT_DEPTH, + 1, + DEFAULT_PROCESS_NUM, + 1, + NullProgressReporter, + id="verbose-1", + ), + pytest.param( + ["-vv"], + DEFAULT_DEPTH, + 1, + DEFAULT_PROCESS_NUM, + 2, + NullProgressReporter, + id="verbose-2", + ), + pytest.param( + ["-vvv"], + DEFAULT_DEPTH, + 1, + DEFAULT_PROCESS_NUM, + 3, + NullProgressReporter, + id="verbose-3", + ), + pytest.param( + ["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, mock.ANY, id="depth" ), - pytest.param(["-vv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 2, id="verbose-2"), pytest.param( - ["-vvv"], DEFAULT_DEPTH, 1, DEFAULT_PROCESS_NUM, 3, id="verbose-3" + ["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, mock.ANY, id="process-num" ), - pytest.param(["--depth", "2"], 2, 1, DEFAULT_PROCESS_NUM, 0, id="depth"), - pytest.param(["--process-num", "2"], DEFAULT_DEPTH, 1, 2, 0, id="process-num"), ], ) def test_archive_success( @@ -194,6 +231,7 @@ def test_archive_success( expected_entropy_depth: int, expected_process_num: int, expected_verbosity: int, + expected_progress_reporter: Type[ProgressReporter], tmp_path: Path, ): runner = CliRunner() @@ -225,6 +263,7 @@ def test_archive_success( process_num=expected_process_num, handlers=BUILTIN_HANDLERS, verbose=expected_verbosity, + progress_reporter=expected_progress_reporter, ) process_file_mock.assert_called_once_with(config, in_path, None) logger_config_mock.assert_called_once_with(expected_verbosity, tmp_path, log_path) diff --git a/unblob/cli.py b/unblob/cli.py index 077a019626..aeffb1b616 100755 --- a/unblob/cli.py +++ b/unblob/cli.py @@ -26,6 +26,7 @@ ExtractionConfig, process_file, ) +from .ui import NullProgressReporter, RichConsoleProgressReporter logger = get_logger() @@ -258,6 +259,9 @@ def cli( dir_handlers=dir_handlers, keep_extracted_chunks=keep_extracted_chunks, verbose=verbose, + progress_reporter=NullProgressReporter + if verbose + else RichConsoleProgressReporter, ) logger.info("Start processing file", file=file) diff --git a/unblob/processing.py b/unblob/processing.py index 76cc98fddc..cc34fb7127 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -2,13 +2,11 @@ import shutil from operator import attrgetter from pathlib import Path -from typing import Iterable, List, Optional, Sequence, Set, Tuple +from typing import Iterable, List, Optional, Sequence, Set, Tuple, Type import attr import magic import plotext as plt -from rich import progress -from rich.style import Style from structlog import get_logger from unblob_native import math_tools as mt @@ -45,6 +43,7 @@ UnknownError, ) from .signals import terminate_gracefully +from .ui import NullProgressReporter, ProgressReporter logger = get_logger() @@ -94,6 +93,7 @@ class ExtractionConfig: handlers: Handlers = BUILTIN_HANDLERS dir_handlers: DirectoryHandlers = BUILTIN_DIR_HANDLERS verbose: int = 1 + progress_reporter: Type[ProgressReporter] = NullProgressReporter def get_extract_dir_for(self, path: Path) -> Path: """Return extraction dir under root with the name of path.""" @@ -146,26 +146,11 @@ def _process_task(config: ExtractionConfig, task: Task) -> ProcessResult: processor = Processor(config) aggregated_result = ProcessResult() - if not config.verbose: - progress_display = progress.Progress( - progress.TextColumn( - "Extraction progress: {task.percentage:>3.0f}%", - style=Style(color="#00FFC8"), - ), - progress.BarColumn( - complete_style=Style(color="#00FFC8"), style=Style(color="#002060") - ), - ) - progress_display.start() - overall_progress_task = progress_display.add_task("Extraction progress:") + progress_reporter = config.progress_reporter() def process_result(pool, result): - if config.verbose == 0 and progress_display.tasks[0].total is not None: - progress_display.update( - overall_progress_task, - advance=1, - total=progress_display.tasks[0].total + len(result.subtasks), - ) + progress_reporter.update(result) + for new_task in result.subtasks: pool.submit(new_task) aggregated_result.register(result) @@ -176,14 +161,10 @@ def process_result(pool, result): result_callback=process_result, ) - with pool: + with pool, progress_reporter: pool.submit(task) pool.process_until_done() - if not config.verbose: - progress_display.remove_task(overall_progress_task) - progress_display.stop() - return aggregated_result diff --git a/unblob/ui.py b/unblob/ui.py new file mode 100644 index 0000000000..a332d54778 --- /dev/null +++ b/unblob/ui.py @@ -0,0 +1,57 @@ +from typing import Protocol + +from rich import progress +from rich.style import Style + +from .models import TaskResult + + +class ProgressReporter(Protocol): + def __enter__(self): + ... + + def __exit__(self, _exc_type, _exc_value, _tb): + ... + + def update(self, result: TaskResult): + ... + + +class NullProgressReporter: + def __enter__(self): + pass + + def __exit__(self, _exc_type, _exc_value, _tb): + pass + + def update(self, result: TaskResult): + pass + + +class RichConsoleProgressReporter: + def __init__(self): + self._progress = progress.Progress( + progress.TextColumn( + "Extraction progress: {task.percentage:>3.0f}%", + style=Style(color="#00FFC8"), + ), + progress.BarColumn( + complete_style=Style(color="#00FFC8"), style=Style(color="#002060") + ), + ) + self._overall_progress_task = self._progress.add_task("Extraction progress:") + + def __enter__(self): + self._progress.start() + + def __exit__(self, _exc_type, _exc_value, _tb): + self._progress.remove_task(self._overall_progress_task) + self._progress.stop() + + def update(self, result: TaskResult): + if (total := self._progress.tasks[0].total) is not None: + self._progress.update( + self._overall_progress_task, + advance=1, + total=total + len(result.subtasks), + )