Commit 3fedf959 authored by Federico Betti's avatar Federico Betti
Browse files

03 new version

parent 1cf5eeda
......@@ -67,38 +67,7 @@ def train(epoch):
train_loss = Metric('train_loss')
train_accuracy = Metric('train_accuracy')
iteration = epoch * len(train_loader)
with tqdm(total=len(train_loader),
desc='Train Epoch #{}'.format(epoch + 1),
disable=not verbose) as t:
for batch_idx, (data, target) in enumerate(train_loader):
adjust_learning_rate(epoch, batch_idx)
if args.cuda:
data, target = data.cuda(), target.cuda()
optimizer.zero_grad()
# Split data into sub-batches of size batch_size
for i in range(0, len(data), args.batch_size):
data_batch = data[i:i + args.batch_size]
target_batch = target[i:i + args.batch_size]
output = model(data_batch)
train_accuracy.update(accuracy(output, target_batch))
loss = F.cross_entropy(output, target_batch)
train_loss.update(loss)
# Average gradients among sub-batches
loss.div_(math.ceil(float(len(data)) / args.batch_size))
loss.backward()
# Gradient is applied across all ranks
optimizer.step()
t.set_postfix({'loss': train_loss.avg.item(),
'accuracy': 100. * train_accuracy.avg.item()})
t.update(1)
if log_writer:
log_writer.add_scalar('train/loss', train_loss.avg, iteration)
log_writer.add_scalar('train/accuracy', train_accuracy.avg, iteration)
iteration += 1
...
def validate(epoch):
......@@ -106,24 +75,7 @@ def validate(epoch):
val_loss = Metric('val_loss')
val_accuracy = Metric('val_accuracy')
with tqdm(total=len(val_loader),
desc='Validate Epoch #{}'.format(epoch + 1),
disable=not verbose) as t:
with torch.no_grad():
for data, target in val_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
val_loss.update(F.cross_entropy(output, target))
val_accuracy.update(accuracy(output, target))
t.set_postfix({'loss': val_loss.avg.item(),
'accuracy': 100. * val_accuracy.avg.item()})
t.update(1)
if log_writer:
log_writer.add_scalar('val/loss', val_loss.avg, epoch)
log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)
...
# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
......@@ -206,18 +158,13 @@ if __name__ == '__main__':
# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(4)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
# When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent
# issues with Infiniband implementations that are not fork-safe
if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and
mp._supports_context and 'forkserver' in mp.get_all_start_methods()):
kwargs['multiprocessing_context'] = 'forkserver'
crop_size = 64
test_image_size = 64
train_dataset = \
datasets.ImageFolder(args.train_dir,
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomResizedCrop(crop_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
......@@ -231,8 +178,8 @@ if __name__ == '__main__':
val_dataset = \
datasets.ImageFolder(args.val_dir,
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.Resize(test_image_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
......@@ -244,21 +191,13 @@ if __name__ == '__main__':
# model = models.resnet50()
model = ...
# By default, Adasum doesn't need scaling up learning rate.
# For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1
if args.cuda:
# Move model to GPU.
model.cuda()
# If using GPU Adasum allreduce, scale learning rate by local_size.
if args.use_adasum and hvd.nccl_built():
lr_scaler = args.batches_per_allreduce * hvd.local_size()
# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
lr=(args.base_lr *
lr_scaler),
lr=args.base_lr,
momentum=args.momentum, weight_decay=args.wd)
# Horovod: (optional) compression algorithm.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment