diff --git a/src/maxdiffusion/eval.py b/src/maxdiffusion/eval.py index 3c052892..ed918c8c 100644 --- a/src/maxdiffusion/eval.py +++ b/src/maxdiffusion/eval.py @@ -38,6 +38,7 @@ from keras.preprocessing.image import ImageDataGenerator from tqdm import tqdm from PIL import Image +cc.initialize_cache(os.path.expanduser("~/jax_cache")) def load_captions(file_path): captions_df = pd.read_csv(file_path, delimiter='\t', header=0, names=['image_id','id', 'caption'])