Skip to content

Commit

Permalink
gpt_bigcode: fixed wrong indentation (#1376)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonchar authored Sep 30, 2024
1 parent e9f8388 commit d8e0adf
Showing 1 changed file with 26 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,32 +134,32 @@ def apply_FusedSDPA(
else:
use_causal_mask = self.is_causal and attention_mask is None and query_length > 1

if query_length > 8192:
sdpa_result = self.gaudi_flash_attn_v1(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
enable_recompute,
self.block_size,
)
htcore.mark_step()
else:
sdpa_result = self.fused_scaled_dot_product_attention(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
enable_recompute,
)
if query_length > 8192:
sdpa_result = self.gaudi_flash_attn_v1(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
enable_recompute,
self.block_size,
)
htcore.mark_step()
else:
sdpa_result = self.fused_scaled_dot_product_attention(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
enable_recompute,
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
Expand Down

0 comments on commit d8e0adf

Please sign in to comment.