mirror of https://github.com/coqui-ai/TTS.git
Syncronize DDP processes
parent
ecf5f17dca
commit
7c0d564965
|
@ -84,7 +84,10 @@ class TrainingArgs(Coqpit):
|
||||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||||
use_ddp: bool= field(default=False, metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."})
|
use_ddp: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use DDP in distributed training. It is to set in `distribute.py`. Do not set manually."},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Trainer:
|
class Trainer:
|
||||||
|
@ -362,7 +365,9 @@ class Trainer:
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
if num_gpus > 1:
|
if num_gpus > 1:
|
||||||
if hasattr(model.module, "get_data_loader"):
|
if hasattr(model.module, "get_data_loader"):
|
||||||
loader = model.module.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank)
|
loader = model.module.get_data_loader(
|
||||||
|
config, ap, is_eval, data_items, verbose, num_gpus, self.args.rank
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if hasattr(model, "get_data_loader"):
|
if hasattr(model, "get_data_loader"):
|
||||||
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
loader = model.get_data_loader(config, ap, is_eval, data_items, verbose, num_gpus)
|
||||||
|
@ -797,6 +802,7 @@ class Trainer:
|
||||||
loader_time = time.time() - loader_start_time
|
loader_time = time.time() - loader_start_time
|
||||||
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
self.keep_avg_eval.update_values({"avg_loader_time": loader_time})
|
||||||
outputs, _ = self.eval_step(batch, cur_step)
|
outputs, _ = self.eval_step(batch, cur_step)
|
||||||
|
loader_start_time = time.time()
|
||||||
# plot epoch stats, artifacts and figures
|
# plot epoch stats, artifacts and figures
|
||||||
if self.args.rank == 0:
|
if self.args.rank == 0:
|
||||||
figures, audios = None, None
|
figures, audios = None, None
|
||||||
|
@ -839,7 +845,7 @@ class Trainer:
|
||||||
self.total_steps_done = self.restore_step
|
self.total_steps_done = self.restore_step
|
||||||
|
|
||||||
for epoch in range(0, self.config.epochs):
|
for epoch in range(0, self.config.epochs):
|
||||||
if self.num_gpus:
|
if self.num_gpus > 1:
|
||||||
# let all processes sync up before starting with a new epoch of training
|
# let all processes sync up before starting with a new epoch of training
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
self.callbacks.on_epoch_start()
|
self.callbacks.on_epoch_start()
|
||||||
|
@ -868,6 +874,9 @@ class Trainer:
|
||||||
self.callbacks.on_keyboard_interrupt()
|
self.callbacks.on_keyboard_interrupt()
|
||||||
# if the output folder is empty remove the run.
|
# if the output folder is empty remove the run.
|
||||||
remove_experiment_folder(self.output_path)
|
remove_experiment_folder(self.output_path)
|
||||||
|
# clear the DDP processes
|
||||||
|
if self.num_gpus > 1:
|
||||||
|
dist.destroy_process_group()
|
||||||
# finish the wandb run and sync data
|
# finish the wandb run and sync data
|
||||||
self.dashboard_logger.finish()
|
self.dashboard_logger.finish()
|
||||||
# stop without error signal
|
# stop without error signal
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coqpit import Coqpit
|
from coqpit import Coqpit
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -164,7 +165,14 @@ class BaseTTS(BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_data_loader(
|
def get_data_loader(
|
||||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int, rank: int=None
|
self,
|
||||||
|
config: Coqpit,
|
||||||
|
ap: AudioProcessor,
|
||||||
|
is_eval: bool,
|
||||||
|
data_items: List,
|
||||||
|
verbose: bool,
|
||||||
|
num_gpus: int,
|
||||||
|
rank: int = None,
|
||||||
) -> "DataLoader":
|
) -> "DataLoader":
|
||||||
if is_eval and not config.run_eval:
|
if is_eval and not config.run_eval:
|
||||||
loader = None
|
loader = None
|
||||||
|
@ -228,6 +236,10 @@ class BaseTTS(BaseModel):
|
||||||
else:
|
else:
|
||||||
self.train_data_items = dataset.items
|
self.train_data_items = dataset.items
|
||||||
|
|
||||||
|
# halt DDP processes for the main process to finish computing the phoneme cache
|
||||||
|
if num_gpus > 1:
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
dataset.sort_items()
|
dataset.sort_items()
|
||||||
|
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -9,7 +9,8 @@ from TTS.utils.training import NoamLR
|
||||||
def is_apex_available():
|
def is_apex_available():
|
||||||
return importlib.util.find_spec("apex") is not None
|
return importlib.util.find_spec("apex") is not None
|
||||||
|
|
||||||
def setup_torch_training_env(cudnn_enable:bool, cudnn_benchmark:bool, use_ddp:bool=False) -> Tuple[bool, int]:
|
|
||||||
|
def setup_torch_training_env(cudnn_enable: bool, cudnn_benchmark: bool, use_ddp: bool = False) -> Tuple[bool, int]:
|
||||||
"""Setup PyTorch environment for training.
|
"""Setup PyTorch environment for training.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
Loading…
Reference in New Issue