From 566c2a4678856d23d1cce4b22ff1a960855d315a Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 12 Feb 2020 10:29:30 +0100 Subject: [PATCH] add torch.no_grad decorator for inference --- models/tacotron.py | 1 + models/tacotron2.py | 1 + 2 files changed, 2 insertions(+) diff --git a/models/tacotron.py b/models/tacotron.py index a2d9e1c4..04ecd573 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -132,6 +132,7 @@ class Tacotron(nn.Module): return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens + @torch.no_grad() def inference(self, characters, speaker_ids=None, style_mel=None): inputs = self.embedding(characters) self._init_states() diff --git a/models/tacotron2.py b/models/tacotron2.py index 852b1886..3a3863de 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -82,6 +82,7 @@ class Tacotron2(nn.Module): return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward return decoder_outputs, postnet_outputs, alignments, stop_tokens + @torch.no_grad() def inference(self, text, speaker_ids=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs)