-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
62 lines (47 loc) · 1.75 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from config import config
from utils import logger
import torch
import lightning as L
import lightning.pytorch as pl
from src import SpriteLightning, SpriteDataModule
from utils import make_clear_directory
torch.set_float32_matmul_precision('medium')
def train():
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
dm = SpriteDataModule()
light = SpriteLightning()
trainer = pl.Trainer(
default_root_dir=config.paths.roots.output,
logger=L.pytorch.loggers.CSVLogger(save_dir=config.paths.output.logs),
devices='auto',
accelerator="auto",
max_epochs=config.train.max_epochs,
log_every_n_steps=config.train.log_every_n_steps,
check_val_every_n_epoch=config.train.check_val_every_n_epoch,
accumulate_grad_batches=config.train.accumulate_grad_batches,
num_sanity_val_steps=config.train.num_sanity_val_steps,
enable_model_summary=False,
fast_dev_run=config.train.fast_dev_run,
overfit_batches=config.train.overfit_batches,
)
trainer.fit(
light,
datamodule=dm,
# ckpt_path='./output/checkpoints/last-v1.ckpt',
)
# noinspection PyUnresolvedReferences
if trainer.checkpoint_callback.best_model_path:
# noinspection PyUnresolvedReferences
logger.info(f"Best model path : {trainer.checkpoint_callback.best_model_path}")
def prep_directories():
logger.info("Clearing Directories")
make_clear_directory(config.paths.output.logs)
make_clear_directory(config.paths.output.val_images)
make_clear_directory(config.paths.output.test_images)
def main():
torch.cuda.empty_cache()
prep_directories()
train()
if __name__ == '__main__':
main()