mirror of https://github.com/suno-ai/bark.git
fix: use sdpa if dropout > 0 on fine model
parent
6cd7f0ccd7
commit
b56c8df48a
|
@ -26,9 +26,9 @@ class NonCausalSelfAttention(nn.Module):
|
|||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.dropout = config.dropout
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
||||
self.flash = (
|
||||
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
||||
hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in New Issue