bug fixes on train.py

pull/10/head
Eren Golge 2018-12-17 16:37:06 +01:00
parent 2a4adf0c33
commit 8ff9253abd
1 changed files with 9 additions and 7 deletions

View File

@ -43,7 +43,9 @@ def setup_loader(is_val=False):
preprocessor=preprocessor,
ap=ap,
batch_group_size=0 if is_val else 8 * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len)
min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len
cached=False if c.dataset ~= "tts_cache" else True)
loader = DataLoader(
dataset,
batch_size=c.eval_batch_size if is_val else c.batch_size,
@ -164,8 +166,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
flush=True)
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_linear_loss += float(linear_loss.item())
avg_mel_loss += float(mel_loss.item())
avg_stop_loss += stop_loss.item()
avg_step_time += step_time
@ -198,7 +200,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
# Sample audio
tb_logger.tb_train_audios(current_step,
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
c.sample_rate)
c.audio["sample_rate"])
avg_linear_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1)
@ -295,8 +297,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
stop_loss.item()),
flush=True)
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_linear_loss += float(linear_loss.item())
avg_mel_loss += float(mel_loss.item())
avg_stop_loss += stop_loss.item()
# Diagnostic visualizations
@ -442,7 +444,7 @@ if __name__ == '__main__':
default=False,
help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path', type=str, help='dataset path.', default='Defines the data path. It overwrites config.json.')
'--data_path', type=str, default='', default='Defines the data path. It overwrites config.json.')
args = parser.parse_args()
# setup output paths and read configs