Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Fix class balancing weights not being taken into account (#15)
Browse files Browse the repository at this point in the history
Co-authored-by: Ayush Pandey <[email protected]>
  • Loading branch information
Ayush-iitkgp and Ayush Pandey authored Jun 26, 2020
1 parent 64af478 commit 7d171e8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ def convert_from_color(x):
prediction = clf.predict(img.reshape(-1, N_BANDS))
prediction = prediction.reshape(img.shape[:2])
else:
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
if CLASS_BALANCING:
weights = compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
hyperparams['weights'] = torch.from_numpy(weights)
# Neural network
model, optimizer, loss, hyperparams = get_model(MODEL, **hyperparams)
# Split train set in train/val
train_gt, val_gt = sample_gt(train_gt, 0.95, mode='random')
# Generate the dataset
Expand Down

0 comments on commit 7d171e8

Please sign in to comment.