generic train.py for multiple architectures set on config.json

pull/10/head
Eren Golge 2019-03-06 13:11:22 +01:00
parent a4474abd83
commit 08162157ee
2 changed files with 115 additions and 125 deletions

232
train.py
View File

@ -14,20 +14,19 @@ from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from datasets.TTSDataset import MyDataset from datasets.TTSDataset import MyDataset
from layers.losses import L1LossMasked from distribute import (DistributedSampler, apply_gradient_allreduce,
from models.tacotron import Tacotron init_distributed, reduce_tensor)
from layers.losses import L1LossMasked, MSELossMasked
from utils.audio import AudioProcessor from utils.audio import AudioProcessor
from utils.generic_utils import ( from utils.generic_utils import (NoamLR, check_update, count_parameters,
NoamLR, check_update, count_parameters, create_experiment_folder, create_experiment_folder, get_commit_hash,
get_commit_hash, load_config, lr_decay, remove_experiment_folder, load_config, lr_decay,
save_best_model, save_checkpoint, sequence_mask, weight_decay) remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay)
from utils.logger import Logger from utils.logger import Logger
from utils.synthesis import synthesis from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols from utils.text.symbols import phonemes, symbols
from utils.visual import plot_alignment, plot_spectrogram from utils.visual import plot_alignment, plot_spectrogram
from distribute import init_distributed, apply_gradient_allreduce, reduce_tensor
from distribute import DistributedSampler
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
@ -77,24 +76,19 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
data_loader = setup_loader(is_val=False, verbose=(epoch==0)) data_loader = setup_loader(is_val=False, verbose=(epoch==0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_postnet_loss = 0
avg_mel_loss = 0 avg_decoder_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
avg_step_time = 0 avg_step_time = 0
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True) print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
n_priority_freq = int( batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
if num_gpus > 0:
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2] if c.model == "Tacotron" else None
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_targets = data[5] stop_targets = data[5]
@ -104,7 +98,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# set stop targets view, we predict a single stop token per r frames prediction # set stop targets view, we predict a single stop token per r frames prediction
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1) stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
current_step = num_iter + args.restore_step + \ current_step = num_iter + args.restore_step + \
epoch * len(data_loader) + 1 epoch * len(data_loader) + 1
@ -121,24 +115,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_lengths = text_lengths.cuda(non_blocking=True) text_lengths = text_lengths.cuda(non_blocking=True)
mel_input = mel_input.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True)
mel_lengths = mel_lengths.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True)
linear_input = linear_input.cuda(non_blocking=True) linear_input = linear_input.cuda(non_blocking=True) if c.model == "Tacotron" else None
stop_targets = stop_targets.cuda(non_blocking=True) stop_targets = stop_targets.cuda(non_blocking=True)
# compute mask for padding # forward pass model
mask = sequence_mask(text_lengths) decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input)
# forward pass
mel_output, linear_output, alignments, stop_tokens = model(
text_input, mel_input, mask)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\ if c.model == "Tacotron":
+ c.loss_weight * criterion(linear_output[:, :, :n_priority_freq], postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
linear_input[:, :, :n_priority_freq], else:
mel_lengths) postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
loss = mel_loss + linear_loss loss = decoder_loss + postnet_loss
# backpass and check the grad norm for spec losses # backpass and check the grad norm for spec losses
loss.backward(retain_graph=True) loss.backward(retain_graph=True)
@ -157,49 +148,46 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if current_step % c.print_step == 0: if current_step % c.print_step == 0:
print( print(
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} " " | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} " "DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}" "GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
.format(num_iter, batch_n_iter, current_step, loss.item(), num_iter, batch_n_iter, current_step, loss.item(),
linear_loss.item(), mel_loss.item(), stop_loss.item(), postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
grad_norm, grad_norm_st, avg_text_length, grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
avg_spec_length, step_time, current_lr),
flush=True) flush=True)
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
linear_loss = reduce_tensor(linear_loss.data, num_gpus) postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
mel_loss = reduce_tensor(mel_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
loss = reduce_tensor(loss.data, num_gpus) loss = reduce_tensor(loss.data, num_gpus)
stop_loss = reduce_tensor(stop_loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus)
if args.rank == 0: if args.rank == 0:
avg_linear_loss += float(linear_loss.item()) avg_postnet_loss += float(postnet_loss.item())
avg_mel_loss += float(mel_loss.item()) avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
avg_step_time += step_time avg_step_time += step_time
# Plot Training Iter Stats # Plot Training Iter Stats
iter_stats = { iter_stats = {"loss_posnet": postnet_loss.item(),
"loss_posnet": linear_loss.item(), "loss_decoder": decoder_loss.item(),
"loss_decoder": mel_loss.item(), "lr": current_lr,
"lr": current_lr, "grad_norm": grad_norm,
"grad_norm": grad_norm, "grad_norm_st": grad_norm_st,
"grad_norm_st": grad_norm_st, "step_time": step_time}
"step_time": step_time
}
tb_logger.tb_train_iter_stats(current_step, iter_stats) tb_logger.tb_train_iter_stats(current_step, iter_stats)
if current_step % c.save_step == 0: if current_step % c.save_step == 0:
if c.checkpoint: if c.checkpoint:
# save model # save model
save_checkpoint(model, optimizer, optimizer_st, save_checkpoint(model, optimizer, optimizer_st,
linear_loss.item(), OUT_PATH, current_step, postnet_loss.item(), OUT_PATH, current_step,
epoch) epoch)
# Diagnostic visualizations # Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy() const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy() align_img = alignments[0].data.cpu().numpy()
figures = { figures = {
@ -210,47 +198,49 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_train_figures(current_step, figures) tb_logger.tb_train_figures(current_step, figures)
# Sample audio # Sample audio
tb_logger.tb_train_audios( if c.model == "Tacotron":
current_step, {'TrainAudio': ap.inv_spectrogram(const_spec.T)}, train_audio = ap.inv_spectrogram(const_spec.T)
c.audio["sample_rate"]) else:
train_audio = ap.inv_mel_spectrogram(const_spec.T)
tb_logger.tb_train_audios(current_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
avg_linear_loss /= (num_iter + 1) avg_postnet_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1)
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
avg_step_time /= (num_iter + 1) avg_step_time /= (num_iter + 1)
# print epoch stats # print epoch stats
print( print(
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} " " | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} " "AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} " "AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss, "AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
avg_linear_loss, avg_mel_loss, avg_postnet_loss, avg_decoder_loss,
avg_stop_loss, epoch_time, avg_step_time), avg_stop_loss, epoch_time, avg_step_time),
flush=True) flush=True)
# Plot Epoch Stats # Plot Epoch Stats
if args.rank == 0: if args.rank == 0:
# Plot Training Epoch Stats # Plot Training Epoch Stats
epoch_stats = { epoch_stats = {"loss_postnet": avg_postnet_loss,
"loss_postnet": avg_linear_loss, "loss_decoder": avg_decoder_loss,
"loss_decoder": avg_mel_loss, "stop_loss": avg_stop_loss,
"stop_loss": avg_stop_loss, "epoch_time": epoch_time}
"epoch_time": epoch_time
}
tb_logger.tb_train_epoch_stats(current_step, epoch_stats) tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
if c.tb_model_param_stats: if c.tb_model_param_stats:
tb_logger.tb_model_weights(model, current_step) tb_logger.tb_model_weights(model, current_step)
return avg_linear_loss, current_step return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch): def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True) data_loader = setup_loader(is_val=True)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
avg_linear_loss = 0 avg_postnet_loss = 0
avg_mel_loss = 0 avg_decoder_loss = 0
avg_stop_loss = 0 avg_stop_loss = 0
print("\n > Validation") print("\n > Validation")
test_sentences = [ test_sentences = [
@ -269,7 +259,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
# setup input data # setup input data
text_input = data[0] text_input = data[0]
text_lengths = data[1] text_lengths = data[1]
linear_input = data[2] linear_input = data[2] if c.model == "Tacotron" else None
mel_input = data[3] mel_input = data[3]
mel_lengths = data[4] mel_lengths = data[4]
stop_targets = data[5] stop_targets = data[5]
@ -278,56 +268,56 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
stop_targets = stop_targets.view(text_input.shape[0], stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, stop_targets.size(1) // c.r,
-1) -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float() stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU # dispatch data to GPU
if use_cuda: if use_cuda:
text_input = text_input.cuda() text_input = text_input.cuda()
mel_input = mel_input.cuda() mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda() mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda() linear_input = linear_input.cuda() if c.model == "Tacotron" else None
stop_targets = stop_targets.cuda() stop_targets = stop_targets.cuda()
# forward pass # forward pass
mel_output, linear_output, alignments, stop_tokens =\ decoder_output, postnet_output, alignments, stop_tokens =\
model.forward(text_input, mel_input) model.forward(text_input, text_lengths, mel_input)
# loss computation # loss computation
stop_loss = criterion_st(stop_tokens, stop_targets) stop_loss = criterion_st(stop_tokens, stop_targets)
mel_loss = criterion(mel_output, mel_input, mel_lengths) decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \ if c.model == "Tacotron":
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq], postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
linear_input[:, :, :n_priority_freq], else:
mel_lengths) postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
loss = mel_loss + linear_loss + stop_loss loss = decoder_loss + postnet_loss + stop_loss
step_time = time.time() - start_time step_time = time.time() - start_time
epoch_time += step_time epoch_time += step_time
if num_iter % c.print_step == 0: if num_iter % c.print_step == 0:
print( print(
" | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} " " | > TotalLoss: {:.5f} PostnetLoss: {:.5f} DecoderLoss:{:.5f} "
"StopLoss: {:.5f} ".format(loss.item(), "StopLoss: {:.5f} ".format(loss.item(),
linear_loss.item(), postnet_loss.item(),
mel_loss.item(), decoder_loss.item(),
stop_loss.item()), stop_loss.item()),
flush=True) flush=True)
# aggregate losses from processes # aggregate losses from processes
if num_gpus > 1: if num_gpus > 1:
linear_loss = reduce_tensor(linear_loss.data, num_gpus) postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
mel_loss = reduce_tensor(mel_loss.data, num_gpus) decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
stop_loss = reduce_tensor(stop_loss.data, num_gpus) stop_loss = reduce_tensor(stop_loss.data, num_gpus)
avg_linear_loss += float(linear_loss.item()) avg_postnet_loss += float(postnet_loss.item())
avg_mel_loss += float(mel_loss.item()) avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss.item() avg_stop_loss += stop_loss.item()
if args.rank == 0: if args.rank == 0:
# Diagnostic visualizations # Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0]) idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy() const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() gt_spec = mel_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy() align_img = alignments[idx].data.cpu().numpy()
eval_figures = { eval_figures = {
@ -338,21 +328,21 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
tb_logger.tb_eval_figures(current_step, eval_figures) tb_logger.tb_eval_figures(current_step, eval_figures)
# Sample audio # Sample audio
tb_logger.tb_eval_audios( if c.model == "Tacotron":
current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)}, eval_audio = ap.inv_spectrogram(const_spec.T)
c.audio["sample_rate"]) else:
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
tb_logger.tb_eval_audios(current_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# compute average losses # compute average losses
avg_linear_loss /= (num_iter + 1) avg_postnet_loss /= (num_iter + 1)
avg_mel_loss /= (num_iter + 1) avg_decoder_loss /= (num_iter + 1)
avg_stop_loss /= (num_iter + 1) avg_stop_loss /= (num_iter + 1)
# Plot Validation Stats # Plot Validation Stats
epoch_stats = { epoch_stats = {"loss_postnet": avg_postnet_loss,
"loss_postnet": avg_linear_loss, "loss_decoder": avg_decoder_loss,
"loss_decoder": avg_mel_loss, "stop_loss": avg_stop_loss}
"stop_loss": avg_stop_loss
}
tb_logger.tb_eval_stats(current_step, epoch_stats) tb_logger.tb_eval_stats(current_step, epoch_stats)
if args.rank == 0 and epoch > c.test_delay_epochs: if args.rank == 0 and epoch > c.test_delay_epochs:
@ -362,7 +352,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, linear_spec, _, stop_tokens = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
model, test_sentence, c, use_cuda, ap) model, test_sentence, c, use_cuda, ap)
file_path = os.path.join(AUDIO_PATH, str(current_step)) file_path = os.path.join(AUDIO_PATH, str(current_step))
os.makedirs(file_path, exist_ok=True) os.makedirs(file_path, exist_ok=True)
@ -370,16 +360,14 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
"TestSentence_{}.wav".format(idx)) "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path) ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram( test_figures['{}-prediction'.format(idx)] = plot_spectrogram(postnet_output, ap)
linear_spec, ap) test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except: except:
print(" !! Error creating Test Sentence -", idx) print(" !! Error creating Test Sentence -", idx)
traceback.print_exc() traceback.print_exc()
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate']) tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
tb_logger.tb_test_figures(current_step, test_figures) tb_logger.tb_test_figures(current_step, test_figures)
return avg_linear_loss return avg_postnet_loss
def main(args): def main(args):
@ -388,25 +376,24 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id, init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"]) c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols) num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = Tacotron( model = MyModel(num_chars=num_chars, r=1)
num_chars=num_chars,
linear_dim=ap.num_freq, print(" | > Num output units : {}".format(ap.num_freq), flush=True)
mel_dim=ap.num_mels,
r=c.r,
memory_size=c.memory_size)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
optimizer_st = optim.Adam( optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
criterion = L1LossMasked() criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
criterion_st = nn.BCELoss() criterion_st = nn.BCEWithLogitsLoss()
if args.restore_path: if args.restore_path:
checkpoint = torch.load(args.restore_path) checkpoint = torch.load(args.restore_path)
try: try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
# optimizer.load_state_dict(checkpoint['optimizer'])
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
except: except:
print(" > Partial model initialization.") print(" > Partial model initialization.")
partial_init_flag = True partial_init_flag = True
@ -438,7 +425,7 @@ def main(args):
print( print(
" > Model restored from step %d" % checkpoint['step'], flush=True) " > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']
best_loss = checkpoint['linear_loss'] best_loss = checkpoint['postnet_loss']
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
else: else:
args.restore_step = 0 args.restore_step = 0
@ -496,7 +483,7 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
'--debug', '--debug',
type=bool, type=bool,
default=False, default=True,
help='Do not verify commit integrity to run training.') help='Do not verify commit integrity to run training.')
parser.add_argument( parser.add_argument(
'--data_path', '--data_path',
@ -534,7 +521,7 @@ if __name__ == '__main__':
OUT_PATH = args.output_path OUT_PATH = args.output_path
if args.group_id == '': if args.group_id == '':
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug) OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
@ -552,6 +539,9 @@ if __name__ == '__main__':
# Conditional imports # Conditional imports
preprocessor = importlib.import_module('datasets.preprocess') preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower()) preprocessor = getattr(preprocessor, c.dataset.lower())
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = getattr(MyModel, c.model)
# Audio processor # Audio processor
ap = AudioProcessor(**c.audio) ap = AudioProcessor(**c.audio)

View File

@ -48,10 +48,10 @@ def get_commit_hash():
def create_experiment_folder(root_path, model_name, debug): def create_experiment_folder(root_path, model_name, debug):
""" Create a folder with the current date and time """ """ Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
if debug: # if debug:
commit_hash = 'debug' # commit_hash = 'debug'
else: # else:
commit_hash = get_commit_hash() commit_hash = get_commit_hash()
output_folder = os.path.join( output_folder = os.path.join(
root_path, model_name + '-' + date_str + '-' + commit_hash) root_path, model_name + '-' + date_str + '-' + commit_hash)
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)