mirror of https://github.com/coqui-ai/TTS.git
Look for the last two attention values for stop condition and attend to the first encoder verctor if it is the first decoder iteration
parent
d8add67a2f
commit
4b116a2a88
|
@ -105,7 +105,7 @@ class Attention(nn.Module):
|
|||
self.win_idx = None
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = 0
|
||||
self.win_idx = -1
|
||||
|
||||
def get_attention(self, query, processed_inputs, attention_cat):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
|
@ -132,6 +132,10 @@ class Attention(nn.Module):
|
|||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# this is a trick to solve a special problem.
|
||||
# but it does not hurt.
|
||||
if self.win_idx == -1:
|
||||
attention[:, 0] = attention.max()
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
|
@ -355,7 +359,7 @@ class Decoder(nn.Module):
|
|||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or gate_output > 0.5
|
||||
stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5
|
||||
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
|
||||
if all(stop_flags):
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
|
|
Loading…
Reference in New Issue