Created
March 29, 2024 02:52
-
-
Save awni/2f5e678606d2576d1913a2ed87dbebea to your computer and use it in GitHub Desktop.
Benchmark Mistral Graph Construction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import mlx.core as mx | |
import mlx.nn as nn | |
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple, Union | |
@dataclass | |
class ModelArgs: | |
hidden_size: int = 4096 | |
num_hidden_layers: int = 32 | |
intermediate_size: int = 14336 | |
num_attention_heads: int = 32 | |
rms_norm_eps: float = 1e-5 | |
vocab_size: int = 32000 | |
num_key_value_heads: int = 8 | |
rope_theta: float = 10000 | |
rope_traditional: bool = False | |
class Attention(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
dim = args.hidden_size | |
self.n_heads = n_heads = args.num_attention_heads | |
self.n_kv_heads = n_kv_heads = args.num_key_value_heads | |
head_dim = args.hidden_size // n_heads | |
self.scale = head_dim**-0.5 | |
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) | |
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | |
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | |
self.rope = nn.RoPE( | |
head_dim, | |
traditional=args.rope_traditional, | |
base=args.rope_theta, | |
) | |
def __call__( | |
self, | |
x: mx.array, | |
mask: Optional[mx.array] = None, | |
cache: Optional[Tuple[mx.array, mx.array]] = None, | |
) -> mx.array: | |
B, L, D = x.shape | |
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | |
# Prepare the queries, keys and values for the attention computation | |
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | |
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | |
if cache is not None: | |
key_cache, value_cache = cache | |
queries = self.rope(queries, offset=key_cache.shape[2]) | |
keys = self.rope(keys, offset=key_cache.shape[2]) | |
keys = mx.concatenate([key_cache, keys], axis=2) | |
values = mx.concatenate([value_cache, values], axis=2) | |
else: | |
queries = self.rope(queries) | |
keys = self.rope(keys) | |
output = mx.fast.scaled_dot_product_attention( | |
queries, keys, values, scale=self.scale, mask=mask | |
) | |
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | |
return self.o_proj(output), (keys, values) | |
class MLP(nn.Module): | |
def __init__(self, dim, hidden_dim): | |
super().__init__() | |
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) | |
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | |
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | |
def __call__(self, x) -> mx.array: | |
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.num_attention_heads = args.num_attention_heads | |
self.hidden_size = args.hidden_size | |
self.self_attn = Attention(args) | |
self.mlp = MLP(args.hidden_size, args.intermediate_size) | |
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |
self.post_attention_layernorm = nn.RMSNorm( | |
args.hidden_size, eps=args.rms_norm_eps | |
) | |
self.args = args | |
def __call__( | |
self, | |
x: mx.array, | |
mask: Optional[mx.array] = None, | |
cache: Optional[Tuple[mx.array, mx.array]] = None, | |
) -> mx.array: | |
r, cache = self.self_attn(self.input_layernorm(x), mask, cache) | |
h = x + r | |
r = self.mlp(self.post_attention_layernorm(h)) | |
out = h + r | |
return out, cache | |
class Mistral(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.vocab_size = args.vocab_size | |
self.num_hidden_layers = args.num_hidden_layers | |
assert self.vocab_size > 0 | |
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | |
self.layers = [ | |
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) | |
] | |
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | |
def __call__( | |
self, | |
inputs: mx.array, | |
cache=None, | |
): | |
h = self.embed_tokens(inputs) | |
mask = None | |
if h.shape[1] > 1: | |
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | |
mask = mask.astype(h.dtype) | |
if cache is None: | |
cache = [None] * len(self.layers) | |
for e, layer in enumerate(self.layers): | |
h, cache[e] = layer(h, mask, cache[e]) | |
return self.norm(h), cache | |
class Model(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.model = Mistral(args) | |
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) | |
def __call__( | |
self, | |
inputs: mx.array, | |
cache=None, | |
): | |
out, cache = self.model(inputs, cache) | |
return self.lm_head(out), cache | |
def generate_step(prompt, model): | |
y = prompt | |
cache = None | |
while True: | |
logits, cache = model(y[None], cache=cache) | |
logits = logits[:, -1, :] | |
y = mx.argmax(logits, axis=-1) | |
yield y | |
def main(): | |
model = Model(ModelArgs()) | |
nn.QuantizedLinear.quantize_module( | |
model, | |
bits=4, | |
group_size=32, | |
) | |
prompt = mx.array([0]*500) | |
logits, cache = model(prompt[None], cache=None) | |
logits = mx.argmax(logits[:, -1, :], axis=-1) | |
def step(logits, cache): | |
return model(logits[None], cache=cache) | |
print("Warmup", flush=True) | |
for _ in range(5): | |
y, _ = step(logits, cache) | |
print("Timing", flush=True) | |
tic = time.time() | |
for _ in range(200): | |
y, _ = step(logits, cache) | |
toc = time.time() | |
tps = 200 / (toc - tic) | |
ms = 1e3 * (toc - tic) / 200 | |
print(f"Time {tps:.3f} (tps)") | |
print(f"Time {ms:.3f} (ms)") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment