Created
December 29, 2023 18:20
-
-
Save robertknight/d95b9a6c6ac79ef8bf64cea9d534b177 to your computer and use it in GitHub Desktop.
Annotated AVX-512 gemm kernel
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# At entry params are: | |
# | |
# tile_ptr (rdi) | |
# tile_row_stride (rsi) | |
# a (rdx, rcx) | |
# b (r8, r9) | |
# depth (stack) | |
# alpha (xmm0) | |
# beta (xmm1) | |
.section __TEXT,__text,regular,pure_instructions | |
.p2align 4, 0x90 | |
wasnn::gemm::kernels::x64::Avx512Kernel::kernel_avx_512: | |
Lfunc_begin758: | |
.cfi_startproc | |
push rbp | |
.cfi_def_cfa_offset 16 | |
.cfi_offset rbp, -16 | |
mov rbp, rsp | |
.cfi_def_cfa_register rbp | |
mov rax, qword ptr [rbp + 16] # Load `depth` into rax | |
lea r10, [rax + rax] | |
lea r10, [r10 + 2*r10] # Set r10 = `depth * MR` (where MR == 6) | |
cmp r10, rcx # Compare `a.len()` with `depth * MR` | |
ja LBB758_15 | |
mov rcx, rax # Set rcx = `depth * NR` (where NR == 2) | |
shl rcx, 5 | |
cmp rcx, r9 | |
ja LBB758_16 | |
test rax, rax # Check if there are any loop iterations | |
je LBB758_3 | |
add rdx, 20 | |
add r8, 64 | |
# Clear registers that hold `tmp`. The registers that hold `b_rows` don't | |
# need to be cleared because they are dead stores. | |
vxorps xmm2, xmm2, xmm2 | |
vxorps xmm3, xmm3, xmm3 | |
vxorps xmm4, xmm4, xmm4 | |
vxorps xmm5, xmm5, xmm5 | |
vxorps xmm6, xmm6, xmm6 | |
vxorps xmm7, xmm7, xmm7 | |
vxorps xmm8, xmm8, xmm8 | |
vxorps xmm9, xmm9, xmm9 | |
vxorps xmm10, xmm10, xmm10 | |
vxorps xmm11, xmm11, xmm11 | |
vxorps xmm12, xmm12, xmm12 | |
vxorps xmm13, xmm13, xmm13 | |
.p2align 4, 0x90 | |
LBB758_8: | |
# Load `b_rows[i]` | |
vmovups zmm14, zmmword ptr [r8 - 64] | |
vmovups zmm15, zmmword ptr [r8] | |
# tmp[i][j] = fmadd(broadcast(a[i]), b_rows[j]) | |
vbroadcastss zmm16, dword ptr [rdx - 20] | |
vfmadd231ps zmm13, zmm16, zmm14 | |
vfmadd231ps zmm12, zmm15, zmm16 | |
vbroadcastss zmm16, dword ptr [rdx - 16] | |
vfmadd231ps zmm11, zmm16, zmm14 | |
vfmadd231ps zmm10, zmm15, zmm16 | |
vbroadcastss zmm16, dword ptr [rdx - 12] | |
vfmadd231ps zmm9, zmm16, zmm14 | |
vfmadd231ps zmm8, zmm15, zmm16 | |
vbroadcastss zmm16, dword ptr [rdx - 8] | |
vfmadd231ps zmm7, zmm16, zmm14 | |
vfmadd231ps zmm6, zmm15, zmm16 | |
vbroadcastss zmm16, dword ptr [rdx - 4] | |
vfmadd231ps zmm5, zmm16, zmm14 | |
vfmadd231ps zmm4, zmm15, zmm16 | |
vbroadcastss zmm16, dword ptr [rdx] | |
vfmadd231ps zmm3, zmm16, zmm14 | |
vfmadd231ps zmm2, zmm16, zmm15 | |
add rdx, 24 | |
sub r8, -128 | |
dec rax | |
jne LBB758_8 # Jump to top of depth loop if not final iteration | |
vucomiss xmm0, dword ptr [rip + LCPI758_0] | |
jne LBB758_9 | |
jnp LBB758_5 | |
jmp LBB758_9 | |
LBB758_3: | |
vxorps xmm2, xmm2, xmm2 | |
vxorps xmm3, xmm3, xmm3 | |
vxorps xmm4, xmm4, xmm4 | |
vxorps xmm5, xmm5, xmm5 | |
vxorps xmm6, xmm6, xmm6 | |
vxorps xmm7, xmm7, xmm7 | |
vxorps xmm8, xmm8, xmm8 | |
vxorps xmm9, xmm9, xmm9 | |
vxorps xmm10, xmm10, xmm10 | |
vxorps xmm11, xmm11, xmm11 | |
vxorps xmm12, xmm12, xmm12 | |
vxorps xmm13, xmm13, xmm13 | |
# Test if `alpha == 1` | |
vucomiss xmm0, dword ptr [rip + LCPI758_0] | |
jne LBB758_9 | |
jp LBB758_9 | |
LBB758_5: | |
# Test if `beta == 0` | |
vxorps xmm14, xmm14, xmm14 | |
vucomiss xmm1, xmm14 | |
jne LBB758_9 | |
jp LBB758_9 | |
# Store `tmp[i][j]` to `tile_ptr` | |
vmovups zmmword ptr [rdi], zmm13 | |
vmovups zmmword ptr [rdi + 64], zmm12 | |
vmovups zmmword ptr [rdi + 4*rsi], zmm11 | |
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10 | |
vmovups zmmword ptr [rdi + 8*rsi], zmm9 | |
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8 | |
lea rax, [rsi + 2*rsi] | |
vmovups zmmword ptr [rdi + 4*rax], zmm7 | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm6 | |
lea rax, [rsi + 4*rsi] | |
shl rsi, 4 | |
vmovups zmmword ptr [rdi + rsi], zmm5 | |
vmovups zmmword ptr [rdi + rsi + 64], zmm4 | |
vmovups zmmword ptr [rdi + 4*rax], zmm3 | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm2 | |
pop rbp | |
# See https://community.intel.com/t5/Intel-ISA-Extensions/What-is-the-status-of-VZEROUPPER-use/m-p/1098375 | |
vzeroupper | |
ret | |
LBB758_9: | |
# Check if `beta == 1 && alpha == 1` | |
vucomiss xmm0, dword ptr [rip + LCPI758_0] | |
jne LBB758_12 | |
jp LBB758_12 | |
vucomiss xmm1, dword ptr [rip + LCPI758_0] | |
jne LBB758_12 | |
jp LBB758_12 | |
vaddps zmm0, zmm13, zmmword ptr [rdi] | |
vmovups zmmword ptr [rdi], zmm0 | |
vaddps zmm0, zmm12, zmmword ptr [rdi + 64] | |
vmovups zmmword ptr [rdi + 64], zmm0 | |
vaddps zmm0, zmm11, zmmword ptr [rdi + 4*rsi] | |
vmovups zmmword ptr [rdi + 4*rsi], zmm0 | |
vaddps zmm0, zmm10, zmmword ptr [rdi + 4*rsi + 64] | |
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm0 | |
vaddps zmm0, zmm9, zmmword ptr [rdi + 8*rsi] | |
vmovups zmmword ptr [rdi + 8*rsi], zmm0 | |
vaddps zmm0, zmm8, zmmword ptr [rdi + 8*rsi + 64] | |
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm0 | |
lea rax, [rsi + 2*rsi] | |
vaddps zmm0, zmm7, zmmword ptr [rdi + 4*rax] | |
vmovups zmmword ptr [rdi + 4*rax], zmm0 | |
vaddps zmm0, zmm6, zmmword ptr [rdi + 4*rax + 64] | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm0 | |
lea rax, [rsi + 4*rsi] | |
shl rsi, 4 | |
vaddps zmm0, zmm5, zmmword ptr [rdi + rsi] | |
vmovups zmmword ptr [rdi + rsi], zmm0 | |
vaddps zmm0, zmm4, zmmword ptr [rdi + rsi + 64] | |
vmovups zmmword ptr [rdi + rsi + 64], zmm0 | |
vaddps zmm0, zmm3, zmmword ptr [rdi + 4*rax] | |
vmovups zmmword ptr [rdi + 4*rax], zmm0 | |
vaddps zmm0, zmm2, zmmword ptr [rdi + 4*rax + 64] | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm0 | |
pop rbp | |
vzeroupper | |
ret | |
LBB758_12: | |
vbroadcastss zmm0, xmm0 | |
vbroadcastss zmm1, xmm1 | |
vmulps zmm14, zmm1, zmmword ptr [rdi] | |
vfmadd213ps zmm13, zmm0, zmm14 | |
vmovups zmmword ptr [rdi], zmm13 | |
vmulps zmm13, zmm1, zmmword ptr [rdi + 64] | |
vfmadd213ps zmm12, zmm0, zmm13 | |
vmovups zmmword ptr [rdi + 64], zmm12 | |
vmulps zmm12, zmm1, zmmword ptr [rdi + 4*rsi] | |
vfmadd213ps zmm11, zmm0, zmm12 | |
vmovups zmmword ptr [rdi + 4*rsi], zmm11 | |
vmulps zmm11, zmm1, zmmword ptr [rdi + 4*rsi + 64] | |
vfmadd213ps zmm10, zmm0, zmm11 | |
vmovups zmmword ptr [rdi + 4*rsi + 64], zmm10 | |
vmulps zmm10, zmm1, zmmword ptr [rdi + 8*rsi] | |
vfmadd213ps zmm9, zmm0, zmm10 | |
vmovups zmmword ptr [rdi + 8*rsi], zmm9 | |
vmulps zmm9, zmm1, zmmword ptr [rdi + 8*rsi + 64] | |
vfmadd213ps zmm8, zmm0, zmm9 | |
vmovups zmmword ptr [rdi + 8*rsi + 64], zmm8 | |
lea rax, [rsi + 2*rsi] | |
vmulps zmm8, zmm1, zmmword ptr [rdi + 4*rax] | |
vfmadd213ps zmm7, zmm0, zmm8 | |
vmovups zmmword ptr [rdi + 4*rax], zmm7 | |
vmulps zmm7, zmm1, zmmword ptr [rdi + 4*rax + 64] | |
vfmadd213ps zmm6, zmm0, zmm7 | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm6 | |
lea rax, [rsi + 4*rsi] | |
shl rsi, 4 | |
vmulps zmm6, zmm1, zmmword ptr [rdi + rsi] | |
vfmadd213ps zmm5, zmm0, zmm6 | |
vmovups zmmword ptr [rdi + rsi], zmm5 | |
vmulps zmm5, zmm1, zmmword ptr [rdi + rsi + 64] | |
vfmadd213ps zmm4, zmm0, zmm5 | |
vmovups zmmword ptr [rdi + rsi + 64], zmm4 | |
vmulps zmm4, zmm1, zmmword ptr [rdi + 4*rax] | |
vfmadd213ps zmm3, zmm0, zmm4 | |
vmovups zmmword ptr [rdi + 4*rax], zmm3 | |
vmulps zmm1, zmm1, zmmword ptr [rdi + 4*rax + 64] | |
vfmadd213ps zmm2, zmm0, zmm1 | |
vmovups zmmword ptr [rdi + 4*rax + 64], zmm2 | |
pop rbp | |
vzeroupper | |
ret | |
LBB758_15: | |
lea rdi, [rip + l___unnamed_545] | |
lea rdx, [rip + l___unnamed_551] | |
mov esi, 39 | |
call core::panicking::panic | |
LBB758_16: | |
lea rdi, [rip + l___unnamed_547] | |
lea rdx, [rip + l___unnamed_552] | |
mov esi, 39 | |
call core::panicking::panic |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment