mirror of https://github.com/coqui-ai/TTS.git
generic train.py for multiple architectures set on config.json
parent
a4474abd83
commit
08162157ee
232
train.py
232
train.py
|
@ -14,20 +14,19 @@ from torch import optim
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from datasets.TTSDataset import MyDataset
|
||||
from layers.losses import L1LossMasked
|
||||
from models.tacotron import Tacotron
|
||||
from distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from layers.losses import L1LossMasked, MSELossMasked
|
||||
from utils.audio import AudioProcessor
|
||||
from utils.generic_utils import (
|
||||
NoamLR, check_update, count_parameters, create_experiment_folder,
|
||||
get_commit_hash, load_config, lr_decay, remove_experiment_folder,
|
||||
save_best_model, save_checkpoint, sequence_mask, weight_decay)
|
||||
from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
||||
create_experiment_folder, get_commit_hash,
|
||||
load_config, lr_decay,
|
||||
remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, sequence_mask, weight_decay)
|
||||
from utils.logger import Logger
|
||||
from utils.synthesis import synthesis
|
||||
from utils.text.symbols import phonemes, symbols
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from distribute import init_distributed, apply_gradient_allreduce, reduce_tensor
|
||||
from distribute import DistributedSampler
|
||||
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
@ -77,24 +76,19 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
avg_postnet_loss = 0
|
||||
avg_decoder_loss = 0
|
||||
avg_stop_loss = 0
|
||||
avg_step_time = 0
|
||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
n_priority_freq = int(
|
||||
3000 / (c.audio['sample_rate'] * 0.5) * c.audio['num_freq'])
|
||||
if num_gpus > 0:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
linear_input = data[2] if c.model == "Tacotron" else None
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_targets = data[5]
|
||||
|
@ -104,7 +98,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
# set stop targets view, we predict a single stop token per r frames prediction
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
current_step = num_iter + args.restore_step + \
|
||||
epoch * len(data_loader) + 1
|
||||
|
@ -121,24 +115,21 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True) if c.model == "Tacotron" else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
|
||||
# compute mask for padding
|
||||
mask = sequence_mask(text_lengths)
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens = model(
|
||||
text_input, mel_input, mask)
|
||||
# forward pass model
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = (1 - c.loss_weight) * criterion(linear_output, linear_input, mel_lengths)\
|
||||
+ c.loss_weight * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
loss = decoder_loss + postnet_loss
|
||||
|
||||
# backpass and check the grad norm for spec losses
|
||||
loss.backward(retain_graph=True)
|
||||
|
@ -157,49 +148,46 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
|
||||
if current_step % c.print_step == 0:
|
||||
print(
|
||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
|
||||
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}"
|
||||
.format(num_iter, batch_n_iter, current_step, loss.item(),
|
||||
linear_loss.item(), mel_loss.item(), stop_loss.item(),
|
||||
grad_norm, grad_norm_st, avg_text_length,
|
||||
avg_spec_length, step_time, current_lr),
|
||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
||||
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
|
||||
num_iter, batch_n_iter, current_step, loss.item(),
|
||||
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
||||
flush=True)
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
linear_loss = reduce_tensor(linear_loss.data, num_gpus)
|
||||
mel_loss = reduce_tensor(mel_loss.data, num_gpus)
|
||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||
loss = reduce_tensor(loss.data, num_gpus)
|
||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||
|
||||
if args.rank == 0:
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_postnet_loss += float(postnet_loss.item())
|
||||
avg_decoder_loss += float(decoder_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
avg_step_time += step_time
|
||||
|
||||
# Plot Training Iter Stats
|
||||
iter_stats = {
|
||||
"loss_posnet": linear_loss.item(),
|
||||
"loss_decoder": mel_loss.item(),
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"loss_posnet": postnet_loss.item(),
|
||||
"loss_decoder": decoder_loss.item(),
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time}
|
||||
tb_logger.tb_train_iter_stats(current_step, iter_stats)
|
||||
|
||||
if current_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, optimizer_st,
|
||||
linear_loss.item(), OUT_PATH, current_step,
|
||||
postnet_loss.item(), OUT_PATH, current_step,
|
||||
epoch)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_input[0].data.cpu().numpy()
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
|
@ -210,47 +198,49 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_train_figures(current_step, figures)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_train_audios(
|
||||
current_step, {'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
||||
c.audio["sample_rate"])
|
||||
if c.model == "Tacotron":
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(current_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
avg_postnet_loss /= (num_iter + 1)
|
||||
avg_decoder_loss /= (num_iter + 1)
|
||||
avg_stop_loss /= (num_iter + 1)
|
||||
avg_total_loss = avg_mel_loss + avg_linear_loss + avg_stop_loss
|
||||
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
|
||||
avg_step_time /= (num_iter + 1)
|
||||
|
||||
# print epoch stats
|
||||
print(
|
||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
|
||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||
avg_linear_loss, avg_mel_loss,
|
||||
avg_postnet_loss, avg_decoder_loss,
|
||||
avg_stop_loss, epoch_time, avg_step_time),
|
||||
flush=True)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {
|
||||
"loss_postnet": avg_linear_loss,
|
||||
"loss_decoder": avg_mel_loss,
|
||||
"stop_loss": avg_stop_loss,
|
||||
"epoch_time": epoch_time
|
||||
}
|
||||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
||||
"loss_decoder": avg_decoder_loss,
|
||||
"stop_loss": avg_stop_loss,
|
||||
"epoch_time": epoch_time}
|
||||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, current_step)
|
||||
return avg_linear_loss, current_step
|
||||
tb_logger.tb_model_weights(model, current_step)
|
||||
return avg_postnet_loss, current_step
|
||||
|
||||
|
||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||
data_loader = setup_loader(is_val=True)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
avg_mel_loss = 0
|
||||
avg_postnet_loss = 0
|
||||
avg_decoder_loss = 0
|
||||
avg_stop_loss = 0
|
||||
print("\n > Validation")
|
||||
test_sentences = [
|
||||
|
@ -269,7 +259,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
linear_input = data[2] if c.model == "Tacotron" else None
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_targets = data[5]
|
||||
|
@ -278,56 +268,56 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r,
|
||||
-1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float()
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
linear_input = linear_input.cuda() if c.model == "Tacotron" else None
|
||||
stop_targets = stop_targets.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, mel_input)
|
||||
decoder_output, postnet_output, alignments, stop_tokens =\
|
||||
model.forward(text_input, text_lengths, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss + stop_loss
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
postnet_loss = criterion(postnet_output, linear_input, mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input, mel_lengths)
|
||||
loss = decoder_loss + postnet_loss + stop_loss
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
if num_iter % c.print_step == 0:
|
||||
print(
|
||||
" | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
|
||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} DecoderLoss:{:.5f} "
|
||||
"StopLoss: {:.5f} ".format(loss.item(),
|
||||
linear_loss.item(),
|
||||
mel_loss.item(),
|
||||
postnet_loss.item(),
|
||||
decoder_loss.item(),
|
||||
stop_loss.item()),
|
||||
flush=True)
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
linear_loss = reduce_tensor(linear_loss.data, num_gpus)
|
||||
mel_loss = reduce_tensor(mel_loss.data, num_gpus)
|
||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_postnet_loss += float(postnet_loss.item())
|
||||
avg_decoder_loss += float(decoder_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = linear_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = mel_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
|
@ -338,21 +328,21 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
tb_logger.tb_eval_figures(current_step, eval_figures)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_eval_audios(
|
||||
current_step, {"ValAudio": ap.inv_spectrogram(const_spec.T)},
|
||||
c.audio["sample_rate"])
|
||||
if c.model == "Tacotron":
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(current_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# compute average losses
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
avg_postnet_loss /= (num_iter + 1)
|
||||
avg_decoder_loss /= (num_iter + 1)
|
||||
avg_stop_loss /= (num_iter + 1)
|
||||
|
||||
# Plot Validation Stats
|
||||
epoch_stats = {
|
||||
"loss_postnet": avg_linear_loss,
|
||||
"loss_decoder": avg_mel_loss,
|
||||
"stop_loss": avg_stop_loss
|
||||
}
|
||||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
||||
"loss_decoder": avg_decoder_loss,
|
||||
"stop_loss": avg_stop_loss}
|
||||
tb_logger.tb_eval_stats(current_step, epoch_stats)
|
||||
|
||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||
|
@ -362,7 +352,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
print(" | > Synthesizing test sentences")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, linear_spec, _, stop_tokens = synthesis(
|
||||
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
||||
model, test_sentence, c, use_cuda, ap)
|
||||
file_path = os.path.join(AUDIO_PATH, str(current_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
|
@ -370,16 +360,14 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
"TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
linear_spec, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(alignment)
|
||||
except:
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(current_step, test_figures)
|
||||
return avg_linear_loss
|
||||
return avg_postnet_loss
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -388,25 +376,24 @@ def main(args):
|
|||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = Tacotron(
|
||||
num_chars=num_chars,
|
||||
linear_dim=ap.num_freq,
|
||||
mel_dim=ap.num_mels,
|
||||
r=c.r,
|
||||
memory_size=c.memory_size)
|
||||
model = MyModel(num_chars=num_chars, r=1)
|
||||
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
optimizer_st = optim.Adam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
criterion = L1LossMasked()
|
||||
criterion_st = nn.BCELoss()
|
||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||
criterion_st = nn.BCEWithLogitsLoss()
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
# optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
except:
|
||||
print(" > Partial model initialization.")
|
||||
partial_init_flag = True
|
||||
|
@ -438,7 +425,7 @@ def main(args):
|
|||
print(
|
||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
start_epoch = checkpoint['epoch']
|
||||
best_loss = checkpoint['linear_loss']
|
||||
best_loss = checkpoint['postnet_loss']
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
@ -496,7 +483,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
'--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
|
@ -534,7 +521,7 @@ if __name__ == '__main__':
|
|||
OUT_PATH = args.output_path
|
||||
|
||||
if args.group_id == '':
|
||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.model_name, args.debug)
|
||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
|
@ -552,6 +539,9 @@ if __name__ == '__main__':
|
|||
# Conditional imports
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('models.'+c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
|
|
@ -48,10 +48,10 @@ def get_commit_hash():
|
|||
def create_experiment_folder(root_path, model_name, debug):
|
||||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
if debug:
|
||||
commit_hash = 'debug'
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
# if debug:
|
||||
# commit_hash = 'debug'
|
||||
# else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(
|
||||
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue