Skip to content

Commit

Permalink
fdf
Browse files Browse the repository at this point in the history
  • Loading branch information
dudulightricks committed Nov 21, 2024
1 parent a3fb4e1 commit 4ba9067
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions test/test_pallas_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,34 +112,31 @@ def test_flash_attention_spmd_data_parallel_with_segment_ids(self):
n_devices = xr.global_runtime_device_count()
xs.set_global_mesh(xs.Mesh(range(n_devices), (n_devices, 1, 1, 1)))

q = torch.randn(4, 2, 128, 4).to("xla")
k = torch.randn(4, 2, 128, 4).to("xla")
v = torch.randn(4, 2, 128, 4).to("xla")
q_segment_ids = torch.ones(4, 128, device=q.device, dtype=torch.float32).to("xla")
kv_segment_ids = torch.rand(4, 128).to("xla")
q = torch.randn(16, 32, 2048, 64).to("xla")
k = torch.randn(16, 32, 128, 64).to("xla")
v = torch.randn(16, 32, 128, 64).to("xla")
q_segment_ids = torch.ones(16, 2048, dtype=torch.float32).to("xla")
kv_segment_ids = torch.zeros(16, 1, 128, dtype=torch.float32).to("xla")
kv_segment_ids[:8, :, 30:] = -10000.0
kv_segment_ids[8:, :, 60:] = -10000.0

o = flash_attention(q, k, v, q_segment_ids, kv_segment_ids, partition_spec=range(4))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(o),
f"{{devices=[{n_devices},1,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}}")

attention_mask = F.pad(kv_segment_ids, (0, 16256), value=0.0)
attention_mask = attention_mask.repeat_interleave(2, dim=0)
attention_mask = attention_mask.view(4, 2, 128, 128)
# attention_mask = torch.ones(4, 2, 128, 128).to("xla")
# head_size = self.heads
# current_length: int = attention_mask.shape[-1]
# if current_length != target_length:
# attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

# if attention_mask.shape[0] < 4 * head_size:
# attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
#
# attention_mask = attention_mask.view(
# batch_size, attn.heads, -1, attention_mask.shape[-1]
# )
attention_mask = kv_segment_ids.repeat_interleave(32, dim=0)
attention_mask = attention_mask.view(16, 32, 1, 128)

expected_o = self._attention(q, k, v, attn_mask=attention_mask)
# expected_o = F.scaled_dot_product_attention(
# q,
# k,
# v,
# attn_mask=attention_mask,
# dropout_p=0.0,
# is_causal=False,
# )
diff = (expected_o - o).abs()
# z = torch.allclose(o.cpu(), expected_o.cpu(), atol=1e-05)

Expand Down

0 comments on commit 4ba9067

Please sign in to comment.