Skip to content

Commit

Permalink
Merge pull request #4 from google/fix_unittest
Browse files Browse the repository at this point in the history
cleanup after tests to prevent oom
  • Loading branch information
entrpn authored Feb 21, 2024
2 parents 253e793 + 22d4afd commit b3f1414
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ on:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 30 mins
- cron: '*/30 * * * *'
# Run the job every 2 hours
- cron: '0 */2 * * *'

jobs:
build:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
-->

[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/sshahrokhi/maxdiffusion/actions/workflows/UnitTests.yml)
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/google/maxdiffusion/actions/workflows/UnitTests.yml)

# Overview

Expand Down
11 changes: 11 additions & 0 deletions src/maxdiffusion/tests/input_pipeline_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""
import os
import pathlib
import shutil
import unittest
from absl.testing import absltest

Expand All @@ -25,7 +27,12 @@
from maxdiffusion.input_pipeline.input_pipeline_interface import make_pokemon_train_iterator
from maxdiffusion import FlaxStableDiffusionPipeline

HOME_DIR = pathlib.Path.home()
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
DATASET_DIR = str(HOME_DIR / ".cache" / "huggingface" / "datasets")

def cleanup(output_dir):
shutil.rmtree(output_dir)

class InputPipelineInterface(unittest.TestCase):
"""Test Unet sharding"""
Expand Down Expand Up @@ -65,6 +72,8 @@ def test_make_pokemon_train_iterator(self):
assert data["input_ids"].shape == (device_count,77)
assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution)

cleanup(DATASET_DIR)

def test_make_pokemon_train_iterator_w_latents_caching(self):
pyconfig.initialize([None,os.path.join(THIS_DIR,'..','configs','base21.yml'),
"pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1",
Expand Down Expand Up @@ -104,5 +113,7 @@ def test_make_pokemon_train_iterator_w_latents_caching(self):
config.resolution // vae_scale_factor,
config.resolution // vae_scale_factor)

cleanup(DATASET_DIR)

if __name__ == '__main__':
absltest.main()
6 changes: 6 additions & 0 deletions src/maxdiffusion/tests/train_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

""" Smoke test """
import os
import pathlib
import shutil
import unittest
from maxdiffusion.models.train import main as train_main
Expand All @@ -27,6 +28,7 @@
import numpy as np
from PIL import Image

HOME_DIR = pathlib.Path.home()
THIS_DIR = os.path.dirname(os.path.abspath(__file__))

def cleanup(output_dir):
Expand Down Expand Up @@ -64,6 +66,8 @@ def test_sd21_config(self):
assert ssim_compare >=0.70

cleanup(output_dir)
dataset_dir = str(HOME_DIR / ".cache" / "huggingface" / "datasets")
cleanup(dataset_dir)

def test_sd_2_base_config(self):
output_dir="train-smoke-test"
Expand All @@ -89,6 +93,8 @@ def test_sd_2_base_config(self):
assert ssim_compare >=0.70

cleanup(output_dir)
dataset_dir = str(HOME_DIR / ".cache" / "huggingface" / "datasets")
cleanup(dataset_dir)

if __name__ == '__main__':
absltest.main()

0 comments on commit b3f1414

Please sign in to comment.