diff --git a/intermediate_source/transformer_building_blocks.py b/intermediate_source/transformer_building_blocks.py index 97f7bb53e9..586423b14e 100644 --- a/intermediate_source/transformer_building_blocks.py +++ b/intermediate_source/transformer_building_blocks.py @@ -570,14 +570,16 @@ def forward(self, x): print(f"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}") out = new_mha_layer(query, key, value, is_causal=False) +# TODO: anything else I can add here? + ################################################################################ # Fully masked rows no longer cause NaNs # -------------------------------------- # -# There has been a long standing issue with ``nn.MultiheadAttention`` where if a row was -# fully masked by the key_padding_mask, the output of the attention layer would be NaN -# See `issue `_. This is because -# the softmax operation would divide by zero. +# There has been a long standing issue with ``nn.MultiheadAttention`` and +# ``scaled_dot_product_attention`` where if a row was fully masked, the output +# of the attention layer would be NaN. See `issue `_. +# This is because the softmax operation would divide by zero. # # Thanks to `this PR `_ # this is no longer the case. Instead, fully masked rows will be set to zero. More