change stop conditioning

pull/10/head
Eren Golge 2019-03-31 16:44:17 +02:00
parent 2e361e2306
commit e1cd253d65
2 changed files with 11 additions and 8 deletions

View File

@ -1,6 +1,6 @@
{
"run_name": "bos",
"run_description": "bos character added to get away with the first char miss",
"run_description": "finetune entropy model due to some spelling mistakes.",
"audio":{
// Audio processing parameters
@ -41,7 +41,7 @@
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
"r": 1, // Number of frames to predict for step.
"wd": 0.000001, // Weight decay weight.

View File

@ -125,8 +125,8 @@ class Attention(nn.Module):
self._mask_value = -float("inf")
self.windowing = windowing
if self.windowing:
self.win_back = 3
self.win_front = 6
self.win_back = 1
self.win_front = 3
self.win_idx = None
self.norm = norm
@ -394,7 +394,8 @@ class Decoder(nn.Module):
self.attention_layer.init_win_idx()
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [False, False, False]
stop_flags = [True, False, False]
stop_count = 0
while True:
memory = self.prenet(memory)
mel_output, stop_token, alignment = self.decode(memory)
@ -404,10 +405,12 @@ class Decoder(nn.Module):
alignments += [alignment]
stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.5 and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1]
stop_flags[1] = stop_flags[1] or (alignment[0, -2:].sum() > 0.8 and t > inputs.shape[1])
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
break
stop_count += 1
if stop_count > 10:
break
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break