mirror of https://github.com/suno-ai/bark.git
Merge 8f5c8638d5
into f4f32d4cd4
commit
b5416017f8
|
@ -1,2 +1,2 @@
|
|||
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt, set_seed
|
||||
from .generation import SAMPLE_RATE, preload_models
|
||||
|
|
45
bark/api.py
45
bark/api.py
|
@ -1,6 +1,9 @@
|
|||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import os
|
||||
|
||||
from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
|
||||
|
||||
|
@ -83,6 +86,48 @@ def save_as_prompt(filepath, full_generation):
|
|||
np.savez(filepath, **full_generation)
|
||||
|
||||
|
||||
def set_seed(seed: int = 0):
|
||||
"""Set the seed
|
||||
|
||||
seed = 0 Generate a random seed
|
||||
seed = -1 Disable deterministic algorithms
|
||||
0 < seed < 2**32 Set the seed
|
||||
|
||||
Args:
|
||||
seed: integer to use as seed
|
||||
|
||||
Returns:
|
||||
integer used as seed
|
||||
"""
|
||||
|
||||
original_seed = seed
|
||||
|
||||
# See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
|
||||
if seed == -1:
|
||||
# Disable deterministic
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
# Enable deterministic
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
if seed <= 0:
|
||||
# Generate random seed
|
||||
# Use default_rng() because it is independent of np.random.seed()
|
||||
seed = np.random.default_rng().integers(1, 2**32 - 1)
|
||||
|
||||
assert(0 < seed and seed < 2**32)
|
||||
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
return original_seed if original_seed != 0 else seed
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
history_prompt: Optional[Union[Dict, str]] = None,
|
||||
|
|
Loading…
Reference in New Issue