diff --git a/TTS/encoder/utils/visual.py b/TTS/encoder/utils/visual.py index f2db2f3f..6575b86e 100644 --- a/TTS/encoder/utils/visual.py +++ b/TTS/encoder/utils/visual.py @@ -23,7 +23,7 @@ colormap = ( [0, 0, 0], [183, 183, 183], ], - dtype=np.float, + dtype=float, ) / 255 ) diff --git a/TTS/tts/configs/bark_config.py b/TTS/tts/configs/bark_config.py index 647116bd..4d1cd137 100644 --- a/TTS/tts/configs/bark_config.py +++ b/TTS/tts/configs/bark_config.py @@ -1,5 +1,5 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Dict from TTS.tts.configs.shared_configs import BaseTTSConfig @@ -46,11 +46,11 @@ class BarkConfig(BaseTTSConfig): """ model: str = "bark" - audio: BarkAudioConfig = BarkAudioConfig() + audio: BarkAudioConfig = field(default_factory=BarkAudioConfig) num_chars: int = 0 - semantic_config: GPTConfig = GPTConfig() - fine_config: FineGPTConfig = FineGPTConfig() - coarse_config: GPTConfig = GPTConfig() + semantic_config: GPTConfig = field(default_factory=GPTConfig) + fine_config: FineGPTConfig = field(default_factory=FineGPTConfig) + coarse_config: GPTConfig = field(default_factory=GPTConfig) CONTEXT_WINDOW_SIZE: int = 1024 SEMANTIC_RATE_HZ: float = 49.9 SEMANTIC_VOCAB_SIZE: int = 10_000 diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index 7c667755..ee544ee1 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -10,11 +10,11 @@ License: MIT import logging from pathlib import Path -import fairseq import torch from einops import pack, unpack from torch import nn from torchaudio.functional import resample +from transformers import HubertModel logging.root.setLevel(logging.ERROR) @@ -49,22 +49,11 @@ class CustomHubert(nn.Module): self.target_sample_hz = target_sample_hz self.seq_len_multiple_of = seq_len_multiple_of self.output_layer = output_layer - if device is not None: self.to(device) - - model_path = Path(checkpoint_path) - - assert model_path.exists(), f"path {checkpoint_path} does not exist" - - checkpoint = torch.load(checkpoint_path) - load_model_input = {checkpoint_path: checkpoint} - model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) - + self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960") if device is not None: - model[0].to(device) - - self.model = model[0] + self.model.to(device) self.model.eval() @property @@ -81,19 +70,13 @@ class CustomHubert(nn.Module): if exists(self.seq_len_multiple_of): wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) - embed = self.model( + outputs = self.model.forward( wav_input, - features_only=True, - mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code - output_layer=self.output_layer, + output_hidden_states=True, ) - - embed, packed_shape = pack([embed["x"]], "* d") - - # codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) - - codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() - + embed = outputs["hidden_states"][self.output_layer] + embed, packed_shape = pack([embed], "* d") + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) if flatten: return codebook_indices diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py index fa7a1ebf..da962ab1 100644 --- a/TTS/tts/layers/bark/inference_funcs.py +++ b/TTS/tts/layers/bark/inference_funcs.py @@ -130,7 +130,7 @@ def generate_voice( # generate semantic tokens # Load the HuBERT model hubert_manager = HubertManager() - hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) + # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"]) hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"]) hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device) diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 56ef2944..c6d1ec2c 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -207,7 +207,7 @@ def maximum_path_numpy(value, mask, max_neg_val=None): device = value.device dtype = value.dtype value = value.cpu().detach().numpy() - mask = mask.cpu().detach().numpy().astype(np.bool) + mask = mask.cpu().detach().numpy().astype(bool) b, t_x, t_y = value.shape direction = np.zeros(value.shape, dtype=np.int64)