mirror of https://github.com/coqui-ai/TTS.git
editing the pr #310 and merging
parent
574de86b9b
commit
8af75cad46
|
@ -11,3 +11,4 @@ scipy==0.19.0
|
||||||
tqdm
|
tqdm
|
||||||
git+git://github.com/bootphon/phonemizer@master
|
git+git://github.com/bootphon/phonemizer@master
|
||||||
soundfile
|
soundfile
|
||||||
|
bokeh==1.4.0
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -87,6 +87,7 @@ setup(
|
||||||
"flask",
|
"flask",
|
||||||
# "lws",
|
# "lws",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"bokeh==1.4.0",
|
||||||
"soundfile",
|
"soundfile",
|
||||||
"phonemizer @ https://github.com/bootphon/phonemizer/tarball/master",
|
"phonemizer @ https://github.com/bootphon/phonemizer/tarball/master",
|
||||||
],
|
],
|
||||||
|
|
|
@ -11,17 +11,26 @@ from TTS.utils.audio import AudioProcessor
|
||||||
from TTS.utils.generic_utils import load_config
|
from TTS.utils.generic_utils import load_config
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Compute embedding vectors for each wav file in a dataset. "
|
description='Compute embedding vectors for each wav file in a dataset. ')
|
||||||
|
parser.add_argument(
|
||||||
|
'model_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to model outputs (checkpoint, tensorboard etc.).')
|
||||||
|
parser.add_argument(
|
||||||
|
'config_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to config file for training.',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'data_path',
|
'data_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='Data path for wav files - directory or CSV file')
|
help='Data path for wav files - directory or CSV file')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"config_path", type=str, help="Path to config file for training.",
|
'output_path',
|
||||||
)
|
type=str,
|
||||||
|
help='path for training outputs.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"data_path", type=str, help="Defines the data path. It overwrites config.json."
|
'--use_cuda', type=bool, help='flag to set cuda.', default=False
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
|
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
|
||||||
|
@ -30,7 +39,7 @@ args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
ap = AudioProcessor(**c["audio"])
|
ap = AudioProcessor(**c['audio'])
|
||||||
|
|
||||||
data_path = args.data_path
|
data_path = args.data_path
|
||||||
split_ext = os.path.splitext(data_path)
|
split_ext = os.path.splitext(data_path)
|
||||||
|
@ -65,7 +74,7 @@ for output_file in output_files:
|
||||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
|
||||||
model = SpeakerEncoder(**c.model)
|
model = SpeakerEncoder(**c.model)
|
||||||
model.load_state_dict(torch.load(args.model_path)["model"])
|
model.load_state_dict(torch.load(args.model_path)['model'])
|
||||||
model.eval()
|
model.eval()
|
||||||
if args.use_cuda:
|
if args.use_cuda:
|
||||||
model.cuda()
|
model.cuda()
|
||||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue