mirror of https://github.com/suno-ai/bark.git
fix: use sdpa if dropout > 0 on fine model
parent
6cd7f0ccd7
commit
b56c8df48a
|
@ -26,9 +26,9 @@ class NonCausalSelfAttention(nn.Module):
|
||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
||||||
self.flash = (
|
self.flash = (
|
||||||
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
Loading…
Reference in New Issue