Use torch.no_grad for VITS inference

pull/1573/head
Eren Gölge 2022-05-11 11:29:36 +02:00
parent 3f03e3012c
commit 5021a03de0
1 changed files with 1 additions and 0 deletions

View File

@ -982,6 +982,7 @@ class Vits(BaseTTS):
return aux_input["x_lengths"]
return torch.tensor(x.shape[1:2]).to(x.device)
@torch.no_grad()
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None, "language_ids": None}
): # pylint: disable=dangerous-default-value