mirror of https://github.com/coqui-ai/TTS.git
Prevent weighted sampler use when num_gpus > 1
parent
74cedfac38
commit
2bbcb558dc
|
@ -102,7 +102,7 @@ class BaseTTS(BaseModel):
|
||||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||||
)
|
)
|
||||||
# init speaker embedding layer
|
# init speaker embedding layer
|
||||||
if config.use_speaker_embedding and not config.use_d_vector_file::
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||||
print(" > Init speaker_embedding layer.")
|
print(" > Init speaker_embedding layer.")
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
@ -337,8 +337,15 @@ class BaseTTS(BaseModel):
|
||||||
if config.compute_f0:
|
if config.compute_f0:
|
||||||
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
dataset.pitch_extractor.load_pitch_stats(config.get("f0_cache_path", None))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# sampler for DDP
|
# sampler for DDP
|
||||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||||
|
|
||||||
|
# Weighted samplers
|
||||||
|
assert not (num_gpus > 1 and getattr(config, "use_language_weighted_sampler", False)), "language_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
assert not (num_gpus > 1 and getattr(config, "use_speaker_weighted_sampler", False)), "speaker_weighted_sampler is not supported with DistributedSampler"
|
||||||
|
|
||||||
if sampler is None:
|
if sampler is None:
|
||||||
if getattr(config, "use_language_weighted_sampler", False):
|
if getattr(config, "use_language_weighted_sampler", False):
|
||||||
print(" > Using Language weighted sampler")
|
print(" > Using Language weighted sampler")
|
||||||
|
|
Loading…
Reference in New Issue