use stop token again

pull/10/head
Eren Golge 2019-04-18 15:20:19 +02:00
parent e05263769e
commit f450fe3571
1 changed files with 2 additions and 2 deletions

View File

@ -445,7 +445,7 @@ class Decoder(nn.Module):
self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False]
stop_flags = [False, False, False]
stop_count = 0
while True:
memory = self.prenet(memory)
@ -456,7 +456,7 @@ class Decoder(nn.Module):
alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
stop_count += 1