mirror of https://github.com/coqui-ai/TTS.git
stop conditioning with padding for inference_truncated
parent
5212a11836
commit
68f8ef730d
|
@ -125,8 +125,8 @@ class Attention(nn.Module):
|
||||||
self._mask_value = -float("inf")
|
self._mask_value = -float("inf")
|
||||||
self.windowing = windowing
|
self.windowing = windowing
|
||||||
if self.windowing:
|
if self.windowing:
|
||||||
self.win_back = 1
|
self.win_back = 3
|
||||||
self.win_front = 3
|
self.win_front = 6
|
||||||
self.win_idx = None
|
self.win_idx = None
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
|
|
||||||
|
@ -405,7 +405,7 @@ class Decoder(nn.Module):
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
|
@ -436,6 +436,7 @@ class Decoder(nn.Module):
|
||||||
self.attention_layer.init_win_idx()
|
self.attention_layer.init_win_idx()
|
||||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
outputs, gate_outputs, alignments, t = [], [], [], 0
|
||||||
stop_flags = [False, False]
|
stop_flags = [False, False]
|
||||||
|
stop_count = 0
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(self.memory_truncated)
|
memory = self.prenet(self.memory_truncated)
|
||||||
mel_output, gate_output, alignment = self.decode(memory)
|
mel_output, gate_output, alignment = self.decode(memory)
|
||||||
|
@ -444,14 +445,16 @@ class Decoder(nn.Module):
|
||||||
gate_outputs += [gate_output]
|
gate_outputs += [gate_output]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
||||||
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
break
|
stop_count += 1
|
||||||
|
if stop_count > 20:
|
||||||
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
break
|
break
|
||||||
|
|
||||||
self.memory_truncated = mel_output
|
self.memory_truncated = mel_output
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue