mirror of https://github.com/coqui-ai/TTS.git
Syncronize DDP processes
parent
ecf5f17dca
commit
7c0d564965
|
@ -84,7 +84,10 @@ class TrainingArgs(Coqpit):
|
|||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
|
||||
use_ddp: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
|
||||
)
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
@ -362,7 +365,9 @@ class Trainer:
|
|||
) -> DataLoader:
|
||||
if num_gpus > 1:
|
||||
if hasattr(model.module, "get_data_loader"):
|
||||
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
|
||||
loader = model.module.get_data_loader(
|
||||
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||
)
|
||||
else:
|
||||
if hasattr(model, "get_data_loader"):
|
||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||
|
@ -797,6 +802,7 @@ class Trainer:
|
|||
loader_time = time.time() - loader_start_time
|
||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||
outputs, _ = self.eval_step(batch, cur_step)
|
||||
loader_start_time = time.time()
|
||||
# plot epoch stats, artifacts and figures
|
||||
if self.args.rank == 0:
|
||||
figures, audios = None, None
|
||||
|
@ -839,7 +845,7 @@ class Trainer:
|
|||
self.total_steps_done = self.restore_step
|
||||
|
||||
for epoch in range(0, self.config.epochs):
|
||||
if self.num_gpus:
|
||||
if self.num_gpus > 1:
|
||||
# let all processes sync up before starting with a new epoch of training
|
||||
dist.barrier()
|
||||
self.callbacks.on_epoch_start()
|
||||
|
@ -868,6 +874,9 @@ class Trainer:
|
|||
self.callbacks.on_keyboard_interrupt()
|
||||
# if the output folder is empty remove the run.
|
||||
remove_experiment_folder(self.output_path)
|
||||
# clear the DDP processes
|
||||
if self.num_gpus > 1:
|
||||
dist.destroy_process_group()
|
||||
# finish the wandb run and sync data
|
||||
self.dashboard_logger.finish()
|
||||
# stop without error signal
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
|
|||
}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
|
||||
self,
|
||||
config: Coqpit,
|
||||
ap: AudioProcessor,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
num_gpus: int,
|
||||
rank: int = None,
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
|
@ -228,6 +236,10 @@ class BaseTTS(BaseModel):
|
|||
else:
|
||||
self.train_data_items = dataset.items
|
||||
|
||||
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||
if num_gpus > 1:
|
||||
dist.barrier()
|
||||
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import importlib
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -9,7 +9,8 @@ from TTS.utils.training import NoamLR
|
|||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
|
||||
|
||||
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
|
||||
"""Setup PyTorch environment for training.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue