bug fixes, linter update and test updates

pull/10/head
Eren Golge 2019-10-29 14:28:49 +01:00
parent 89ef71ead8
commit 002991ca15
8 changed files with 33 additions and 32 deletions

View File

@ -103,8 +103,8 @@ class CBHG(nn.Module):
num_highways (int): number of highways layers
Shapes:
- input: batch x time x dim
- output: batch x time x dim*2
- input: B x D x T_in
- output: B x T_in x D*2
"""
def __init__(self,

View File

@ -100,7 +100,7 @@ class Decoder(nn.Module):
#pylint: disable=attribute-defined-outside-init
def __init__(self, in_features, memory_dim, r, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, separate_stopnet,
forward_attn_mask, location_attn, separate_stopnet,
speaker_embedding_dim):
super(Decoder, self).__init__()
self.memory_dim = memory_dim
@ -117,7 +117,7 @@ class Decoder(nn.Module):
self.p_decoder_dropout = 0.1
# memory -> |Prenet| -> processed_memory
prenet_dim = self.memory_dim + speaker_embedding_dim
prenet_dim = self.memory_dim
self.prenet = Prenet(
prenet_dim,
prenet_type,
@ -244,7 +244,10 @@ class Decoder(nn.Module):
memory = self.get_go_frame(inputs).unsqueeze(0)
memories = self._reshape_memory(memories)
memories = torch.cat((memory, memories), dim=0)
memories = self.prenet(self._update_memory(memories))
memories = self._update_memory(memories)
if speaker_embeddings is not None:
memories = torch.cat([memories, speaker_embeddings], dim=-1)
memories = self.prenet(memories)
self._init_states(inputs, mask=mask)
self.attention.init_states(inputs)
@ -252,8 +255,6 @@ class Decoder(nn.Module):
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)]
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, attention_weights, stop_token = self.decode(memory)
outputs += [mel_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)]
@ -277,7 +278,7 @@ class Decoder(nn.Module):
while True:
memory = self.prenet(memory)
if speaker_embeddings is not None:
memory = torch.cat([memory, speaker_embeddings], dim=-1)
memory = torch.cat([memory, speaker_embeddings], dim=-1)
mel_output, alignment, stop_token = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]

View File

@ -96,7 +96,6 @@ class Tacotron(nn.Module):
- speaker_ids: B x 1
"""
self._init_states()
B = characters.size(0)
mask = sequence_mask(text_lengths).to(characters.device)
# B x T_in x embed_dim
inputs = self.embedding(characters)
@ -132,14 +131,13 @@ class Tacotron(nn.Module):
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference(self, characters, speaker_ids=None, style_mel=None):
B = characters.size(0)
inputs = self.embedding(characters)
self._init_states()
self.compute_speaker_embedding(speaker_ids)
if self.num_speakers > 1:
inputs = self._concat_speaker_embedding(inputs,
self.speaker_embeddings)
encoder_outputs = self.encoder(inputs)
encoder_outputs = self.encoder(inputs)
if self.gst and style_mel is not None:
encoder_outputs = self.compute_gst(encoder_outputs, style_mel)
if self.num_speakers > 1:

View File

@ -28,8 +28,8 @@ class Tacotron2(nn.Module):
self.decoder_output_dim = decoder_output_dim
self.n_frames_per_step = r
self.bidirectional_decoder = bidirectional_decoder
decoder_dim = 512 + 256 if num_speakers > 1 else 512
encoder_dim = 512 + 256 if num_speakers > 1 else 512
decoder_dim = 512 if num_speakers > 1 else 512
encoder_dim = 512 if num_speakers > 1 else 512
proj_speaker_dim = 80 if num_speakers > 1 else 0
# embedding layer
self.embedding = nn.Embedding(num_chars, 512)
@ -39,6 +39,8 @@ class Tacotron2(nn.Module):
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.speaker_embeddings = None
self.speaker_embeddings_projected = None
self.encoder = Encoder(encoder_dim)
self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win,
attn_norm, prenet_type, prenet_dropout,
@ -47,7 +49,7 @@ class Tacotron2(nn.Module):
if self.bidirectional_decoder:
self.decoder_backward = copy.deepcopy(self.decoder)
self.postnet = Postnet(self.decoder_output_dim)
def _init_states(self):
self.speaker_embeddings = None
self.speaker_embeddings_projected = None

View File

@ -44,6 +44,7 @@
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"forward_attn_mask": false,
"bidirectional_decoder": false,
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.

View File

@ -29,7 +29,8 @@ class CBHGTests(unittest.TestCase):
highway_features=80,
gru_features=80,
num_highways=4)
dummy_input = T.rand(4, 8, 128)
# B x D x T
dummy_input = T.rand(4, 128, 8)
print(layer)
output = layer(dummy_input)
@ -63,8 +64,8 @@ class DecoderTests(unittest.TestCase):
dummy_input, dummy_memory, mask=None)
assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
assert output.shape[1] == 80, "size not {}".format(output.shape[1])
assert output.shape[2] == 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4
@staticmethod
@ -92,8 +93,8 @@ class DecoderTests(unittest.TestCase):
dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed)
assert output.shape[0] == 4
assert output.shape[1] == 1, "size not {}".format(output.shape[1])
assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2])
assert output.shape[1] == 80, "size not {}".format(output.shape[1])
assert output.shape[2] == 2, "size not {}".format(output.shape[2])
assert stop_tokens.shape[0] == 4

View File

@ -49,8 +49,8 @@ class TacotronTrainTest(unittest.TestCase):
model = Tacotron(
num_chars=32,
num_speakers=5,
linear_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'],
postnet_output_dim=c.audio['num_freq'],
decoder_output_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor
@ -112,8 +112,8 @@ class TacotronGSTTrainTest(unittest.TestCase):
num_chars=32,
num_speakers=5,
gst=True,
linear_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'],
postnet_output_dim=c.audio['num_freq'],
decoder_output_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size
).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor

View File

@ -80,8 +80,7 @@ def format_data(data):
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"
] else None
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
@ -98,7 +97,7 @@ def format_data(data):
# set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
@ -108,9 +107,7 @@ def format_data(data):
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(
non_blocking=True) if c.model in ["Tacotron", "TacotronGST"
] else None
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
@ -352,8 +349,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data)
assert mel_input.shape[1] % model.decoder.r == 0
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model
if c.bidirectional_decoder:
@ -622,7 +619,8 @@ def main(args): # pylint: disable=redefined-outer-name
r, c.batch_size = gradual_training_scheduler(global_step, c)
c.r = r
model.decoder.set_r(r)
if c.bidirectional_decoder: model.decoder_backward.set_r(r)
if c.bidirectional_decoder:
model.decoder_backward.set_r(r)
print(" > Number of outputs per iteration:", model.decoder.r)
train_loss, global_step = train(model, criterion, criterion_st,