editing the pr #310 and merging

pull/10/head
Eren Golge 2019-11-14 16:14:01 +01:00
parent 574de86b9b
commit 8af75cad46
4 changed files with 42 additions and 485 deletions

View File

@ -11,3 +11,4 @@ scipy==0.19.0
tqdm tqdm
git+git://github.com/bootphon/phonemizer@master git+git://github.com/bootphon/phonemizer@master
soundfile soundfile
bokeh==1.4.0

View File

@ -87,6 +87,7 @@ setup(
"flask", "flask",
# "lws", # "lws",
"tqdm", "tqdm",
"bokeh==1.4.0",
"soundfile", "soundfile",
"phonemizer @ https://github.com/bootphon/phonemizer/tarball/master", "phonemizer @ https://github.com/bootphon/phonemizer/tarball/master",
], ],

View File

@ -11,17 +11,26 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import load_config from TTS.utils.generic_utils import load_config
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Compute embedding vectors for each wav file in a dataset. " description='Compute embedding vectors for each wav file in a dataset. ')
parser.add_argument(
'model_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).')
parser.add_argument(
'config_path',
type=str,
help='Path to config file for training.',
) )
parser.add_argument( parser.add_argument(
'data_path', 'data_path',
type=str, type=str,
help='Data path for wav files - directory or CSV file') help='Data path for wav files - directory or CSV file')
parser.add_argument( parser.add_argument(
"config_path", type=str, help="Path to config file for training.", 'output_path',
) type=str,
help='path for training outputs.')
parser.add_argument( parser.add_argument(
"data_path", type=str, help="Defines the data path. It overwrites config.json." '--use_cuda', type=bool, help='flag to set cuda.', default=False
) )
parser.add_argument( parser.add_argument(
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|' '--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
@ -30,7 +39,7 @@ args = parser.parse_args()
c = load_config(args.config_path) c = load_config(args.config_path)
ap = AudioProcessor(**c["audio"]) ap = AudioProcessor(**c['audio'])
data_path = args.data_path data_path = args.data_path
split_ext = os.path.splitext(data_path) split_ext = os.path.splitext(data_path)
@ -65,7 +74,7 @@ for output_file in output_files:
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
model = SpeakerEncoder(**c.model) model = SpeakerEncoder(**c.model)
model.load_state_dict(torch.load(args.model_path)["model"]) model.load_state_dict(torch.load(args.model_path)['model'])
model.eval() model.eval()
if args.use_cuda: if args.use_cuda:
model.cuda() model.cuda()

File diff suppressed because one or more lines are too long