-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0ad2537
commit 3310cbd
Showing
24 changed files
with
647 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: CD | ||
on: | ||
push: | ||
tags: | ||
- "v*.*.*" | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Build and publish to pypi | ||
uses: JRubics/[email protected] | ||
with: | ||
pypi_token: ${{ secrets.PYPI_TOKEN }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: CI | ||
on: [push, pull_request] | ||
|
||
jobs: | ||
tests: | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
python-version: ["3.12"] | ||
poetry-version: ["latest", "1.8.3"] | ||
os: [ubuntu-22.04, macos-latest, windows-latest] | ||
runs-on: ${{ matrix.os }} | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
- name: Run image | ||
uses: abatilo/actions-poetry@v2 | ||
with: | ||
poetry-version: ${{ matrix.poetry-version }} | ||
- name: Install dependencies | ||
run: poetry install | ||
- name: Test with pytest | ||
run: | | ||
poetry run pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +0,0 @@ | ||
# ml-tracker | ||
A simple tool for tracking machine learning aggregates and storing their data. | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from mltracker.adapters.tinydb.getters import get_experiments_collection, get_experiment, get_aggregates_collection | ||
from mltracker.adapters.tinydb.experiments import Experiment | ||
from mltracker.adapters.tinydb.aggregates import Aggregate | ||
from mltracker.adapters.tinydb.metrics import Metric | ||
from mltracker.adapters.tinydb.iterations import Iteration | ||
from mltracker.adapters.tinydb.aggregates import Module | ||
|
||
#TODO: Fix this to support other adapters. For now only tinydb is supported. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Any | ||
from typing import Optional | ||
from mltracker.adapters.tinydb.metrics import Metrics | ||
from mltracker.adapters.tinydb.iterations import Iterations | ||
from mltracker.ports.aggregates import Aggregates as Collection | ||
from mltracker.ports.aggregates import Aggregate, asdict | ||
from mltracker.ports.modules import Module | ||
from tinydb import TinyDB, where | ||
|
||
class Aggregates(Collection): | ||
def __init__(self, owner: Any, database: TinyDB): | ||
self.owner = str(owner) | ||
self.database = database | ||
self.table = self.database.table('aggregates') | ||
|
||
def create(self, id: str, modules: list[Module]) -> Aggregate: | ||
if self.table.contains((where('owner') == self.owner) & (where('id') == id)): | ||
raise ValueError(f'Aggregate with id {id} already exists') | ||
|
||
aggregate = Aggregate( | ||
id=id, | ||
epochs=0, | ||
modules=modules, | ||
metrics=Metrics(id, self.database), | ||
iterations=Iterations(id, self.database) | ||
) | ||
|
||
self.table.insert({ | ||
'owner': self.owner, | ||
'id': id, | ||
'epochs': 0, | ||
'modules': [asdict(module) for module in modules] | ||
}) | ||
return aggregate | ||
|
||
def get(self, id: str) -> Optional[Aggregate]: | ||
result = self.table.get((where('owner') == self.owner) & (where('id') == id)) | ||
if result: | ||
return Aggregate( | ||
**{key: value for key, value in result.items() if key != 'owner' and key != 'modules'}, | ||
modules=[Module(**module) for module in result['modules']], | ||
metrics=Metrics(id, self.database), | ||
iterations=Iterations(id, self.database) | ||
) | ||
return None | ||
|
||
def patch(self, id: str, epochs: int): | ||
self.table.update({'epochs': epochs}, (where('owner') == self.owner) & (where('id') == id)) | ||
|
||
def list(self) -> list[Aggregate]: | ||
results = self.table.search(where('owner') == self.owner) | ||
return [Aggregate( | ||
**{key: value for key, value in result.items() if key != 'owner' and key != 'modules'}, | ||
modules=[Module(**module) for module in result['modules']], | ||
metrics=Metrics(result['id'], self.database), | ||
iterations=Iterations(result['id'], self.database) | ||
) for result in results] | ||
|
||
def remove(self, aggregate: Aggregate): | ||
aggregate.metrics.clear(), aggregate.iterations.clear() | ||
self.table.remove((where('owner') == self.owner) & (where('id') == aggregate.id)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from os import path, makedirs | ||
from uuid import uuid4 | ||
from typing import Optional | ||
from tinydb import TinyDB, where | ||
from mltracker.ports.experiments import Experiment | ||
from mltracker.ports.experiments import Experiments as Collection | ||
|
||
class Experiments(Collection): | ||
def __init__(self, database: TinyDB): | ||
self.database = database | ||
self.table = self.database.table('experiments') | ||
|
||
def read(self, name: str) -> Optional[Experiment]: | ||
result = self.table.get(where('name') == name) | ||
return Experiment(id=result['id'], name=result['name']) if result else None | ||
|
||
def create(self, name: str) -> Experiment: | ||
result = self.table.get(where('name') == name) | ||
if result: | ||
raise ValueError(f'Experiment with name {name} already exists') | ||
id = uuid4() | ||
self.table.insert({'id': str(id), 'name': name}) | ||
return Experiment(id=id, name=name) | ||
|
||
def update(self, experiment: Experiment): | ||
self.table.update({'name': experiment.name}, where('id') == str(experiment.id)) | ||
|
||
def delete(self, name: str): | ||
self.table.remove(where('name') == name) | ||
|
||
def list(self) -> list[Experiment]: | ||
return [Experiment(id=result['id'], name=result['name']) for result in self.table.all()] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from os import path, makedirs | ||
from tinydb import TinyDB | ||
from mltracker.adapters.tinydb.experiments import Experiment, Experiments | ||
from mltracker.adapters.tinydb.aggregates import Aggregates | ||
|
||
def get_experiments_collection(database_location: str) -> Experiments: | ||
if not path.exists(database_location): | ||
makedirs(database_location) | ||
database = TinyDB(path.join(database_location, 'database.json')) | ||
return Experiments(database) | ||
|
||
def get_experiment(name: str, database_location: str) -> Experiment: | ||
experiments = get_experiments_collection(database_location) | ||
experiment = experiments.read(name) | ||
if not experiment: | ||
experiment = experiments.create(name) | ||
return experiment | ||
|
||
def get_aggregates_collection(experiment_name: str, database_location: str) -> Aggregates: | ||
experiment = get_experiment(experiment_name, database_location) | ||
return Aggregates(experiment.id, database_location) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Any | ||
from mltracker.ports.iterations import Iterations as Collection | ||
from mltracker.ports.iterations import Iteration, asdict | ||
from mltracker.ports.modules import Module | ||
from mltracker.ports.iterations import Dataset | ||
from tinydb import TinyDB, where | ||
|
||
|
||
class Iterations(Collection): | ||
def __init__(self, owner: Any, database: TinyDB): | ||
self.owner = str(owner) | ||
self.database = database | ||
self.table = self.database.table('iterations') | ||
|
||
def put(self, iteration: Iteration): | ||
self.table.upsert({ | ||
'owner': self.owner, | ||
'hash': iteration.hash, | ||
'iteration': {key: value for key, value in asdict(iteration).items() if key != 'hash'} | ||
}, where('hash') == iteration.hash) | ||
|
||
def list(self) -> list[Iteration]: | ||
results = self.table.search(where('owner') == self.owner) | ||
return [Iteration( | ||
hash=result['hash'], | ||
phase=result['iteration']['phase'], | ||
epoch=result['iteration']['epoch'], | ||
dataset=Dataset(**result['iteration']['dataset']), | ||
arguments=result['iteration']['arguments'], | ||
modules=[Module(**module) for module in result['iteration']['modules']], | ||
) for result in results] | ||
|
||
def clear(self): | ||
self.table.remove(where('owner') == self.owner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from os import makedirs, path | ||
from mltracker.ports.metrics import Metrics as Collection | ||
from mltracker.ports.metrics import Metric, asdict | ||
from tinydb import TinyDB, where | ||
|
||
class Metrics(Collection): | ||
def __init__(self, owner: str, database: TinyDB): | ||
self.owner = str(owner) | ||
self.database = database | ||
self.table = self.database.table('metrics') | ||
|
||
def add(self, metric: Metric): | ||
self.table.insert({'owner': self.owner, **asdict(metric)}) | ||
|
||
def list(self) -> list[Metric]: | ||
results = self.table.search(where('owner') == self.owner) | ||
return [Metric(**{key: value for key, value in result.items() if key != 'owner'}) for result in results] | ||
|
||
def clear(self): | ||
self.table.remove(where('owner') == self.owner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import Optional | ||
from typing import Any | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, asdict | ||
from copy import deepcopy | ||
from mltracker.ports.modules import Module | ||
from mltracker.ports.metrics import Metrics | ||
from mltracker.ports.iterations import Iterations | ||
|
||
@dataclass | ||
class Aggregate: | ||
id: Any | ||
epochs: int | ||
modules: dict[str, Module] | ||
metrics: Metrics | ||
iterations: Iterations | ||
|
||
def __eq__(self, value: object) -> bool: | ||
if not isinstance(value, Aggregate): | ||
return False | ||
return asdict(self) == asdict(value) | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.id) | ||
|
||
class Aggregates(ABC): | ||
|
||
@abstractmethod | ||
def create(self, id: str, modules: list[Module]) -> Aggregate:... | ||
|
||
@abstractmethod | ||
def get(self, id: str) -> Optional[Aggregate]:... | ||
|
||
@abstractmethod | ||
def patch(self, id: str, epochs: int):... | ||
|
||
@abstractmethod | ||
def remove(self, aggregate: Aggregate):... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Optional | ||
from typing import Any | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class Experiment: | ||
id: Optional[Any] | ||
name: str | ||
|
||
class Experiments(ABC): | ||
|
||
@abstractmethod | ||
def create(self, name: str) -> Experiment: ... | ||
|
||
@abstractmethod | ||
def read(self, name: str) -> Optional[Experiment]: ... | ||
|
||
@abstractmethod | ||
def delete(self, name: str): ... | ||
|
||
@abstractmethod | ||
def list(self) -> list[Experiment]: ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Any | ||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass, asdict, field | ||
from mltracker.ports.modules import Module | ||
|
||
@dataclass | ||
class Dataset: | ||
hash: str | ||
name: str | ||
arguments: dict[str, Any] | ||
|
||
@dataclass | ||
class Iteration: | ||
hash: str | ||
phase: str | ||
epoch: int | ||
dataset: Dataset | ||
arguments: dict[str, Any] | ||
modules: list[Module] | ||
|
||
def __eq__(self, value: object) -> bool: | ||
if not isinstance(value, Iteration): | ||
return False | ||
return self.hash == value.hash | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.hash) | ||
|
||
|
||
class Iterations(ABC): | ||
|
||
@abstractmethod | ||
def put(self, iteration: Iteration): ... | ||
|
||
@abstractmethod | ||
def list(self) -> list[Iteration]: ... | ||
|
||
@abstractmethod | ||
def clear(self): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from abc import ABC | ||
from abc import abstractmethod | ||
from typing import Any | ||
from dataclasses import dataclass, asdict | ||
|
||
@dataclass | ||
class Metric: | ||
name: str | ||
value: Any | ||
batch: int | ||
epoch: int | ||
phase: str | ||
|
||
class Metrics(ABC): | ||
|
||
@abstractmethod | ||
def add(self, metric: Metric): ... | ||
|
||
@abstractmethod | ||
def list(self) -> list[Metric]: ... | ||
|
||
@abstractmethod | ||
def clear(self): ... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Any | ||
from dataclasses import dataclass | ||
|
||
@dataclass | ||
class Module: | ||
type: str | ||
hash: str | ||
name: str | ||
arguments: dict[str, Any] | ||
|
||
def __eq__(self, value: object) -> bool: | ||
if not isinstance(value, Module): | ||
return False | ||
return self.hash == value.hash | ||
|
||
def __hash__(self) -> int: | ||
return hash(self.hash) |
Oops, something went wrong.