mirror of https://github.com/coqui-ai/TTS.git
bug fixes on train.py
parent
2a4adf0c33
commit
8ff9253abd
16
train.py
16
train.py
|
@ -43,7 +43,9 @@ def setup_loader(is_val=False):
|
|||
preprocessor=preprocessor,
|
||||
ap=ap,
|
||||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||
min_seq_len=0 if is_val else c.min_seq_len)
|
||||
min_seq_len=0 if is_val else c.min_seq_len,
|
||||
max_seq_len=float("inf") if is_val else c.max_seq_len
|
||||
cached=False if c.dataset ~= "tts_cache" else True)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||
|
@ -164,8 +166,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
|||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
avg_step_time += step_time
|
||||
|
||||
|
@ -198,7 +200,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
|||
# Sample audio
|
||||
tb_logger.tb_train_audios(current_step,
|
||||
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
||||
c.sample_rate)
|
||||
c.audio["sample_rate"])
|
||||
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
|
@ -295,8 +297,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
|||
stop_loss.item()),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
|
||||
# Diagnostic visualizations
|
||||
|
@ -442,7 +444,7 @@ if __name__ == '__main__':
|
|||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
parser.add_argument(
|
||||
'--data_path', type=str, help='dataset path.', default='Defines the data path. It overwrites config.json.')
|
||||
'--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup output paths and read configs
|
||||
|
|
Loading…
Reference in New Issue