Changesat windowing and some comments

pull/10/head
Eren Golge 2019-04-12 16:13:40 +02:00
parent 455667d2a4
commit 3c2d500f53
2 changed files with 21 additions and 14 deletions

View File

@ -143,8 +143,8 @@ class Attention(nn.Module):
def init_win_idx(self): def init_win_idx(self):
self.win_idx = -1 self.win_idx = -1
self.win_back = 1 self.win_back = 2
self.win_front = 3 self.win_front = 6
def init_forward_attn_state(self, inputs): def init_forward_attn_state(self, inputs):
""" """
@ -165,7 +165,7 @@ class Attention(nn.Module):
energies = energies.squeeze(-1) energies = energies.squeeze(-1)
return energies, processed_query return energies, processed_query
def apply_windowing(self, attention): def apply_windowing(self, attention, inputs):
back_win = self.win_idx - self.win_back back_win = self.win_idx - self.win_back
front_win = self.win_idx + self.win_front front_win = self.win_idx + self.win_front
if back_win > 0: if back_win > 0:
@ -199,10 +199,13 @@ class Attention(nn.Module):
attention, processed_query = self.get_attention( attention, processed_query = self.get_attention(
attention_hidden_state, processed_inputs, attention_cat) attention_hidden_state, processed_inputs, attention_cat)
# apply masking
if mask is not None: if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value) attention.data.masked_fill_(1 - mask, self._mask_value)
# apply windowing - only in eval mode
if not self.training and self.windowing: if not self.training and self.windowing:
attention = self.apply_windowing(attention) attention = self.apply_windowing(attention, inputs)
# normalize attention values
if self.norm == "softmax": if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1) alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid": elif self.norm == "sigmoid":
@ -210,6 +213,7 @@ class Attention(nn.Module):
attention).sum(dim=1).unsqueeze(1) attention).sum(dim=1).unsqueeze(1)
else: else:
raise RuntimeError("Unknown value for attention norm type") raise RuntimeError("Unknown value for attention norm type")
# apply forward attention if enabled
if self.forward_attn: if self.forward_attn:
return self.apply_forward_attention(inputs, alignment, processed_query) return self.apply_forward_attention(inputs, alignment, processed_query)
else: else:
@ -456,7 +460,7 @@ class Decoder(nn.Module):
stop_flags[2] = t > inputs.shape[1] * 2 stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags): if all(stop_flags):
stop_count += 1 stop_count += 1
if stop_count > 20: if stop_count > 2:
break break
elif len(outputs) == self.max_decoder_steps: elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
@ -481,16 +485,17 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None, keep_states=True) self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx() self.attention_layer.init_win_idx()
self.attention_layer.init_forward_attn_state() if self.attention_layer.forward_attn:
outputs, gate_outputs, alignments, t = [], [], [], 0 self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [False, False, False] stop_flags = [False, False, False]
stop_count = 0 stop_count = 0
while True: while True:
memory = self.prenet(self.memory_truncated) memory = self.prenet(self.memory_truncated)
mel_output, gate_output, alignment = self.decode(memory) mel_output, stop_token, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data) stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)] outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output] stop_tokens += [stop_token]
alignments += [alignment] alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5 stop_flags[0] = stop_flags[0] or stop_token > 0.5
@ -498,7 +503,7 @@ class Decoder(nn.Module):
stop_flags[2] = t > inputs.shape[1] * 2 stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags): if all(stop_flags):
stop_count += 1 stop_count += 1
if stop_count > 20: if stop_count > 2:
break break
elif len(outputs) == self.max_decoder_steps: elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps") print(" | > Decoder stopped with 'max_decoder_steps")
@ -507,10 +512,10 @@ class Decoder(nn.Module):
self.memory_truncated = mel_output self.memory_truncated = mel_output
t += 1 t += 1
outputs, gate_outputs, alignments = self._parse_outputs( outputs, stop_tokens, alignments = self._parse_outputs(
outputs, gate_outputs, alignments) outputs, stop_tokens, alignments)
return outputs, gate_outputs, alignments return outputs, stop_tokens, alignments
def inference_step(self, inputs, t, memory=None): def inference_step(self, inputs, t, memory=None):

View File

@ -252,12 +252,14 @@ def setup_model(num_chars, c):
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
r=c.r, r=c.r,
attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
memory_size=c.memory_size) memory_size=c.memory_size)
elif c.model.lower() == "tacotron2": elif c.model.lower() == "tacotron2":
model = MyModel( model = MyModel(
num_chars=num_chars, num_chars=num_chars,
r=c.r, r=c.r,
attn_win=c.windowing,
attn_norm=c.attention_norm, attn_norm=c.attention_norm,
prenet_type=c.prenet_type, prenet_type=c.prenet_type,
forward_attn=c.use_forward_attn, forward_attn=c.use_forward_attn,