Skip to content

Commit

Permalink
More types supported (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
Raalsky authored Apr 18, 2024
1 parent f511153 commit f645534
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 44 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
## [UNRELEASED] neptune-fetcher 0.1.1
## [UNRELEASED] neptune-fetcher 0.2.0

### Features
- Added support for bool, state, datetime and float series ([#19](https://github.com/neptune-ai/neptune-fetcher/pull/19))
- Added support for fetching float series values ([#19](https://github.com/neptune-ai/neptune-fetcher/pull/19))

### Changes
- Using only paths filter endpoint instead of dedicated ones ([#17](https://github.com/neptune-ai/neptune-fetcher/pull/17))
Expand Down
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,79 @@ __Returns__:

---

### Datetime
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
created_at = run["sys/creation_time"].fetch()
```

__Returns__:
`datetime.datetime`

---

### Object state
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
state = run["sys/state"].fetch()
```

__Returns__:
`str`

---

### Boolean
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
status = run["sys/failed"].fetch()
```

__Returns__:
`bool`

---

### Float series
#### `fetch()` or `fetch_last()`
Retrieves last series value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
loss = run["loss"].fetch_last()
```

__Returns__:
`Optional[float]`

#### `fetch_values()`
Retrieves all series values from the API.

__Parameters__:

| Name | Type | Default | Description |
| ---- |--------|---------|----------------------------|
| `include_timestamp` | `bool` | True | Whether the fetched data should include the timestamp field. |

__Example__:
```python
values = run["loss"].fetch_values()
```

__Returns__:
`pandas.DataFrame`

---

## License

This project is licensed under the Apache License Version 2.0. For more details, see [Apache License Version 2.0](http://www.apache.org/licenses/LICENSE-2.0).
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pattern = "default-unprefixed"
python = "^3.7"

# Base neptune package
neptune = "2.0.0a1"
neptune = "2.0.0a2"

# Optional for default progress update handling
tqdm = { version = ">=4.66.0", optional = true }
Expand Down
44 changes: 33 additions & 11 deletions src/neptune_fetcher/fetchable.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import (
TYPE_CHECKING,
Any,
Union,
)

from neptune.api.models import (
Expand All @@ -53,14 +54,20 @@
StringSeriesField,
StringSetField,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.utils.logger import get_logger
from neptune.internal.warnings import NeptuneUnsupportedType

from neptune_fetcher.fields import (
Bool,
DateTime,
Field,
Float,
FloatSeries,
Integer,
ObjectState,
Series,
String,
StringSet,
)

if TYPE_CHECKING:
Expand All @@ -77,8 +84,14 @@
FieldType.INT,
FieldType.FLOAT,
FieldType.STRING,
FieldType.BOOL,
FieldType.DATETIME,
FieldType.STRING_SET,
FieldType.OBJECT_STATE,
}
SUPPORTED_SERIES_TYPES = {
FieldType.FLOAT_SERIES,
}
SUPPORTED_SERIES_TYPES = set()
SUPPORTED_TYPES = {*SUPPORTED_ATOMS, *SUPPORTED_SERIES_TYPES}


Expand Down Expand Up @@ -113,10 +126,19 @@ def fetch(self):

class FetchableSeries(Fetchable):
def fetch(self):
raise NeptuneUnsupportedType()
return self._cache[self._field.path].last

def fetch_last(self):
return self.fetch()

def fetch_values(self, *, include_timestamp: bool = True) -> "DataFrame":
raise NeptuneUnsupportedType()
return self._cache[self._field.path].fetch_values(
backend=self._backend,
container_id=self._container_id,
container_type=ContainerType.RUN,
path=self._field.path,
include_timestamp=include_timestamp,
)


def which_fetchable(field: FieldDefinition, *args: Any, **kwargs: Any) -> Fetchable:
Expand All @@ -127,30 +149,30 @@ def which_fetchable(field: FieldDefinition, *args: Any, **kwargs: Any) -> Fetcha
return NoopFetchable(field, *args, **kwargs)


class FieldToFetchableVisitor(FieldVisitor[Field]):
class FieldToFetchableVisitor(FieldVisitor[Union[Field, Series]]):
def visit_float(self, field: FloatField) -> Field:
return Float(field.type, val=field.value)

def visit_int(self, field: IntField) -> Field:
return Integer(field.type, val=field.value)

def visit_bool(self, field: BoolField) -> Field:
raise NotImplementedError
return Bool(field.type, val=field.value)

def visit_string(self, field: StringField) -> Field:
return String(field.type, val=field.value)

def visit_datetime(self, field: DateTimeField) -> Field:
raise NotImplementedError
return DateTime(field.type, val=field.value)

def visit_file(self, field: FileField) -> Field:
raise NotImplementedError

def visit_file_set(self, field: FileSetField) -> Field:
raise NotImplementedError

def visit_float_series(self, field: FloatSeriesField) -> Field:
raise NotImplementedError
def visit_float_series(self, field: FloatSeriesField) -> Series:
return FloatSeries(field.type, last=field.last)

def visit_string_series(self, field: StringSeriesField) -> Field:
raise NotImplementedError
Expand All @@ -159,13 +181,13 @@ def visit_image_series(self, field: ImageSeriesField) -> Field:
raise NotImplementedError

def visit_string_set(self, field: StringSetField) -> Field:
raise NotImplementedError
return StringSet(field.type, val=field.values)

def visit_git_ref(self, field: GitRefField) -> Field:
raise NotImplementedError

def visit_object_state(self, field: ObjectStateField) -> Field:
raise NotImplementedError
return ObjectState(field.type, val=field.value)

def visit_notebook_ref(self, field: NotebookRefField) -> Field:
raise NotImplementedError
Expand Down
102 changes: 96 additions & 6 deletions src/neptune_fetcher/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,121 @@
"Integer",
"Series",
"String",
"Bool",
"DateTime",
"ObjectState",
"StringSet",
]

import typing
import abc
from abc import ABC
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from typing import (
TYPE_CHECKING,
Dict,
Generic,
List,
Optional,
Set,
TypeVar,
Union,
)

from neptune.api.models import FieldType
from neptune.api.fetching_series_values import fetch_series_values
from neptune.api.models import (
FieldType,
FloatPointValue,
FloatSeriesValues,
StringPointValue,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.utils.paths import parse_path

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from neptune.internal.backends.hosted_neptune_backend import HostedNeptuneBackend
from pandas import DataFrame

T = TypeVar("T")
Row = TypeVar("Row", StringPointValue, FloatPointValue)


def make_row(entry: Row, include_timestamp: bool = True) -> Dict[str, Union[str, float, datetime]]:
row: Dict[str, Union[str, float, datetime]] = {
"step": entry.step,
"value": entry.value,
}

if include_timestamp:
row["timestamp"] = entry.timestamp

return row


@dataclass
class Series(ABC, Generic[T]):
type: FieldType
last: Optional[T] = None

def fetch_values(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: typing.List[str],
path: str,
include_timestamp: bool = True,
) -> "DataFrame":
raise NotImplementedError
import pandas as pd

data = fetch_series_values(
getter=partial(
self._fetch_values_from_backend,
backend=backend,
container_id=container_id,
container_type=container_type,
path=parse_path(path),
),
path=path,
progress_bar=None,
)

rows = dict((n, make_row(entry=entry, include_timestamp=include_timestamp)) for (n, entry) in enumerate(data))
return pd.DataFrame.from_dict(data=rows, orient="index")

@abc.abstractmethod
def _fetch_values_from_backend(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: List[str],
limit: int,
from_step: Optional[float] = None,
):
...

def fetch_last(self) -> Optional[T]:
return self.last


class FloatSeries(Series[float]):
...
def _fetch_values_from_backend(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: List[str],
limit: int,
from_step: Optional[float] = None,
) -> FloatSeriesValues:
return backend.get_float_series_values(
container_id=container_id,
container_type=container_type,
path=path,
from_step=from_step,
limit=limit,
)


@dataclass
Expand All @@ -82,3 +156,19 @@ class Float(Field[float]):

class String(Field[str]):
...


class Bool(Field[bool]):
...


class DateTime(Field[datetime]):
...


class ObjectState(Field[str]):
...


class StringSet(Field[Set[str]]):
...
Loading

0 comments on commit f645534

Please sign in to comment.