Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pre-commit.ci] pre-commit autoupdate #688

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0 # Use the ref you want to point at
rev: v5.0.0 # Use the ref you want to point at
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -16,7 +16,7 @@ repos:
- id: check-toml

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.6.5'
rev: 'v0.8.0'
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -25,7 +25,7 @@ repos:
types_or: [python, jupyter]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
entry: python3 -m mypy --config-file pyproject.toml
Expand All @@ -41,7 +41,7 @@ repos:
entry: python3 -m nbstripout

- repo: https://github.com/nbQA-dev/nbQA
rev: 1.8.7
rev: 1.9.1
hooks:
- id: nbqa-black
- id: nbqa-ruff
Expand Down
13 changes: 8 additions & 5 deletions cyclops/data/features/medical_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,14 @@ def decode_example(
use_auth_token = token_per_repo_id.get(repo_id)
except ValueError:
use_auth_token = None
with xopen(
path,
"rb",
use_auth_token=use_auth_token,
) as file_obj, BytesIO(file_obj.read()) as buffer:
with (
xopen(
path,
"rb",
use_auth_token=use_auth_token,
) as file_obj,
BytesIO(file_obj.read()) as buffer,
):
image, metadata = self._read_file_from_bytes(buffer)
metadata["filename_or_obj"] = path

Expand Down
8 changes: 5 additions & 3 deletions cyclops/evaluate/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def get_columns_as_array(
if isinstance(columns, str):
columns = [columns]

with dataset.formatted_as("arrow", columns=columns, output_all_columns=True) if (
isinstance(dataset, Dataset) and dataset.format != "arrow"
) else nullcontext():
with (
dataset.formatted_as("arrow", columns=columns, output_all_columns=True)
if (isinstance(dataset, Dataset) and dataset.format != "arrow")
else nullcontext()
):
out_arr = squeeze_all(
xp.stack(
[xp.asarray(dataset[col].to_pylist()) for col in columns], axis=-1
Expand Down
21 changes: 12 additions & 9 deletions cyclops/models/wrappers/pt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,14 +968,17 @@ def fit(
splits_mapping["validation"] = val_split

format_kwargs = {} if transforms is None else {"transform": transforms}
with X[train_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
), X[val_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
with (
X[train_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
),
X[val_split].formatted_as(
"custom" if transforms is not None else "torch",
columns=feature_columns + target_columns,
**format_kwargs,
),
):
self.partial_fit(
X,
Expand Down Expand Up @@ -1309,7 +1312,7 @@ def save_model(self, filepath: str, overwrite: bool = True, **kwargs):
if include_lr_scheduler:
state_dict["lr_scheduler"] = self.lr_scheduler_.state_dict() # type: ignore[attr-defined]

epoch = kwargs.get("epoch", None)
epoch = kwargs.get("epoch")
if epoch is not None:
filename, extension = os.path.basename(filepath).split(".")
filepath = join(
Expand Down
Loading