From 3310cbdc506593fd1f62edb011989f913f3cc49c Mon Sep 17 00:00:00 2001 From: mr-raccoon-97 Date: Thu, 21 Nov 2024 02:57:00 +0000 Subject: [PATCH] First push --- .github/workflows/python-CD.yml | 14 ++++ .github/workflows/python-CI.yml | 26 +++++++ README.md | 2 - mltracker/__init__.py | 8 ++ mltracker/adapters/tinydb/__init__.py | 0 mltracker/adapters/tinydb/aggregates.py | 61 +++++++++++++++ mltracker/adapters/tinydb/experiments.py | 32 ++++++++ mltracker/adapters/tinydb/getters.py | 21 +++++ mltracker/adapters/tinydb/iterations.py | 34 +++++++++ mltracker/adapters/tinydb/metrics.py | 20 +++++ mltracker/ports/aggregates.py | 38 ++++++++++ mltracker/ports/experiments.py | 23 ++++++ mltracker/ports/iterations.py | 39 ++++++++++ mltracker/ports/metrics.py | 23 ++++++ mltracker/ports/modules.py | 17 +++++ poetry.lock | 85 +++++++++++++++++++++ pyproject.toml | 22 ++++++ pytest.ini | 2 + tests/__init__.py | 0 tests/tinydb/conftest.py | 17 +++++ tests/tinydb/test_aggregates.py | 26 +++++++ tests/tinydb/test_experiments.py | 27 +++++++ tests/tinydb/test_iterations.py | 97 ++++++++++++++++++++++++ tests/tinydb/test_metrics.py | 15 ++++ 24 files changed, 647 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/python-CD.yml create mode 100644 .github/workflows/python-CI.yml create mode 100644 mltracker/__init__.py create mode 100644 mltracker/adapters/tinydb/__init__.py create mode 100644 mltracker/adapters/tinydb/aggregates.py create mode 100644 mltracker/adapters/tinydb/experiments.py create mode 100644 mltracker/adapters/tinydb/getters.py create mode 100644 mltracker/adapters/tinydb/iterations.py create mode 100644 mltracker/adapters/tinydb/metrics.py create mode 100644 mltracker/ports/aggregates.py create mode 100644 mltracker/ports/experiments.py create mode 100644 mltracker/ports/iterations.py create mode 100644 mltracker/ports/metrics.py create mode 100644 mltracker/ports/modules.py create mode 100644 poetry.lock create mode 100644 pyproject.toml create mode 100644 pytest.ini create mode 100644 tests/__init__.py create mode 100644 tests/tinydb/conftest.py create mode 100644 tests/tinydb/test_aggregates.py create mode 100644 tests/tinydb/test_experiments.py create mode 100644 tests/tinydb/test_iterations.py create mode 100644 tests/tinydb/test_metrics.py diff --git a/.github/workflows/python-CD.yml b/.github/workflows/python-CD.yml new file mode 100644 index 0000000..565a92c --- /dev/null +++ b/.github/workflows/python-CD.yml @@ -0,0 +1,14 @@ +name: CD +on: + push: + tags: + - "v*.*.*" +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build and publish to pypi + uses: JRubics/poetry-publish@v2.0 + with: + pypi_token: ${{ secrets.PYPI_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml new file mode 100644 index 0000000..ff29957 --- /dev/null +++ b/.github/workflows/python-CI.yml @@ -0,0 +1,26 @@ +name: CI +on: [push, pull_request] + +jobs: + tests: + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + poetry-version: ["latest", "1.8.3"] + os: [ubuntu-22.04, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Run image + uses: abatilo/actions-poetry@v2 + with: + poetry-version: ${{ matrix.poetry-version }} + - name: Install dependencies + run: poetry install + - name: Test with pytest + run: | + poetry run pytest \ No newline at end of file diff --git a/README.md b/README.md index 06786a3..e69de29 100644 --- a/README.md +++ b/README.md @@ -1,2 +0,0 @@ -# ml-tracker -A simple tool for tracking machine learning aggregates and storing their data. diff --git a/mltracker/__init__.py b/mltracker/__init__.py new file mode 100644 index 0000000..5ba8c72 --- /dev/null +++ b/mltracker/__init__.py @@ -0,0 +1,8 @@ +from mltracker.adapters.tinydb.getters import get_experiments_collection, get_experiment, get_aggregates_collection +from mltracker.adapters.tinydb.experiments import Experiment +from mltracker.adapters.tinydb.aggregates import Aggregate +from mltracker.adapters.tinydb.metrics import Metric +from mltracker.adapters.tinydb.iterations import Iteration +from mltracker.adapters.tinydb.aggregates import Module + +#TODO: Fix this to support other adapters. For now only tinydb is supported. \ No newline at end of file diff --git a/mltracker/adapters/tinydb/__init__.py b/mltracker/adapters/tinydb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mltracker/adapters/tinydb/aggregates.py b/mltracker/adapters/tinydb/aggregates.py new file mode 100644 index 0000000..5426b1f --- /dev/null +++ b/mltracker/adapters/tinydb/aggregates.py @@ -0,0 +1,61 @@ +from typing import Any +from typing import Optional +from mltracker.adapters.tinydb.metrics import Metrics +from mltracker.adapters.tinydb.iterations import Iterations +from mltracker.ports.aggregates import Aggregates as Collection +from mltracker.ports.aggregates import Aggregate, asdict +from mltracker.ports.modules import Module +from tinydb import TinyDB, where + +class Aggregates(Collection): + def __init__(self, owner: Any, database: TinyDB): + self.owner = str(owner) + self.database = database + self.table = self.database.table('aggregates') + + def create(self, id: str, modules: list[Module]) -> Aggregate: + if self.table.contains((where('owner') == self.owner) & (where('id') == id)): + raise ValueError(f'Aggregate with id {id} already exists') + + aggregate = Aggregate( + id=id, + epochs=0, + modules=modules, + metrics=Metrics(id, self.database), + iterations=Iterations(id, self.database) + ) + + self.table.insert({ + 'owner': self.owner, + 'id': id, + 'epochs': 0, + 'modules': [asdict(module) for module in modules] + }) + return aggregate + + def get(self, id: str) -> Optional[Aggregate]: + result = self.table.get((where('owner') == self.owner) & (where('id') == id)) + if result: + return Aggregate( + **{key: value for key, value in result.items() if key != 'owner' and key != 'modules'}, + modules=[Module(**module) for module in result['modules']], + metrics=Metrics(id, self.database), + iterations=Iterations(id, self.database) + ) + return None + + def patch(self, id: str, epochs: int): + self.table.update({'epochs': epochs}, (where('owner') == self.owner) & (where('id') == id)) + + def list(self) -> list[Aggregate]: + results = self.table.search(where('owner') == self.owner) + return [Aggregate( + **{key: value for key, value in result.items() if key != 'owner' and key != 'modules'}, + modules=[Module(**module) for module in result['modules']], + metrics=Metrics(result['id'], self.database), + iterations=Iterations(result['id'], self.database) + ) for result in results] + + def remove(self, aggregate: Aggregate): + aggregate.metrics.clear(), aggregate.iterations.clear() + self.table.remove((where('owner') == self.owner) & (where('id') == aggregate.id)) \ No newline at end of file diff --git a/mltracker/adapters/tinydb/experiments.py b/mltracker/adapters/tinydb/experiments.py new file mode 100644 index 0000000..251c727 --- /dev/null +++ b/mltracker/adapters/tinydb/experiments.py @@ -0,0 +1,32 @@ +from os import path, makedirs +from uuid import uuid4 +from typing import Optional +from tinydb import TinyDB, where +from mltracker.ports.experiments import Experiment +from mltracker.ports.experiments import Experiments as Collection + +class Experiments(Collection): + def __init__(self, database: TinyDB): + self.database = database + self.table = self.database.table('experiments') + + def read(self, name: str) -> Optional[Experiment]: + result = self.table.get(where('name') == name) + return Experiment(id=result['id'], name=result['name']) if result else None + + def create(self, name: str) -> Experiment: + result = self.table.get(where('name') == name) + if result: + raise ValueError(f'Experiment with name {name} already exists') + id = uuid4() + self.table.insert({'id': str(id), 'name': name}) + return Experiment(id=id, name=name) + + def update(self, experiment: Experiment): + self.table.update({'name': experiment.name}, where('id') == str(experiment.id)) + + def delete(self, name: str): + self.table.remove(where('name') == name) + + def list(self) -> list[Experiment]: + return [Experiment(id=result['id'], name=result['name']) for result in self.table.all()] \ No newline at end of file diff --git a/mltracker/adapters/tinydb/getters.py b/mltracker/adapters/tinydb/getters.py new file mode 100644 index 0000000..63e6ac8 --- /dev/null +++ b/mltracker/adapters/tinydb/getters.py @@ -0,0 +1,21 @@ +from os import path, makedirs +from tinydb import TinyDB +from mltracker.adapters.tinydb.experiments import Experiment, Experiments +from mltracker.adapters.tinydb.aggregates import Aggregates + +def get_experiments_collection(database_location: str) -> Experiments: + if not path.exists(database_location): + makedirs(database_location) + database = TinyDB(path.join(database_location, 'database.json')) + return Experiments(database) + +def get_experiment(name: str, database_location: str) -> Experiment: + experiments = get_experiments_collection(database_location) + experiment = experiments.read(name) + if not experiment: + experiment = experiments.create(name) + return experiment + +def get_aggregates_collection(experiment_name: str, database_location: str) -> Aggregates: + experiment = get_experiment(experiment_name, database_location) + return Aggregates(experiment.id, database_location) \ No newline at end of file diff --git a/mltracker/adapters/tinydb/iterations.py b/mltracker/adapters/tinydb/iterations.py new file mode 100644 index 0000000..d410ae6 --- /dev/null +++ b/mltracker/adapters/tinydb/iterations.py @@ -0,0 +1,34 @@ +from typing import Any +from mltracker.ports.iterations import Iterations as Collection +from mltracker.ports.iterations import Iteration, asdict +from mltracker.ports.modules import Module +from mltracker.ports.iterations import Dataset +from tinydb import TinyDB, where + + +class Iterations(Collection): + def __init__(self, owner: Any, database: TinyDB): + self.owner = str(owner) + self.database = database + self.table = self.database.table('iterations') + + def put(self, iteration: Iteration): + self.table.upsert({ + 'owner': self.owner, + 'hash': iteration.hash, + 'iteration': {key: value for key, value in asdict(iteration).items() if key != 'hash'} + }, where('hash') == iteration.hash) + + def list(self) -> list[Iteration]: + results = self.table.search(where('owner') == self.owner) + return [Iteration( + hash=result['hash'], + phase=result['iteration']['phase'], + epoch=result['iteration']['epoch'], + dataset=Dataset(**result['iteration']['dataset']), + arguments=result['iteration']['arguments'], + modules=[Module(**module) for module in result['iteration']['modules']], + ) for result in results] + + def clear(self): + self.table.remove(where('owner') == self.owner) \ No newline at end of file diff --git a/mltracker/adapters/tinydb/metrics.py b/mltracker/adapters/tinydb/metrics.py new file mode 100644 index 0000000..ab423b4 --- /dev/null +++ b/mltracker/adapters/tinydb/metrics.py @@ -0,0 +1,20 @@ +from os import makedirs, path +from mltracker.ports.metrics import Metrics as Collection +from mltracker.ports.metrics import Metric, asdict +from tinydb import TinyDB, where + +class Metrics(Collection): + def __init__(self, owner: str, database: TinyDB): + self.owner = str(owner) + self.database = database + self.table = self.database.table('metrics') + + def add(self, metric: Metric): + self.table.insert({'owner': self.owner, **asdict(metric)}) + + def list(self) -> list[Metric]: + results = self.table.search(where('owner') == self.owner) + return [Metric(**{key: value for key, value in result.items() if key != 'owner'}) for result in results] + + def clear(self): + self.table.remove(where('owner') == self.owner) \ No newline at end of file diff --git a/mltracker/ports/aggregates.py b/mltracker/ports/aggregates.py new file mode 100644 index 0000000..5b7c86e --- /dev/null +++ b/mltracker/ports/aggregates.py @@ -0,0 +1,38 @@ +from typing import Optional +from typing import Any +from abc import ABC, abstractmethod +from dataclasses import dataclass, asdict +from copy import deepcopy +from mltracker.ports.modules import Module +from mltracker.ports.metrics import Metrics +from mltracker.ports.iterations import Iterations + +@dataclass +class Aggregate: + id: Any + epochs: int + modules: dict[str, Module] + metrics: Metrics + iterations: Iterations + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Aggregate): + return False + return asdict(self) == asdict(value) + + def __hash__(self) -> int: + return hash(self.id) + +class Aggregates(ABC): + + @abstractmethod + def create(self, id: str, modules: list[Module]) -> Aggregate:... + + @abstractmethod + def get(self, id: str) -> Optional[Aggregate]:... + + @abstractmethod + def patch(self, id: str, epochs: int):... + + @abstractmethod + def remove(self, aggregate: Aggregate):... \ No newline at end of file diff --git a/mltracker/ports/experiments.py b/mltracker/ports/experiments.py new file mode 100644 index 0000000..283315e --- /dev/null +++ b/mltracker/ports/experiments.py @@ -0,0 +1,23 @@ +from typing import Optional +from typing import Any +from abc import ABC, abstractmethod +from dataclasses import dataclass + +@dataclass +class Experiment: + id: Optional[Any] + name: str + +class Experiments(ABC): + + @abstractmethod + def create(self, name: str) -> Experiment: ... + + @abstractmethod + def read(self, name: str) -> Optional[Experiment]: ... + + @abstractmethod + def delete(self, name: str): ... + + @abstractmethod + def list(self) -> list[Experiment]: ... \ No newline at end of file diff --git a/mltracker/ports/iterations.py b/mltracker/ports/iterations.py new file mode 100644 index 0000000..b624f06 --- /dev/null +++ b/mltracker/ports/iterations.py @@ -0,0 +1,39 @@ +from typing import Any +from abc import ABC, abstractmethod +from dataclasses import dataclass, asdict, field +from mltracker.ports.modules import Module + +@dataclass +class Dataset: + hash: str + name: str + arguments: dict[str, Any] + +@dataclass +class Iteration: + hash: str + phase: str + epoch: int + dataset: Dataset + arguments: dict[str, Any] + modules: list[Module] + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Iteration): + return False + return self.hash == value.hash + + def __hash__(self) -> int: + return hash(self.hash) + + +class Iterations(ABC): + + @abstractmethod + def put(self, iteration: Iteration): ... + + @abstractmethod + def list(self) -> list[Iteration]: ... + + @abstractmethod + def clear(self): ... \ No newline at end of file diff --git a/mltracker/ports/metrics.py b/mltracker/ports/metrics.py new file mode 100644 index 0000000..69bf59c --- /dev/null +++ b/mltracker/ports/metrics.py @@ -0,0 +1,23 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any +from dataclasses import dataclass, asdict + +@dataclass +class Metric: + name: str + value: Any + batch: int + epoch: int + phase: str + +class Metrics(ABC): + + @abstractmethod + def add(self, metric: Metric): ... + + @abstractmethod + def list(self) -> list[Metric]: ... + + @abstractmethod + def clear(self): ... \ No newline at end of file diff --git a/mltracker/ports/modules.py b/mltracker/ports/modules.py new file mode 100644 index 0000000..ceb76be --- /dev/null +++ b/mltracker/ports/modules.py @@ -0,0 +1,17 @@ +from typing import Any +from dataclasses import dataclass + +@dataclass +class Module: + type: str + hash: str + name: str + arguments: dict[str, Any] + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Module): + return False + return self.hash == value.hash + + def __hash__(self) -> int: + return hash(self.hash) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock new file mode 100644 index 0000000..11acbbd --- /dev/null +++ b/poetry.lock @@ -0,0 +1,85 @@ +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "colorama" +version = "0.4.6" +description = "Cross-platform colored terminal text." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, + {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + +[[package]] +name = "packaging" +version = "24.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +files = [ + {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, + {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + +[[package]] +name = "pytest" +version = "8.3.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-8.3.3-py3-none-any.whl", hash = "sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2"}, + {file = "pytest-8.3.3.tar.gz", hash = "sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=1.5,<2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "tinydb" +version = "4.8.2" +description = "TinyDB is a tiny, document oriented database optimized for your happiness :)" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "tinydb-4.8.2-py3-none-any.whl", hash = "sha256:f97030ee5cbc91eeadd1d7af07ab0e48ceb04aa63d4a983adbaca4cba16e86c3"}, + {file = "tinydb-4.8.2.tar.gz", hash = "sha256:f7dfc39b8d7fda7a1ca62a8dbb449ffd340a117c1206b68c50b1a481fb95181d"}, +] + +[metadata] +lock-version = "2.0" +python-versions = "^3.12" +content-hash = "f5077a6d0cfc8e3a1b050ecdc3c80dfa14a7fbf3c81c3cfc8923a607f0874994" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..05c033a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[tool.poetry] +name = "mltracker" +version = "0.1.0" +description = "" +authors = ["mr-raccoon-97 "] +readme = "README.md" + +[tool.poetry.dependencies] +python = "^3.12" + + +[tool.poetry.group.tests.dependencies] +pytest = "^8.3.3" + + + +[tool.poetry.group.tinydb.dependencies] +tinydb = "^4.8.2" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..03f586d --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tinydb/conftest.py b/tests/tinydb/conftest.py new file mode 100644 index 0000000..430be59 --- /dev/null +++ b/tests/tinydb/conftest.py @@ -0,0 +1,17 @@ +from os import path, makedirs +from shutil import rmtree +from pytest import fixture +from logging import getLogger +from tinydb import TinyDB + +logger = getLogger(__name__) + +@fixture(scope='session') +def database(): + if not path.exists('data'): + makedirs('data') + yield TinyDB('data/database.json') + try: + rmtree('data') + except PermissionError: + logger.warning('Could not remove data directory') \ No newline at end of file diff --git a/tests/tinydb/test_aggregates.py b/tests/tinydb/test_aggregates.py new file mode 100644 index 0000000..a7166d1 --- /dev/null +++ b/tests/tinydb/test_aggregates.py @@ -0,0 +1,26 @@ +from pytest import fixture, raises +from mltracker.adapters.tinydb.aggregates import Module, Aggregates + +@fixture +def aggregates(database): + return Aggregates('test', database) + +def test_aggregates(aggregates: Aggregates): + aggregate = aggregates.create('test', [Module('test', 'test', 'test', {'test': 'test'})]) + assert aggregate.id == 'test' + assert aggregate.modules == [Module('test', 'test', 'test', {'test': 'test'})] + + aggregate = aggregates.get('test') + assert aggregate.id == 'test' + assert aggregates.get('test2') is None + assert len(aggregates.list()) == 1 + + aggregate = aggregates.create('test2', []) + assert aggregate.id == 'test2' + assert len(aggregates.list()) == 2 + + aggregate = aggregates.get('test') + aggregates.remove(aggregate) + assert aggregates.get('test') is None + with raises(ValueError): + aggregates.create('test2', []) \ No newline at end of file diff --git a/tests/tinydb/test_experiments.py b/tests/tinydb/test_experiments.py new file mode 100644 index 0000000..a0a2d29 --- /dev/null +++ b/tests/tinydb/test_experiments.py @@ -0,0 +1,27 @@ +from pytest import fixture, raises +from mltracker.adapters.tinydb.experiments import Experiments, Experiment + +@fixture +def experiments(database): + return Experiments(database) + +def test_experiments(experiments: Experiments): + experiment = experiments.create('test') + assert experiment.name == 'test' + experiment = experiments.read('test') + assert experiment.name == 'test' + assert experiments.read('test2') is None + assert len(experiments.list()) == 1 + + experiment = experiments.create('test2') + assert experiment.name == 'test2' + assert len(experiments.list()) == 2 + + experiments.delete('test') + experiment.name = 'test' + experiments.update(experiment) + + experiment = experiments.read('test') + assert experiment.name == 'test' + with raises(ValueError): + experiments.create('test') \ No newline at end of file diff --git a/tests/tinydb/test_iterations.py b/tests/tinydb/test_iterations.py new file mode 100644 index 0000000..e690e55 --- /dev/null +++ b/tests/tinydb/test_iterations.py @@ -0,0 +1,97 @@ +from pytest import fixture +from mltracker.adapters.tinydb.iterations import Iteration, Iterations +from mltracker.ports.iterations import Dataset +from mltracker.ports.modules import Module + +@fixture +def iterations(database): + return Iterations('test', database) + +def test_iterations(iterations: Iterations): + iteration = Iteration( + hash='1', + epoch=1, + phase='train', + dataset=Dataset( + hash='123', + name='mnist', + arguments={'train': True, 'normalize': True} + ), + arguments={'batch_size': 32, 'shuffle': True}, + modules=[ + Module( + type='criterion', + hash='123', + name='CrossEntropyLoss', + arguments={'reduction': 'mean'} + ), + + Module( + type='optimizer', + hash='123', + name='SGD', + arguments={'lr': 0.01} + ) + ] + ) + + + iteration2 = Iteration( + hash='2', + epoch=1, + phase='train', + dataset=Dataset( + hash='123', + name='mnist', + arguments={'train': True, 'normalize': True} + ), + arguments={'batch_size': 32, 'shuffle': True}, + modules=[ + Module( + type='criterion', + hash='123', + name='CrossEntropyLoss', + arguments={'reduction': 'mean'} + ), + + Module( + type='optimizer', + hash='123', + name='SGD', + arguments={'lr': 0.01} + ) + ] + ) + + + iteration3 = Iteration( + hash='2', + epoch=1, + phase='train', + dataset=Dataset( + hash='123', + name='mnist', + arguments={'train': True, 'normalize': True} + ), + arguments={'batch_size': 32, 'shuffle': True}, + modules=[ + Module( + type='criterion', + hash='123', + name='CrossEntropyLoss', + arguments={'reduction': 'mean'} + ), + + Module( + type='optimizer', + hash='123', + name='SGD', + arguments={'lr': 0.01} + ) + ] + ) + + iterations.put(iteration) + iterations.put(iteration2) + iterations.put(iteration3) #overriding iteration2 since hash is the same + assert iterations.list() == [iteration, iteration3] \ No newline at end of file diff --git a/tests/tinydb/test_metrics.py b/tests/tinydb/test_metrics.py new file mode 100644 index 0000000..a83048d --- /dev/null +++ b/tests/tinydb/test_metrics.py @@ -0,0 +1,15 @@ +from pytest import fixture +from mltracker.adapters.tinydb.metrics import Metrics, Metric + +@fixture +def metrics(database): + return Metrics('test', database) + +def test_metrics(metrics: Metrics): + metrics.add(Metric('accuracy', value=0.9, batch=100, epoch=1, phase='train')) + metrics.add(Metric('accuracy', value=0.2, batch=100, epoch=2, phase='train')) + metrics.add(Metric('accuracy', value=0.2, batch=100, epoch=2, phase='test')) + metrics.add(Metric('loss', value=0.3, batch=100, epoch=1, phase='train')) + assert len(metrics.list()) == 4 + metrics.clear() + assert len(metrics.list()) == 0 \ No newline at end of file