mirror of https://github.com/coqui-ai/TTS.git
small argument fix
parent
3fa185e63d
commit
1d40909ef7
26
train.py
26
train.py
|
@ -81,7 +81,7 @@ def format_data(data):
|
|||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None
|
||||
linear_input = data[3] if c.model in ["Tacotron"] else None
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
|
@ -96,7 +96,7 @@ def format_data(data):
|
|||
else:
|
||||
speaker_ids = None
|
||||
|
||||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) >
|
||||
|
@ -108,7 +108,7 @@ def format_data(data):
|
|||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None
|
||||
linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron"] else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
|
@ -171,7 +171,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
|
@ -179,7 +179,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
|
@ -277,7 +277,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
"Tacotron"
|
||||
] else mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
|
@ -293,7 +293,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
|
@ -370,7 +370,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input,
|
||||
mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
else:
|
||||
|
@ -378,7 +378,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
|
@ -434,7 +434,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
"Tacotron"
|
||||
] else mel_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
|
@ -445,7 +445,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
}
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
if c.model in ["Tacotron"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
|
@ -562,10 +562,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer_st = None
|
||||
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron", "TacotronGST"
|
||||
criterion = L1LossMasked() if c.model in ["Tacotron"
|
||||
] else MSELossMasked()
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron", "TacotronGST"
|
||||
criterion = nn.L1Loss() if c.model in ["Tacotron"
|
||||
] else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss(
|
||||
pos_weight=torch.tensor(10)) if c.stopnet else None
|
||||
|
|
Loading…
Reference in New Issue