Skip to content

Instantly share code, notes, and snippets.

@victoroliv2
Created March 21, 2024 01:58
Show Gist options
  • Save victoroliv2/3668f07e11a0757febb6e55a8d78592a to your computer and use it in GitHub Desktop.
Save victoroliv2/3668f07e11a0757febb6e55a8d78592a to your computer and use it in GitHub Desktop.
pytorch_fmha_nested_tensor.py
import torch
BATCH = 4
EMB_DIM = 256
HEADS = 8
Q_TOKENS = 512
KV_TOKENS = 16384
q_proj = torch.nested.nested_tensor([torch.zeros(HEADS, Q_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")
k_proj = torch.nested.nested_tensor([torch.zeros(HEADS, KV_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")
v_proj = torch.nested.nested_tensor([torch.zeros(HEADS, KV_TOKENS // (i+1), EMB_DIM) for i in range(BATCH)], dtype=torch.half, device="cuda")
def trace_ready(p):
print('trace ready!')
import os
os.system('rm chrome_trace.gz')
p.export_chrome_trace('chrome_trace')
os.system('gzip chrome_trace')
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],on_trace_ready=trace_ready,with_stack=True):
for i in range(10):
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
out = torch.nn.functional.scaled_dot_product_attention(
q_proj, k_proj, v_proj, attn_mask=None, dropout_p=0.0
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment