mirror of https://github.com/coqui-ai/TTS.git
update train_vocoder_gan.py for coqpit
parent
6ee6a563bc
commit
e6f45b9eb7
|
@ -16,7 +16,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
|
@ -163,7 +163,6 @@ def train(
|
|||
y_hat_sub=y_hat_sub,
|
||||
y_sub=y_G_sub,
|
||||
)
|
||||
|
||||
loss_G = loss_G_dict["G_loss"]
|
||||
|
||||
# optimizer generator
|
||||
|
@ -469,7 +468,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
|
@ -620,13 +619,15 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
|
||||
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
|
|
Loading…
Reference in New Issue