stop conditioning with padding for inference_truncated

pull/10/head
Eren Golge 2019-04-01 14:10:38 +02:00
parent 5212a11836
commit 68f8ef730d
1 changed files with 10 additions and 7 deletions

View File

@ -125,8 +125,8 @@ class Attention(nn.Module):
self._mask_value = -float("inf") self._mask_value = -float("inf")
self.windowing = windowing self.windowing = windowing
if self.windowing: if self.windowing:
self.win_back = 1 self.win_back = 3
self.win_front = 3 self.win_front = 6
self.win_idx = None self.win_idx = None
self.norm = norm self.norm = norm
@ -405,7 +405,7 @@ class Decoder(nn.Module):
alignments += [alignment] alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5 stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1]) stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2 stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags): if all(stop_flags):
stop_count += 1 stop_count += 1
@ -436,6 +436,7 @@ class Decoder(nn.Module):
self.attention_layer.init_win_idx() self.attention_layer.init_win_idx()
outputs, gate_outputs, alignments, t = [], [], [], 0 outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False] stop_flags = [False, False]
stop_count = 0
while True: while True:
memory = self.prenet(self.memory_truncated) memory = self.prenet(self.memory_truncated)
mel_output, gate_output, alignment = self.decode(memory) mel_output, gate_output, alignment = self.decode(memory)
@ -444,14 +445,16 @@ class Decoder(nn.Module):
gate_outputs += [gate_output] gate_outputs += [gate_output]
alignments += [alignment] alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5 stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5 stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags): if all(stop_flags):
break stop_count += 1
if stop_count > 20:
break
elif len(outputs) == self.max_decoder_steps: elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
break break
self.memory_truncated = mel_output self.memory_truncated = mel_output
t += 1 t += 1