diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 0a6982e2..6122c38b 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -143,8 +143,8 @@ class Attention(nn.Module): def init_win_idx(self): self.win_idx = -1 - self.win_back = 1 - self.win_front = 3 + self.win_back = 2 + self.win_front = 6 def init_forward_attn_state(self, inputs): """ @@ -165,7 +165,7 @@ class Attention(nn.Module): energies = energies.squeeze(-1) return energies, processed_query - def apply_windowing(self, attention): + def apply_windowing(self, attention, inputs): back_win = self.win_idx - self.win_back front_win = self.win_idx + self.win_front if back_win > 0: @@ -199,10 +199,13 @@ class Attention(nn.Module): attention, processed_query = self.get_attention( attention_hidden_state, processed_inputs, attention_cat) + # apply masking if mask is not None: attention.data.masked_fill_(1 - mask, self._mask_value) + # apply windowing - only in eval mode if not self.training and self.windowing: - attention = self.apply_windowing(attention) + attention = self.apply_windowing(attention, inputs) + # normalize attention values if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": @@ -210,6 +213,7 @@ class Attention(nn.Module): attention).sum(dim=1).unsqueeze(1) else: raise RuntimeError("Unknown value for attention norm type") + # apply forward attention if enabled if self.forward_attn: return self.apply_forward_attention(inputs, alignment, processed_query) else: @@ -456,7 +460,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 - if stop_count > 20: + if stop_count > 2: break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -481,16 +485,17 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None, keep_states=True) self.attention_layer.init_win_idx() - self.attention_layer.init_forward_attn_state() - outputs, gate_outputs, alignments, t = [], [], [], 0 + if self.attention_layer.forward_attn: + self.attention_layer.init_forward_attn_state(inputs) + outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [False, False, False] stop_count = 0 while True: memory = self.prenet(self.memory_truncated) - mel_output, gate_output, alignment = self.decode(memory) - gate_output = torch.sigmoid(gate_output.data) + mel_output, stop_token, alignment = self.decode(memory) + stop_token = torch.sigmoid(stop_token.data) outputs += [mel_output.squeeze(1)] - gate_outputs += [gate_output] + stop_tokens += [stop_token] alignments += [alignment] stop_flags[0] = stop_flags[0] or stop_token > 0.5 @@ -498,7 +503,7 @@ class Decoder(nn.Module): stop_flags[2] = t > inputs.shape[1] * 2 if all(stop_flags): stop_count += 1 - if stop_count > 20: + if stop_count > 2: break elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -507,10 +512,10 @@ class Decoder(nn.Module): self.memory_truncated = mel_output t += 1 - outputs, gate_outputs, alignments = self._parse_outputs( - outputs, gate_outputs, alignments) + outputs, stop_tokens, alignments = self._parse_outputs( + outputs, stop_tokens, alignments) - return outputs, gate_outputs, alignments + return outputs, stop_tokens, alignments def inference_step(self, inputs, t, memory=None): diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ef686962..f22c4a3a 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -252,12 +252,14 @@ def setup_model(num_chars, c): model = MyModel( num_chars=num_chars, r=c.r, + attn_win=c.windowing, attn_norm=c.attention_norm, memory_size=c.memory_size) elif c.model.lower() == "tacotron2": model = MyModel( num_chars=num_chars, r=c.r, + attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn,