mirror of https://github.com/coqui-ai/TTS.git
Changesat windowing and some comments
parent
455667d2a4
commit
3c2d500f53
|
@ -143,8 +143,8 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
def init_win_idx(self):
|
def init_win_idx(self):
|
||||||
self.win_idx = -1
|
self.win_idx = -1
|
||||||
self.win_back = 1
|
self.win_back = 2
|
||||||
self.win_front = 3
|
self.win_front = 6
|
||||||
|
|
||||||
def init_forward_attn_state(self, inputs):
|
def init_forward_attn_state(self, inputs):
|
||||||
"""
|
"""
|
||||||
|
@ -165,7 +165,7 @@ class Attention(nn.Module):
|
||||||
energies = energies.squeeze(-1)
|
energies = energies.squeeze(-1)
|
||||||
return energies, processed_query
|
return energies, processed_query
|
||||||
|
|
||||||
def apply_windowing(self, attention):
|
def apply_windowing(self, attention, inputs):
|
||||||
back_win = self.win_idx - self.win_back
|
back_win = self.win_idx - self.win_back
|
||||||
front_win = self.win_idx + self.win_front
|
front_win = self.win_idx + self.win_front
|
||||||
if back_win > 0:
|
if back_win > 0:
|
||||||
|
@ -199,10 +199,13 @@ class Attention(nn.Module):
|
||||||
attention, processed_query = self.get_attention(
|
attention, processed_query = self.get_attention(
|
||||||
attention_hidden_state, processed_inputs, attention_cat)
|
attention_hidden_state, processed_inputs, attention_cat)
|
||||||
|
|
||||||
|
# apply masking
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||||
|
# apply windowing - only in eval mode
|
||||||
if not self.training and self.windowing:
|
if not self.training and self.windowing:
|
||||||
attention = self.apply_windowing(attention)
|
attention = self.apply_windowing(attention, inputs)
|
||||||
|
# normalize attention values
|
||||||
if self.norm == "softmax":
|
if self.norm == "softmax":
|
||||||
alignment = torch.softmax(attention, dim=-1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
elif self.norm == "sigmoid":
|
elif self.norm == "sigmoid":
|
||||||
|
@ -210,6 +213,7 @@ class Attention(nn.Module):
|
||||||
attention).sum(dim=1).unsqueeze(1)
|
attention).sum(dim=1).unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown value for attention norm type")
|
raise RuntimeError("Unknown value for attention norm type")
|
||||||
|
# apply forward attention if enabled
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
return self.apply_forward_attention(inputs, alignment, processed_query)
|
return self.apply_forward_attention(inputs, alignment, processed_query)
|
||||||
else:
|
else:
|
||||||
|
@ -456,7 +460,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
if stop_count > 20:
|
if stop_count > 2:
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
@ -481,16 +485,17 @@ class Decoder(nn.Module):
|
||||||
self._init_states(inputs, mask=None, keep_states=True)
|
self._init_states(inputs, mask=None, keep_states=True)
|
||||||
|
|
||||||
self.attention_layer.init_win_idx()
|
self.attention_layer.init_win_idx()
|
||||||
self.attention_layer.init_forward_attn_state()
|
if self.attention_layer.forward_attn:
|
||||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
self.attention_layer.init_forward_attn_state(inputs)
|
||||||
|
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||||
stop_flags = [False, False, False]
|
stop_flags = [False, False, False]
|
||||||
stop_count = 0
|
stop_count = 0
|
||||||
while True:
|
while True:
|
||||||
memory = self.prenet(self.memory_truncated)
|
memory = self.prenet(self.memory_truncated)
|
||||||
mel_output, gate_output, alignment = self.decode(memory)
|
mel_output, stop_token, alignment = self.decode(memory)
|
||||||
gate_output = torch.sigmoid(gate_output.data)
|
stop_token = torch.sigmoid(stop_token.data)
|
||||||
outputs += [mel_output.squeeze(1)]
|
outputs += [mel_output.squeeze(1)]
|
||||||
gate_outputs += [gate_output]
|
stop_tokens += [stop_token]
|
||||||
alignments += [alignment]
|
alignments += [alignment]
|
||||||
|
|
||||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||||
|
@ -498,7 +503,7 @@ class Decoder(nn.Module):
|
||||||
stop_flags[2] = t > inputs.shape[1] * 2
|
stop_flags[2] = t > inputs.shape[1] * 2
|
||||||
if all(stop_flags):
|
if all(stop_flags):
|
||||||
stop_count += 1
|
stop_count += 1
|
||||||
if stop_count > 20:
|
if stop_count > 2:
|
||||||
break
|
break
|
||||||
elif len(outputs) == self.max_decoder_steps:
|
elif len(outputs) == self.max_decoder_steps:
|
||||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||||
|
@ -507,10 +512,10 @@ class Decoder(nn.Module):
|
||||||
self.memory_truncated = mel_output
|
self.memory_truncated = mel_output
|
||||||
t += 1
|
t += 1
|
||||||
|
|
||||||
outputs, gate_outputs, alignments = self._parse_outputs(
|
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||||
outputs, gate_outputs, alignments)
|
outputs, stop_tokens, alignments)
|
||||||
|
|
||||||
return outputs, gate_outputs, alignments
|
return outputs, stop_tokens, alignments
|
||||||
|
|
||||||
|
|
||||||
def inference_step(self, inputs, t, memory=None):
|
def inference_step(self, inputs, t, memory=None):
|
||||||
|
|
|
@ -252,12 +252,14 @@ def setup_model(num_chars, c):
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
memory_size=c.memory_size)
|
memory_size=c.memory_size)
|
||||||
elif c.model.lower() == "tacotron2":
|
elif c.model.lower() == "tacotron2":
|
||||||
model = MyModel(
|
model = MyModel(
|
||||||
num_chars=num_chars,
|
num_chars=num_chars,
|
||||||
r=c.r,
|
r=c.r,
|
||||||
|
attn_win=c.windowing,
|
||||||
attn_norm=c.attention_norm,
|
attn_norm=c.attention_norm,
|
||||||
prenet_type=c.prenet_type,
|
prenet_type=c.prenet_type,
|
||||||
forward_attn=c.use_forward_attn,
|
forward_attn=c.use_forward_attn,
|
||||||
|
|
Loading…
Reference in New Issue