mirror of https://github.com/coqui-ai/TTS.git
Add max_audio_length parameter
parent
ceb8b05abe
commit
1a60767d83
|
@ -8,7 +8,7 @@ from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrai
|
|||
from TTS.utils.manage import ModelManager
|
||||
|
||||
|
||||
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path):
|
||||
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
|
||||
# Logging parameters
|
||||
RUN_NAME = "GPT_XTTS_FT"
|
||||
PROJECT_NAME = "XTTS_trainer"
|
||||
|
@ -79,7 +79,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
|
|||
max_conditioning_length=132300, # 6 secs
|
||||
min_conditioning_length=66150, # 3 secs
|
||||
debug_loading_failures=False,
|
||||
max_wav_length=255995, # ~11.6 seconds
|
||||
max_wav_length=max_audio_length, # ~11.6 seconds
|
||||
max_text_length=200,
|
||||
mel_norm_file=MEL_NORM_FILE,
|
||||
dvae_checkpoint=DVAE_CHECKPOINT,
|
||||
|
|
|
@ -147,6 +147,13 @@ if __name__ == "__main__":
|
|||
help="Grad accumulation steps. Default: 1",
|
||||
default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_audio_length",
|
||||
type=int,
|
||||
help="Max permitted audio size in seconds. Default: 11",
|
||||
default=11,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with gr.Blocks() as demo:
|
||||
|
@ -250,6 +257,13 @@ if __name__ == "__main__":
|
|||
step=1,
|
||||
value=args.grad_acumm,
|
||||
)
|
||||
max_audio_length = gr.Slider(
|
||||
label="Max permitted audio size in seconds:",
|
||||
minimum=2,
|
||||
maximum=20,
|
||||
step=1,
|
||||
value=args.max_audio_length,
|
||||
)
|
||||
progress_train = gr.Label(
|
||||
label="Progress:"
|
||||
)
|
||||
|
@ -260,12 +274,14 @@ if __name__ == "__main__":
|
|||
demo.load(read_logs, None, logs_tts_train, every=1)
|
||||
train_btn = gr.Button(value="Step 2 - Run the training")
|
||||
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path):
|
||||
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
||||
clear_gpu_cache()
|
||||
if not train_csv or not eval_csv:
|
||||
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
||||
try:
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path)
|
||||
# convert seconds to waveform frames
|
||||
max_audio_length = int(max_audio_length * 22050)
|
||||
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
error = traceback.format_exc()
|
||||
|
@ -280,7 +296,6 @@ if __name__ == "__main__":
|
|||
clear_gpu_cache()
|
||||
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
||||
|
||||
|
||||
with gr.Tab("3 - Inference"):
|
||||
with gr.Row():
|
||||
with gr.Column() as col1:
|
||||
|
@ -367,6 +382,7 @@ if __name__ == "__main__":
|
|||
batch_size,
|
||||
grad_acumm,
|
||||
out_path,
|
||||
max_audio_length,
|
||||
],
|
||||
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue