From 2bbcb558dc74950f9555d705b716328479e3e0ac Mon Sep 17 00:00:00 2001 From: WeberJulian Date: Sun, 28 Nov 2021 00:48:53 +0100 Subject: [PATCH] Prevent weighted sampler use when num_gpus > 1 --- TTS/tts/models/base_tts.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index f1fdbd33..1f92bfc7 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -102,7 +102,7 @@ class BaseTTS(BaseModel): config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512 ) # init speaker embedding layer - if config.use_speaker_embedding and not config.use_d_vector_file:: + if config.use_speaker_embedding and not config.use_d_vector_file: print(" > Init speaker_embedding layer.") self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) @@ -337,8 +337,15 @@ class BaseTTS(BaseModel): if config.compute_f0: dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None)) + + # sampler for DDP sampler = DistributedSampler(dataset) if num_gpus > 1 else None + + # Weighted samplers + assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler" + assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler" + if sampler is None: if getattr(config, "use_language_weighted_sampler", False): print(" > Using Language weighted sampler")