pull/175/merge
Jairo Correa 2024-04-08 15:47:07 -06:00 committed by GitHub
commit b5416017f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 1 deletions

View File

@ -1,2 +1,2 @@
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt, set_seed
from .generation import SAMPLE_RATE, preload_models

View File

@ -1,6 +1,9 @@
from typing import Dict, Optional, Union
import numpy as np
import torch
import random
import os
from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
@ -83,6 +86,48 @@ def save_as_prompt(filepath, full_generation):
np.savez(filepath, **full_generation)
def set_seed(seed: int = 0):
"""Set the seed
seed = 0 Generate a random seed
seed = -1 Disable deterministic algorithms
0 < seed < 2**32 Set the seed
Args:
seed: integer to use as seed
Returns:
integer used as seed
"""
original_seed = seed
# See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
if seed == -1:
# Disable deterministic
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
# Enable deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if seed <= 0:
# Generate random seed
# Use default_rng() because it is independent of np.random.seed()
seed = np.random.default_rng().integers(1, 2**32 - 1)
assert(0 < seed and seed < 2**32)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
return original_seed if original_seed != 0 else seed
def generate_audio(
text: str,
history_prompt: Optional[Union[Dict, str]] = None,