use decorater for torch.no_grad

pull/10/head
erogol 2020-02-07 14:21:57 +01:00
parent abf8ea4633
commit 2cec58320b
1 changed files with 112 additions and 112 deletions

View File

@ -327,6 +327,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
return keep_avg['avg_postnet_loss'], global_step return keep_avg['avg_postnet_loss'], global_step
@torch.no_grad()
def evaluate(model, criterion, criterion_st, ap, global_step, epoch): def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
data_loader = setup_loader(ap, model.decoder.r, is_val=True) data_loader = setup_loader(ap, model.decoder.r, is_val=True)
if c.use_speaker_embedding: if c.use_speaker_embedding:
@ -346,7 +347,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
keep_avg.add_values(eval_values_dict) keep_avg.add_values(eval_values_dict)
print("\n > Validation") print("\n > Validation")
with torch.no_grad():
if data_loader is not None: if data_loader is not None:
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()