mirror of https://github.com/coqui-ai/TTS.git
use decorater for torch.no_grad
parent
abf8ea4633
commit
2cec58320b
2
train.py
2
train.py
|
@ -327,6 +327,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
return keep_avg['avg_postnet_loss'], global_step
|
return keep_avg['avg_postnet_loss'], global_step
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
|
@ -346,7 +347,6 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
keep_avg.add_values(eval_values_dict)
|
keep_avg.add_values(eval_values_dict)
|
||||||
print("\n > Validation")
|
print("\n > Validation")
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
if data_loader is not None:
|
if data_loader is not None:
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
Loading…
Reference in New Issue