Skip to content

Commit

Permalink
[WIP] Improve gpu utilization πŸš‚ (#87)
Browse files Browse the repository at this point in the history
Adresses #7
  • Loading branch information
KarelZe authored Dec 22, 2022
1 parent c8771fe commit 85c44bc
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/otc/models/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import optuna
import torch
import wandb
from catboost import CatBoostClassifier
from torch import nn

import wandb
from otc.config.config import settings
from otc.data.fs import fs
from otc.utils.colors import Colors
Expand Down Expand Up @@ -179,7 +179,7 @@ def on_train_end(
).as_posix()

fs.put(loc_training_stats, uri_training_stats)
m_artifact = wandb.Artifact(name=file_model, type="model")
m_artifact = wandb.Artifact(name=file_model, type="model") # type: ignore # noqa: E501

m_artifact.add_reference(uri_training_stats, name=file_training_stats)
logger.info(
Expand All @@ -199,7 +199,7 @@ def on_train_end(
with fs.open(uri_model, "wb") as f:
torch.save(model.state_dict(), f)

m_artifact = wandb.Artifact(name=file_model, type="model")
m_artifact = wandb.Artifact(name=file_model, type="model") # type: ignore # noqa: E501
else:
return

Expand Down
15 changes: 10 additions & 5 deletions src/otc/models/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,9 @@ def __call__(self, trial: optuna.Trial) -> float:
weight_decay: float = trial.suggest_float("weight_decay", 1e-6, 1e-1)
lr = trial.suggest_float("lr", 1e-6, 4e-3, log=False)
dropout = trial.suggest_float("dropout", 0, 0.5, step=0.1)
batch_size: int = trial.suggest_categorical("batch_size", [8192, 16384, 32768]) # type: ignore # noqa: E501
batch_size: int = trial.suggest_categorical("batch_size", [16192, 32768, 65536]) # type: ignore # noqa: E501

no_devices = torch.cuda.device_count()
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

Expand All @@ -193,7 +194,10 @@ def __call__(self, trial: optuna.Trial) -> float:
)

dl_kwargs: dict[str, Any] = {
"batch_size": batch_size,
"batch_size": batch_size
* max(
no_devices, 1
), # dataparallel splits tensors across devices by dim 1.
"shuffle": False,
"device": device,
}
Expand Down Expand Up @@ -373,11 +377,11 @@ def __call__(self, trial: optuna.Trial) -> float:

weight_decay: float = trial.suggest_float("weight_decay", 1e-6, 1e-1)
lr = trial.suggest_float("lr", 1e-6, 4e-3, log=False)
bs = [8192, 16384, 32768]
batch_size: int = trial.suggest_categorical("batch_size", bs) # type: ignore
batch_size: int = trial.suggest_categorical("batch_size", [16192, 32768, 65536]) # type: ignore # noqa: E501

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
no_devices = torch.cuda.device_count()

training_data = TabDataset(
self.x_train, self.y_train, self._cat_features, self._cat_cardinalities
Expand All @@ -387,7 +391,8 @@ def __call__(self, trial: optuna.Trial) -> float:
)

dl_kwargs: dict[str, Any] = {
"batch_size": batch_size,
"batch_size": batch_size
* max(no_devices, 1), # dataprallel splits batches across devices
"shuffle": False,
"device": device,
}
Expand Down
14 changes: 7 additions & 7 deletions src/otc/models/tabtransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import einsum, nn

from otc.models.activation import GeGLU
Expand Down Expand Up @@ -169,16 +168,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: output tensor.
"""
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
b, n, _ = q.shape
# reshape and permute: b n (h d) -> b h n d
q, k, v = map(
lambda t: t.reshape(b, n, self.heads, -1).permute(0, 2, 1, 3), (q, k, v)
)
sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

attn = sim.softmax(dim=-1)
attn = self.dropout(attn)

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)", h=h)
# reshape and permute: b h i j, b h j d -> b h i d
out = out.permute(0, 2, 1, 3).reshape(b, n, -1)
return self.to_out(out)


Expand Down Expand Up @@ -414,7 +415,6 @@ def __init__(
)

# mlp to logits

input_size = (dim * self.num_categories) + num_continuous
j = input_size // 8

Expand Down
6 changes: 3 additions & 3 deletions tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setup(self) -> None:
"""
self._old_cwd = os.getcwd()
start = dt.datetime(2020, 1, 1).replace(tzinfo=dt.timezone.utc)
end = dt.datetime(2020, 12, 31).replace(tzinfo=dt.timezone.utc)
end = dt.datetime(2021, 12, 31).replace(tzinfo=dt.timezone.utc)
index = pd.date_range(start=start, end=end, freq="15min")

# make 1 const feature and 1 non-const feature, as catboost requires non-const
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_fttransformer_objective(self) -> None:
"ffn_dropout": 0.1,
"weight_decay": 0.01,
"learning_rate": 3e-4,
"batch_size": 8192,
"batch_size": 16192,
}

study = optuna.create_study(direction="maximize")
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_tabtransformer_objective(self) -> None:
"dropout": 0.1,
"weight_decay": 0.01,
"learning_rate": 3e-4,
"batch_size": 8192,
"batch_size": 16192,
}

study = optuna.create_study(direction="maximize")
Expand Down

0 comments on commit 85c44bc

Please sign in to comment.