Fix all-zero duration case for GlowTTS

pull/847/head
Eren Gölge 2021-10-01 09:20:07 +00:00
parent 37959ad0c7
commit 4dbe7ed0de
2 changed files with 3 additions and 2 deletions

View File

@ -310,7 +310,7 @@ class GlowTTS(BaseTTS):
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
w_ceil = torch.clamp_min(torch.ceil(w), 1)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
# compute masks

View File

@ -10,7 +10,7 @@ output_path = os.path.join(get_tests_output_path(), "train_outputs")
config = GlowTTSConfig(
batch_size=8,
batch_size=2,
eval_batch_size=8,
num_loader_workers=0,
num_eval_loader_workers=0,
@ -27,6 +27,7 @@ config = GlowTTSConfig(
test_sentences=[
"Be a voice, not an echo.",
],
data_dep_init_steps=1.0,
)
config.audio.do_trim_silence = True
config.audio.trim_db = 60