general purpose distribute.py

pull/10/head
erogol 2020-10-08 01:29:11 +02:00
parent 6f0654f9a8
commit c5074cfd8e
1 changed files with 69 additions and 0 deletions

69
TTS/bin/distribute.py Normal file
View File

@ -0,0 +1,69 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import pathlib
import time
import subprocess
import argparse
import torch
def main():
"""
Call train.py as a new process and pass command arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--script',
type=str,
help='Target training script to distibute.')
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
)
args = parser.parse_args()
num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S")
# set arguments for train.py
folder_path = pathlib.Path(__file__).parent.absolute()
command = [os.path.join(folder_path, args.script)]
command.append('--continue_path={}'.format(args.continue_path))
command.append('--restore_path={}'.format(args.restore_path))
command.append('--config_path={}'.format(args.config_path))
command.append('--group_id=group_{}'.format(group_id))
command.append('')
# run processes
processes = []
for i in range(num_gpus):
my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = '--rank={}'.format(i)
stdout = None if i == 0 else open(os.devnull, 'w')
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
processes.append(p)
print(command)
for p in processes:
p.wait()
if __name__ == '__main__':
main()