diff --git a/server/server.py b/server/server.py index 6af119bf..705937e2 100644 --- a/server/server.py +++ b/server/server.py @@ -18,9 +18,9 @@ def create_argparser(): parser.add_argument('--wavernn_file', type=str, default=None, help='path to WaveRNN checkpoint file.') parser.add_argument('--wavernn_config', type=str, default=None, help='path to WaveRNN config file.') parser.add_argument('--is_wavernn_batched', type=convert_boolean, default=False, help='true to use batched WaveRNN.') - parser.add_argument('--pwgan_lib_path', type=str, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') - parser.add_argument('--pwgan_file', type=str, help='path to ParallelWaveGAN checkpoint file.') - parser.add_argument('--pwgan_config', type=str, help='path to ParallelWaveGAN config file.') + parser.add_argument('--pwgan_lib_path', type=str, default=None, help='path to ParallelWaveGAN project folder to be imported. If this is not passed, model uses Griffin-Lim for synthesis.') + parser.add_argument('--pwgan_file', type=str, default=None, help='path to ParallelWaveGAN checkpoint file.') + parser.add_argument('--pwgan_config', type=str, default=None, help='path to ParallelWaveGAN config file.') parser.add_argument('--port', type=int, default=5002, help='port to listen on.') parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') @@ -29,28 +29,35 @@ def create_argparser(): synthesizer = None -embedded_model_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -checkpoint_file = os.path.join(embedded_model_folder, 'checkpoint.pth.tar') -config_file = os.path.join(embedded_model_folder, 'config.json') +embedded_models_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'model') -# Default options with embedded model files -if os.path.isfile(checkpoint_file): - default_tts_checkpoint = checkpoint_file -else: - default_tts_checkpoint = None +embedded_tts_folder = os.path.join(embedded_models_folder, 'tts') +tts_checkpoint_file = os.path.join(embedded_tts_folder, 'checkpoint.pth.tar') +tts_config_file = os.path.join(embedded_tts_folder, 'config.json') -if os.path.isfile(config_file): - default_tts_config = config_file -else: - default_tts_config = None +embedded_wavernn_folder = os.path.join(embedded_models_folder, 'wavernn') +wavernn_checkpoint_file = os.path.join(embedded_wavernn_folder, 'checkpoint.pth.tar') +wavernn_config_file = os.path.join(embedded_wavernn_folder, 'config.json') + +embedded_pwgan_folder = os.path.join(embedded_models_folder, 'pwgan') +pwgan_checkpoint_file = os.path.join(embedded_pwgan_folder, 'checkpoint.pkl') +pwgan_config_file = os.path.join(embedded_pwgan_folder, 'config.yml') args = create_argparser().parse_args() -# If these were not specified in the CLI args, use default values -if not args.tts_checkpoint: - args.tts_checkpoint = default_tts_checkpoint -if not args.tts_config: - args.tts_config = default_tts_config +# If these were not specified in the CLI args, use default values with embedded model files +if not args.tts_checkpoint and os.path.isfile(tts_checkpoint_file): + args.tts_checkpoint = tts_checkpoint_file +if not args.tts_config and os.path.isfile(tts_config_file): + args.tts_config = tts_config_file +if not args.wavernn_file and os.path.isfile(wavernn_checkpoint_file): + args.wavernn_file = wavernn_checkpoint_file +if not args.wavernn_config and os.path.isfile(wavernn_config_file): + args.wavernn_config = wavernn_config_file +if not args.pwgan_file and os.path.isfile(pwgan_checkpoint_file): + args.pwgan_file = pwgan_checkpoint_file +if not args.pwgan_config and os.path.isfile(pwgan_config_file): + args.pwgan_config = pwgan_config_file synthesizer = Synthesizer(args)