Drop fairseq for Hubert

pull/2700/head
Eren G??lge 2023-06-26 19:27:48 +02:00
parent c03768bb53
commit 17ac188958
5 changed files with 16 additions and 33 deletions

View File

@ -23,7 +23,7 @@ colormap = (
[0, 0, 0],
[183, 183, 183],
],
dtype=np.float,
dtype=float,
)
/ 255
)

View File

@ -1,5 +1,5 @@
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Dict
from TTS.tts.configs.shared_configs import BaseTTSConfig
@ -46,11 +46,11 @@ class BarkConfig(BaseTTSConfig):
"""
model: str = "bark"
audio: BarkAudioConfig = BarkAudioConfig()
audio: BarkAudioConfig = field(default_factory=BarkAudioConfig)
num_chars: int = 0
semantic_config: GPTConfig = GPTConfig()
fine_config: FineGPTConfig = FineGPTConfig()
coarse_config: GPTConfig = GPTConfig()
semantic_config: GPTConfig = field(default_factory=GPTConfig)
fine_config: FineGPTConfig = field(default_factory=FineGPTConfig)
coarse_config: GPTConfig = field(default_factory=GPTConfig)
CONTEXT_WINDOW_SIZE: int = 1024
SEMANTIC_RATE_HZ: float = 49.9
SEMANTIC_VOCAB_SIZE: int = 10_000

View File

@ -10,11 +10,11 @@ License: MIT
import logging
from pathlib import Path
import fairseq
import torch
from einops import pack, unpack
from torch import nn
from torchaudio.functional import resample
from transformers import HubertModel
logging.root.setLevel(logging.ERROR)
@ -49,22 +49,11 @@ class CustomHubert(nn.Module):
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
if device is not None:
self.to(device)
model_path = Path(checkpoint_path)
assert model_path.exists(), f"path {checkpoint_path} does not exist"
checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
self.model = HubertModel.from_pretrained("facebook/hubert-base-ls960")
if device is not None:
model[0].to(device)
self.model = model[0]
self.model.to(device)
self.model.eval()
@property
@ -81,19 +70,13 @@ class CustomHubert(nn.Module):
if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
embed = self.model(
outputs = self.model.forward(
wav_input,
features_only=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer,
output_hidden_states=True,
)
embed, packed_shape = pack([embed["x"]], "* d")
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
embed = outputs["hidden_states"][self.output_layer]
embed, packed_shape = pack([embed], "* d")
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device)
if flatten:
return codebook_indices

View File

@ -130,7 +130,7 @@ def generate_voice(
# generate semantic tokens
# Load the HuBERT model
hubert_manager = HubertManager()
hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)

View File

@ -207,7 +207,7 @@ def maximum_path_numpy(value, mask, max_neg_val=None):
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(np.bool)
mask = mask.cpu().detach().numpy().astype(bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)