update train_vocoder_gan.py for coqpit

pull/476/head
Eren Gölge 2021-05-07 03:39:49 +02:00
parent 6ee6a563bc
commit e6f45b9eb7
15 changed files with 7 additions and 6 deletions

View File

@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.arguments import init_training
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
@ -163,7 +163,6 @@ def train(
y_hat_sub=y_hat_sub,
y_sub=y_G_sub,
)
loss_G = loss_G_dict["G_loss"]
# optimizer generator
@ -469,7 +468,7 @@ def main(args): # pylint: disable=redefined-outer-name
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
# setup audio processor
ap = AudioProcessor(**c.audio)
ap = AudioProcessor(**c.audio.to_dict())
# DISTRUBUTED
if num_gpus > 1:
@ -620,13 +619,15 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
try:
main(args)
except KeyboardInterrupt:
remove_experiment_folder(OUT_PATH)
try:
sys.exit(0)
except SystemExit:
os._exit(0) # pylint: disable=protected-access
except Exception: # pylint: disable=broad-except
remove_experiment_folder(OUT_PATH)
traceback.print_exc()