diff --git a/train.py b/train.py index 7aed721..81f5290 100644 --- a/train.py +++ b/train.py @@ -152,7 +152,7 @@ def main(): print('* check forward path...', end=' ') di = train_config.in_size do = train_config.out_size - dx = model.xp.zeros((args.batch_size, 3, di, di), dtype=np.float32) + dx = model.xp.zeros((args.batch_size, ch, di, di), dtype=np.float32) dy = model(dx) if dy.shape[2:] != (do, do): raise ValueError('Invlid output size\n'