mirror of https://github.com/coqui-ai/TTS.git
change stop conditioning
parent
2e361e2306
commit
e1cd253d65
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "bos",
|
||||
"run_description": "bos character added to get away with the first char miss",
|
||||
"run_description": "finetune entropy model due to some spelling mistakes.",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -41,7 +41,7 @@
|
|||
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
|
||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
|
|
|
@ -125,8 +125,8 @@ class Attention(nn.Module):
|
|||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
if self.windowing:
|
||||
self.win_back = 3
|
||||
self.win_front = 6
|
||||
self.win_back = 1
|
||||
self.win_front = 3
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
|
||||
|
@ -394,7 +394,8 @@ class Decoder(nn.Module):
|
|||
|
||||
self.attention_layer.init_win_idx()
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False, False]
|
||||
stop_flags = [True, False, False]
|
||||
stop_count = 0
|
||||
while True:
|
||||
memory = self.prenet(memory)
|
||||
mel_output, stop_token, alignment = self.decode(memory)
|
||||
|
@ -404,10 +405,12 @@ class Decoder(nn.Module):
|
|||
alignments += [alignment]
|
||||
|
||||
stop_flags[0] = stop_flags[0] or stop_token > 0.5
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
|
||||
stop_flags[2] = t > inputs.shape[1]
|
||||
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
|
||||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
break
|
||||
stop_count += 1
|
||||
if stop_count > 10:
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
|
Loading…
Reference in New Issue