diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 70c26743..3a3fc62c 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -125,8 +125,8 @@ class Attention(nn.Module): self._mask_value = -float("inf") self.windowing = windowing if self.windowing: - self.win_back = 1 - self.win_front = 3 + self.win_back = 3 + self.win_front = 6 self.win_idx = None self.norm = norm @@ -405,7 +405,7 @@ class Decoder(nn.Module): alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 - stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1]) + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1]) stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 @@ -436,6 +436,7 @@ class Decoder(nn.Module): self.attention_layer.init_win_idx() outputs, gate_outputs, alignments, t = [], [], [], 0 stop_flags = [False, False] + stop_count = 0 while True: memory = self.prenet(self.memory_truncated) mel_output, gate_output, alignment = self.decode(memory) @@ -444,14 +445,16 @@ class Decoder(nn.Module): gate_outputs += [gate_output] alignments += [alignment] - stop_flags[0] = stop_flags[0] or gate_output > 0.5 - stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5 + stop_flags[0] = stop_flags[0] or stop_token > 0.5 + stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1]) + stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): - break + stop_count += 1 + if stop_count > 20: + break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break - self.memory_truncated = mel_output t += 1