Merge branch 'suno-ai:main' into update_readme_with_transformers

pull/391/head
Yoach Lacombe 2023-07-18 10:26:26 +02:00 committed by GitHub
commit 309ee029d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -26,9 +26,9 @@ class NonCausalSelfAttention(nn.Module):
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = (
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
hasattr(torch.nn.functional, "scaled_dot_product_attention")
)
def forward(self, x):