setup gradual training schedule wrt num gpus

pull/10/head
Eren Golge 2019-11-20 12:08:23 +01:00
parent a31d1cb2d9
commit e41a7c1a02
1 changed files with 7 additions and 1 deletions

View File

@ -347,9 +347,15 @@ def split_dataset(items):
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
num_gpus = 1
new_values = None
# we set the scheduling wrt num_gpus
for values in config.gradual_training:
if global_step >= values[0]:
if global_step * num_gpus >= values[0]:
new_values = values
return new_values[1], new_values[2]