mirror of https://github.com/suno-ai/bark.git
Merge branch 'suno-ai:main' into update_readme_with_transformers
commit
309ee029d7
|
@ -26,9 +26,9 @@ class NonCausalSelfAttention(nn.Module):
|
|||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.dropout = config.dropout
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
||||
self.flash = (
|
||||
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
||||
hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in New Issue