fix: use sdpa if dropout > 0 on fine model

pull/364/head
Michael Wei 2023-06-20 11:37:21 -07:00
parent 6cd7f0ccd7
commit b56c8df48a
1 changed files with 2 additions and 2 deletions

View File

@ -26,9 +26,9 @@ class NonCausalSelfAttention(nn.Module):
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = (
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
hasattr(torch.nn.functional, "scaled_dot_product_attention")
)
def forward(self, x):