Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Use new .to() function for CPU/GPU agnostic computing
Browse files Browse the repository at this point in the history
PyTorch 4.0 introduced the new device API to simplify tensor storage management.
We removed all calls to the old .cuda() function and replace them with the new .to() storage management.
  * All functions that previously took a "cuda=" keyword argument now take a "device=" argument that expects either a torch.device object (or a 'cpu' or 'cuda' string if the object is not available).
  * The --cuda CLI argument now expects an integer. -1 is CPU computing (default if omitted), else it is the ordinal of the GPU on which to perform the computation.
  • Loading branch information
nshaud committed Sep 24, 2018
1 parent eb98f09 commit 2eb12bc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 41 deletions.
27 changes: 14 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@
parser.add_argument('--folder', type=str, help="Folder where to store the "
"datasets (defaults to the current working directory).",
default="./Datasets/")
parser.add_argument('--cuda', action='store_true',
help="Use CUDA (defaults to false)")
parser.add_argument('--cuda', type=int, default=-1,
help="Specify CUDA device (defaults to -1, which learns on CPU)")
parser.add_argument('--runs', type=int, default=1, help="Number of runs (default: 1)")
parser.add_argument('--restore', type=str, default=None,
help="Weights to use for initialization, e.g. a checkpoint")
Expand Down Expand Up @@ -121,7 +121,13 @@
args = parser.parse_args()

# Use GPU ?
CUDA = args.cuda
if args.cuda < 0:
print("Computation on CPU")
CUDA_DEVICE = torch.device('cpu')
else:
print("Computation on CUDA GPU device {}".format(args.cuda))
CUDA_DEVICE = torch.device('cuda:{}'.format(args.cuda))

# % of training samples
SAMPLE_PERCENTAGE = args.training_sample / 100
# Data augmentation ?
Expand Down Expand Up @@ -168,11 +174,6 @@
print("Visdom is not connected. Did you run 'python -m visdom.server' ?")


if CUDA:
print("Using CUDA")
else:
print("Not using CUDA, will run on CPU.")

hyperparams = vars(args)
# Load the dataset
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET,
Expand Down Expand Up @@ -202,7 +203,7 @@ def convert_from_color(x):


# Instantiate the experiment based on predefined networks
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS})
hyperparams.update({'n_classes': N_CLASSES, 'n_bands': N_BANDS, 'ignored_labels': IGNORED_LABELS, 'device': CUDA_DEVICE})
hyperparams = dict((k, v) for k, v in hyperparams.items() if v is not None)

# Show the image and the ground truth
Expand Down Expand Up @@ -285,12 +286,12 @@ def convert_from_color(x):
train_dataset = HyperX(img, train_gt, **hyperparams)
train_loader = data.DataLoader(train_dataset,
batch_size=hyperparams['batch_size'],
pin_memory=hyperparams['cuda'],
#pin_memory=hyperparams['device'],
shuffle=True)
val_dataset = HyperX(img, val_gt, **hyperparams)
val_loader = data.DataLoader(val_dataset,
batch_size=hyperparams['batch_size'],
pin_memory=hyperparams['cuda'])
#pin_memory=hyperparams['device'],
batch_size=hyperparams['batch_size'])

print("Network :")
with torch.no_grad():
Expand All @@ -306,7 +307,7 @@ def convert_from_color(x):

try:
train(model, optimizer, loss, train_loader, hyperparams['epoch'],
scheduler=hyperparams['scheduler'], cuda=hyperparams['cuda'],
scheduler=hyperparams['scheduler'], device=hyperparams['device'],
supervision=hyperparams['supervision'], val_loader=val_loader,
display=viz)
except KeyboardInterrupt:
Expand Down
45 changes: 17 additions & 28 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ def get_model(name, **kwargs):
criterion: PyTorch loss Function
kwargs: hyperparameters with sane defaults
"""
cuda = kwargs.setdefault('cuda', False)
device = kwargs.setdefault('device', torch.device('cpu'))
n_classes = kwargs['n_classes']
n_bands = kwargs['n_bands']
weights = torch.ones(n_classes)
weights[torch.LongTensor(kwargs['ignored_labels'])] = 0.
weights = weights.to(device)
weights = kwargs.setdefault('weights', weights)
if cuda:
kwargs['weights'] = weights.cuda()

if name == 'nn':
kwargs.setdefault('patch_size', 1)
Expand Down Expand Up @@ -102,8 +101,7 @@ def get_model(name, **kwargs):
center_pixel = True
model = HeEtAl(n_bands, n_classes, patch_size=kwargs['patch_size'])
# For Adagrad, we need to load the model on GPU before creating the optimizer
if cuda:
model = model.cuda()
model = model.to(device)
optimizer = optim.Adagrad(model.parameters(), lr=lr, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(weight=kwargs['weights'])
elif name == 'luo':
Expand Down Expand Up @@ -164,15 +162,13 @@ def get_model(name, **kwargs):
lr = kwargs.setdefault('lr', 1.0)
model = MouEtAl(n_bands, n_classes)
# For Adadelta, we need to load the model on GPU before creating the optimizer
if cuda:
model = model.cuda()
model = model.to(device)
optimizer = optim.Adadelta(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(weight=kwargs['weights'])
else:
raise KeyError("{} model is unknown.".format(name))

if cuda:
model = model.cuda()
model = model.to(device)
epoch = kwargs.setdefault('epoch', 100)
kwargs.setdefault('scheduler', optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=epoch//4, verbose=True))
#kwargs.setdefault('scheduler', None)
Expand Down Expand Up @@ -1079,7 +1075,7 @@ def forward(self, x, verbose=False):


def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
display_iter=100, cuda=True, display=None,
display_iter=100, device=torch.device('cpu'), display=None,
val_loader=None, supervision='full'):
"""
Training loop to optimize a network for several epochs and a specified loss
Expand All @@ -1090,7 +1086,7 @@ def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
data_loader: a PyTorch dataset loader
epoch: int specifying the number of training epochs
criterion: a PyTorch-compatible loss function, e.g. nn.CrossEntropyLoss
cuda (optional): bool set to True to use CUDA/CUDNN
device (optional): torch device to use (defaults to CPU)
display_iter (optional): number of iterations before refreshing the
display (False/None to switch off).
scheduler (optional): PyTorch scheduler
Expand All @@ -1101,13 +1097,10 @@ def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
if criterion is None:
raise Exception("Missing criterion. You must specify a loss function.")

if cuda:
net.cuda()
net.to(device)

save_epoch = epoch // 20 if epoch > 20 else 1

# Set the network to training mode
net.train()

losses = np.zeros(1000000)
mean_losses = np.zeros(100000000)
Expand All @@ -1116,13 +1109,14 @@ def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
val_accuracies = []

for e in tqdm(range(1, epoch + 1), desc="Training the network"):
# Set the network to training mode
net.train()
avg_loss = 0.

# Run the training loop for one epoch
for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)):
# Load the data into the GPU if required
if cuda:
data, target = data.cuda(), target.cuda()
data, target = data.to(device), target.to(device)

optimizer.zero_grad()
if supervision == 'full':
Expand Down Expand Up @@ -1174,7 +1168,7 @@ def train(net, optimizer, criterion, data_loader, epoch, scheduler=None,
# Update the scheduler
avg_loss /= len(data_loader)
if val_loader is not None:
val_acc = val(net, val_loader, cuda=cuda, supervision=supervision)
val_acc = val(net, val_loader, device=device, supervision=supervision)
val_accuracies.append(val_acc)
metric = -val_acc
else:
Expand Down Expand Up @@ -1210,7 +1204,7 @@ def test(net, img, hyperparams):
net.eval()
patch_size = hyperparams['patch_size']
center_pixel = hyperparams['center_pixel']
batch_size, cuda = hyperparams['batch_size'], hyperparams['cuda']
batch_size, device = hyperparams['batch_size'], hyperparams['device']
n_classes = hyperparams['n_classes']

kwargs = {'step': hyperparams['test_stride'], 'window_size': (patch_size, patch_size)}
Expand All @@ -1234,15 +1228,11 @@ def test(net, img, hyperparams):
data = data.unsqueeze(1)

indices = [b[1:] for b in batch]
if cuda:
data = data.cuda()
data = data.to(device)
output = net(data)
if isinstance(output, tuple):
output = output[0]
if cuda:
output = output.data.cpu()
else:
output = output.data
output = output.to('cpu')

if patch_size == 1 or center_pixel:
output = output.numpy()
Expand All @@ -1255,15 +1245,14 @@ def test(net, img, hyperparams):
probs[x:x + w, y:y + h] += out
return probs

def val(net, data_loader, cuda=True, supervision='full'):
def val(net, data_loader, device='cpu', supervision='full'):
# TODO : fix me using metrics()
accuracy, total = 0., 0.
ignored_labels = data_loader.dataset.ignored_labels
for batch_idx, (data, target) in enumerate(data_loader):
with torch.no_grad():
# Load the data into the GPU if required
if cuda:
data, target = data.cuda(), target.cuda()
data, target = data.to(device), target.to(device)
if supervision == 'full':
output = net(data)
elif supervision == 'semi':
Expand Down

0 comments on commit 2eb12bc

Please sign in to comment.