1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| def torch_attn( q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, *args, **kwargs ): batch_size, seq_len, hs, hd = q.size() query = q.view(batch_size, -1, hs, hd).transpose(1, 2) key = k.view(batch_size, -1, hs, hd).transpose(1, 2) value = v.view(batch_size, -1, hs, hd).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention( query, key, value, dropout_p=dropout_p, is_causal=causal )
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, hs, hd) hidden_states = hidden_states.to(query.dtype) return hidden_states
|