diff --git a/main.py b/main.py index ff978f6..3fef992 100644 --- a/main.py +++ b/main.py @@ -275,11 +275,11 @@ def convert_from_color(x): prediction = clf.predict(img.reshape(-1, N_BANDS)) prediction = prediction.reshape(img.shape[:2]) else: - # Neural network - model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams) if CLASS_BALANCING: weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS) hyperparams['weights'] = torch.from_numpy(weights) + # Neural network + model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams) # Split train set in train/val train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random') # Generate the dataset