add torch.no_grad decorator for inference

pull/10/head
erogol 2020-02-12 10:29:30 +01:00
parent c553c7ecd4
commit 566c2a4678
2 changed files with 2 additions and 0 deletions

View File

@ -132,6 +132,7 @@ class Tacotron(nn.Module):
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad()
def inference(self, characters, speaker_ids=None, style_mel=None):
inputs = self.embedding(characters)
self._init_states()

View File

@ -82,6 +82,7 @@ class Tacotron2(nn.Module):
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad()
def inference(self, text, speaker_ids=None):
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder.inference(embedded_inputs)