import argparse from datetime import datetime import math import numpy as np import os import subprocess import time import tensorflow as tf import traceback from datasets.datafeeder import DataFeeder from hparams import hparams, hparams_debug_string from models import create_model from text import sequence_to_text from util import audio, infolog, plot, ValueWindow log = infolog.log def get_git_commit(): subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode().strip()[:10] log('Git commit: %s' % commit) return commit def add_stats(model): with tf.variable_scope('stats') as scope: tf.summary.histogram('linear_outputs', model.linear_outputs) tf.summary.histogram('linear_targets', model.linear_targets) tf.summary.histogram('mel_outputs', model.mel_outputs) tf.summary.histogram('mel_targets', model.mel_targets) tf.summary.scalar('loss_mel', model.mel_loss) tf.summary.scalar('loss_linear', model.linear_loss) tf.summary.scalar('learning_rate', model.learning_rate) tf.summary.scalar('loss', model.loss) gradient_norms = [tf.norm(grad) for grad in model.gradients] tf.summary.histogram('gradient_norm', gradient_norms) tf.summary.scalar('max_gradient_norm', tf.reduce_max(gradient_norms)) return tf.summary.merge_all() def time_string(): return datetime.now().strftime('%Y-%m-%d %H:%M') def train(log_dir, args): commit = get_git_commit() if args.git else 'None' checkpoint_path = os.path.join(log_dir, 'model.ckpt') input_path = os.path.join(args.base_dir, args.input) log('Checkpoint path: %s' % checkpoint_path) log('Loading training data from: %s' % input_path) log('Using model: %s' % args.model) log(hparams_debug_string()) # Set up DataFeeder: coord = tf.train.Coordinator() with tf.variable_scope('datafeeder') as scope: feeder = DataFeeder(coord, input_path, hparams) # Set up model: global_step = tf.Variable(0, name='global_step', trainable=False) with tf.variable_scope('model') as scope: model = create_model(args.model, hparams) model.initialize(feeder.inputs, feeder.input_lengths, feeder.mel_targets, feeder.linear_targets) model.add_loss() model.add_optimizer(global_step) stats = add_stats(model) # Bookkeeping: step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) saver = tf.train.Saver(max_to_keep=5, keep_checkpoint_every_n_hours=2) # Train! gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_fraction) with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: try: summary_writer = tf.summary.FileWriter(log_dir, sess.graph) sess.run(tf.global_variables_initializer()) if args.restore_step: # Restore from a checkpoint if the user requested it. restore_path = '%s-%d' % (checkpoint_path, args.restore_step) saver.restore(sess, restore_path) log('Resuming from checkpoint: %s at commit: %s' % (restore_path, commit), slack=True) else: log('Starting new training run at commit: %s' % commit, slack=True) feeder.start_in_session(sess) while not coord.should_stop(): start_time = time.time() step, loss, opt = sess.run([global_step, model.loss, model.optimize]) time_window.append(time.time() - start_time) loss_window.append(loss) message = 'Step %-7d [%.03f sec/step, loss=%.05f, avg_loss=%.05f]' % ( step, time_window.average, loss, loss_window.average) log(message, slack=(step % args.checkpoint_interval == 0)) if loss > 100 or math.isnan(loss): log('Loss exploded to %.05f at step %d!' % (loss, step), slack=True) raise Exception('Loss Exploded') if step % args.summary_interval == 0: log('Writing summary at step: %d' % step) summary_writer.add_summary(sess.run(stats), step) if step % args.checkpoint_interval == 0: log('Saving checkpoint to: %s-%d' % (checkpoint_path, step)) saver.save(sess, checkpoint_path, global_step=step) log('Saving audio and alignment...') input_seq, spectrogram, alignment = sess.run([ model.inputs[0], model.linear_outputs[0], model.alignments[0]]) waveform = audio.inv_spectrogram(spectrogram.T) audio.save_wav(waveform, os.path.join(log_dir, 'step-%d-audio.wav' % step)) plot.plot_alignment(alignment, os.path.join(log_dir, 'step-%d-align.png' % step), info='%s, %s, %s, step=%d, loss=%.5f' % (args.model, commit, time_string(), step, loss)) log('Input: %s' % sequence_to_text(input_seq)) except Exception as e: log('Exiting due to exception: %s' % e, slack=True) traceback.print_exc() coord.request_stop(e) def main(): parser = argparse.ArgumentParser() parser.add_argument('--base_dir', default=os.path.expanduser('~/tacotron')) parser.add_argument('--input', default='training/train.txt') parser.add_argument('--model', default='tacotron') parser.add_argument('--name', help='Name of the run. Used for logging. Defaults to model name.') parser.add_argument('--hparams', default='', help='Hyperparameter overrides as a comma-separated list of name=value pairs') parser.add_argument('--restore_step', type=int, help='Global step to restore from checkpoint.') parser.add_argument('--summary_interval', type=int, default=100, help='Steps between running summary ops.') parser.add_argument('--checkpoint_interval', type=int, default=1000, help='Steps between writing checkpoints.') parser.add_argument('--slack_url', help='Slack webhook URL to get periodic reports.') parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.') parser.add_argument('--git', action='store_true', help='If set, verify that the client is clean.') parser.add_argument('--gpu_assignment', default='0', help='Set the gpu the model should run on') parser.add_argument('--gpu_fraction', type=float, default='1.0', help='Set the fraction of gpu memory to allocate. 0 - 1') args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_assignment os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level) run_name = args.name or args.model log_dir = os.path.join(args.base_dir, 'logs-%s' % run_name) os.makedirs(log_dir, exist_ok=True) infolog.init(os.path.join(log_dir, 'train.log'), run_name, args.slack_url) hparams.parse(args.hparams) train(log_dir, args) if __name__ == '__main__': main()