Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More types supported #19

Merged
merged 5 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
## [UNRELEASED] neptune-fetcher 0.1.1
## [UNRELEASED] neptune-fetcher 0.2.0

### Features
- Added support for bool, state, datetime and float series ([#19](https://github.com/neptune-ai/neptune-fetcher/pull/19))
- Added support for fetching float series values ([#19](https://github.com/neptune-ai/neptune-fetcher/pull/19))

### Changes
- Using only paths filter endpoint instead of dedicated ones ([#17](https://github.com/neptune-ai/neptune-fetcher/pull/17))
Expand Down
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,79 @@ __Returns__:

---

### Datetime
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
created_at = run["sys/creation_time"].fetch()
```

__Returns__:
`datetime.datetime`

---

### Object state
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
state = run["sys/state"].fetch()
```

__Returns__:
`str`

---

### Boolean
#### `fetch()`
Retrieves value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
status = run["sys/failed"].fetch()
```

__Returns__:
`bool`

---

### Float series
#### `fetch()` or `fetch_last()`
Retrieves last series value either from the internal cache (see [`prefetch()`](#prefetch)) or from the API.

__Example__:
```python
loss = run["loss"].fetch_last()
```

__Returns__:
`Optional[float]`
normandy7 marked this conversation as resolved.
Show resolved Hide resolved

#### `fetch_values()`
Retrieves all series values from the API.

__Parameters__:

| Name | Type | Default | Description |
| ---- |--------|---------|----------------------------|
| `include_timestamp` | `bool` | True | Whether the fetched data should include the timestamp field. |

__Example__:
```python
values = run["loss"].fetch_values()
```

__Returns__:
`pandas.DataFrame`

---

## License

This project is licensed under the Apache License Version 2.0. For more details, see [Apache License Version 2.0](http://www.apache.org/licenses/LICENSE-2.0).
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pattern = "default-unprefixed"
python = "^3.7"

# Base neptune package
neptune = "2.0.0a1"
neptune = "2.0.0a2"

# Optional for default progress update handling
tqdm = { version = ">=4.66.0", optional = true }
Expand Down
44 changes: 33 additions & 11 deletions src/neptune_fetcher/fetchable.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from typing import (
TYPE_CHECKING,
Any,
Union,
)

from neptune.api.models import (
Expand All @@ -53,14 +54,20 @@
StringSeriesField,
StringSetField,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.utils.logger import get_logger
from neptune.internal.warnings import NeptuneUnsupportedType

from neptune_fetcher.fields import (
Bool,
DateTime,
Field,
Float,
FloatSeries,
Integer,
ObjectState,
Series,
String,
StringSet,
)

if TYPE_CHECKING:
Expand All @@ -77,8 +84,14 @@
FieldType.INT,
FieldType.FLOAT,
FieldType.STRING,
FieldType.BOOL,
FieldType.DATETIME,
FieldType.STRING_SET,
FieldType.OBJECT_STATE,
}
SUPPORTED_SERIES_TYPES = {
FieldType.FLOAT_SERIES,
}
SUPPORTED_SERIES_TYPES = set()
SUPPORTED_TYPES = {*SUPPORTED_ATOMS, *SUPPORTED_SERIES_TYPES}


Expand Down Expand Up @@ -113,10 +126,19 @@ def fetch(self):

class FetchableSeries(Fetchable):
def fetch(self):
raise NeptuneUnsupportedType()
return self._cache[self._field.path].last

def fetch_last(self):
return self.fetch()

def fetch_values(self, *, include_timestamp: bool = True) -> "DataFrame":
raise NeptuneUnsupportedType()
return self._cache[self._field.path].fetch_values(
backend=self._backend,
container_id=self._container_id,
container_type=ContainerType.RUN,
path=self._field.path,
include_timestamp=include_timestamp,
)


def which_fetchable(field: FieldDefinition, *args: Any, **kwargs: Any) -> Fetchable:
Expand All @@ -127,30 +149,30 @@ def which_fetchable(field: FieldDefinition, *args: Any, **kwargs: Any) -> Fetcha
return NoopFetchable(field, *args, **kwargs)


class FieldToFetchableVisitor(FieldVisitor[Field]):
class FieldToFetchableVisitor(FieldVisitor[Union[Field, Series]]):
def visit_float(self, field: FloatField) -> Field:
return Float(field.type, val=field.value)

def visit_int(self, field: IntField) -> Field:
return Integer(field.type, val=field.value)

def visit_bool(self, field: BoolField) -> Field:
raise NotImplementedError
return Bool(field.type, val=field.value)

def visit_string(self, field: StringField) -> Field:
return String(field.type, val=field.value)

def visit_datetime(self, field: DateTimeField) -> Field:
raise NotImplementedError
return DateTime(field.type, val=field.value)

def visit_file(self, field: FileField) -> Field:
raise NotImplementedError

def visit_file_set(self, field: FileSetField) -> Field:
raise NotImplementedError

def visit_float_series(self, field: FloatSeriesField) -> Field:
raise NotImplementedError
def visit_float_series(self, field: FloatSeriesField) -> Series:
return FloatSeries(field.type, last=field.last)

def visit_string_series(self, field: StringSeriesField) -> Field:
raise NotImplementedError
Expand All @@ -159,13 +181,13 @@ def visit_image_series(self, field: ImageSeriesField) -> Field:
raise NotImplementedError

def visit_string_set(self, field: StringSetField) -> Field:
raise NotImplementedError
return StringSet(field.type, val=field.values)

def visit_git_ref(self, field: GitRefField) -> Field:
raise NotImplementedError

def visit_object_state(self, field: ObjectStateField) -> Field:
raise NotImplementedError
return ObjectState(field.type, val=field.value)

def visit_notebook_ref(self, field: NotebookRefField) -> Field:
raise NotImplementedError
Expand Down
102 changes: 96 additions & 6 deletions src/neptune_fetcher/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,121 @@
"Integer",
"Series",
"String",
"Bool",
"DateTime",
"ObjectState",
"StringSet",
]

import typing
import abc
from abc import ABC
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from typing import (
TYPE_CHECKING,
Dict,
Generic,
List,
Optional,
Set,
TypeVar,
Union,
)

from neptune.api.models import FieldType
from neptune.api.fetching_series_values import fetch_series_values
from neptune.api.models import (
FieldType,
FloatPointValue,
FloatSeriesValues,
StringPointValue,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.utils.paths import parse_path

if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from neptune.internal.backends.hosted_neptune_backend import HostedNeptuneBackend
from pandas import DataFrame

T = TypeVar("T")
Row = TypeVar("Row", StringPointValue, FloatPointValue)


def make_row(entry: Row, include_timestamp: bool = True) -> Dict[str, Union[str, float, datetime]]:
row: Dict[str, Union[str, float, datetime]] = {
"step": entry.step,
"value": entry.value,
}

if include_timestamp:
row["timestamp"] = entry.timestamp

return row


@dataclass
class Series(ABC, Generic[T]):
type: FieldType
last: Optional[T] = None

def fetch_values(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: typing.List[str],
path: str,
include_timestamp: bool = True,
) -> "DataFrame":
raise NotImplementedError
import pandas as pd

data = fetch_series_values(
getter=partial(
self._fetch_values_from_backend,
backend=backend,
container_id=container_id,
container_type=container_type,
path=parse_path(path),
),
path=path,
progress_bar=None,
)

rows = dict((n, make_row(entry=entry, include_timestamp=include_timestamp)) for (n, entry) in enumerate(data))
return pd.DataFrame.from_dict(data=rows, orient="index")

@abc.abstractmethod
def _fetch_values_from_backend(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: List[str],
limit: int,
from_step: Optional[float] = None,
):
...

def fetch_last(self) -> Optional[T]:
return self.last


class FloatSeries(Series[float]):
...
def _fetch_values_from_backend(
self,
backend: "HostedNeptuneBackend",
container_id: str,
container_type: ContainerType,
path: List[str],
limit: int,
from_step: Optional[float] = None,
) -> FloatSeriesValues:
return backend.get_float_series_values(
container_id=container_id,
container_type=container_type,
path=path,
from_step=from_step,
limit=limit,
)


@dataclass
Expand All @@ -82,3 +156,19 @@ class Float(Field[float]):

class String(Field[str]):
...


class Bool(Field[bool]):
...


class DateTime(Field[datetime]):
...


class ObjectState(Field[str]):
...


class StringSet(Field[Set[str]]):
...
Loading
Loading