diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 9dc1c596..5fd4cf3f 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -18,16 +18,11 @@ from TTS.utils.tensorboard_logger import TensorboardLogger def parse_arguments(argv): """Parse command line arguments of training scripts. - Parameters - ---------- - argv : list - This is a list of input arguments as given by sys.argv - - Returns - ------- - argparse.Namespace - Parsed arguments. + Args: + argv (list): This is a list of input arguments as given by sys.argv + Returns: + argparse.Namespace: Parsed arguments. """ parser = argparse.ArgumentParser() parser.add_argument( @@ -45,7 +40,8 @@ def parse_arguments(argv): parser.add_argument( "--best_path", type=str, - help="Best model file to be used for extracting best loss.", + help=("Best model file to be used for extracting best loss." + "If not specified, the latest best model in continue path is used"), default="") parser.add_argument( "--config_path", @@ -77,21 +73,14 @@ def get_last_models(path): It is based on globbing for `*.pth.tar` and the RegEx `(checkpoint|best_model)_([0-9]+)`. - Parameters - ---------- - path : list - Path to files to be compared. + Args: + path (list): Path to files to be compared. - Raises - ------ - ValueError - If no checkpoint or best_model files are found. - - Returns - ------- - last_checkpoint : str - Last checkpoint filename. + Raises: + ValueError: If no checkpoint or best_model files are found. + Returns: + last_checkpoint (str): Last checkpoint filename. """ file_names = glob.glob(os.path.join(path, "*.pth.tar")) last_models = {} @@ -131,8 +120,8 @@ def process_args(args, model_type): Args: args (argparse.Namespace or dict like): Parsed input arguments. - model_type (str): Model type used to check config parameters and setup the TensorBoard - logger. One of: + model_type (str): Model type used to check config parameters and setup + the TensorBoard logger. One of: - tacotron - glow_tts - speedy_speech @@ -141,15 +130,16 @@ def process_args(args, model_type): - wavernn Raises: - ValueError - If `model_type` is not one of implemented choices. + ValueError: If `model_type` is not one of implemented choices. Returns: c (TTS.utils.io.AttrDict): Config paramaters. out_path (str): Path to save models and logging. audio_path (str): Path to save generated test audios. - c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does logging to the console. - tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does the TensorBoard loggind. + c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does + logging to the console. + tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does + the TensorBoard loggind. """ if args.continue_path: args.output_path = args.continue_path