small argument fix

pull/10/head
Eren Golge 2019-11-19 13:07:06 +01:00
parent 3fa185e63d
commit 1d40909ef7
1 changed files with 13 additions and 13 deletions

View File

@ -81,7 +81,7 @@ def format_data(data):
text_input = data[0]
text_lengths = data[1]
speaker_names = data[2]
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
linear_input = data[3] if c.model in ["Tacotron"] else None
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
@ -96,7 +96,7 @@ def format_data(data):
else:
speaker_ids = None
# set stop targets view, we predict a single stop token per r frames prediction
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) >
@ -108,7 +108,7 @@ def format_data(data):
text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
stop_targets = stop_targets.cuda(non_blocking=True)
if speaker_ids is not None:
speaker_ids = speaker_ids.cuda(non_blocking=True)
@ -171,7 +171,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
else:
@ -179,7 +179,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -277,7 +277,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"
"Tacotron"
] else mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
@ -293,7 +293,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -370,7 +370,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input,
mel_lengths)
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
postnet_loss = criterion(postnet_output, linear_input,
mel_lengths)
else:
@ -378,7 +378,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
mel_lengths)
else:
decoder_loss = criterion(decoder_output, mel_input)
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
postnet_loss = criterion(postnet_output, linear_input)
else:
postnet_loss = criterion(postnet_output, mel_input)
@ -434,7 +434,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"
"Tacotron"
] else mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
@ -445,7 +445,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
}
# Sample audio
if c.model in ["Tacotron", "TacotronGST"]:
if c.model in ["Tacotron"]:
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
@ -562,10 +562,10 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer_st = None
if c.loss_masking:
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
criterion = L1LossMasked() if c.model in ["Tacotron"
] else MSELossMasked()
else:
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
criterion = nn.L1Loss() if c.model in ["Tacotron"
] else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss(
pos_weight=torch.tensor(10)) if c.stopnet else None