Skip to content

Instantly share code, notes, and snippets.

@awni
Last active December 31, 2024 22:42
Show Gist options
  • Save awni/ebd1c9faa0e33c5d924561695c15ac7e to your computer and use it in GitHub Desktop.
Save awni/ebd1c9faa0e33c5d924561695c15ac7e to your computer and use it in GitHub Desktop.
MLX Export Llama

Export Llama Inference from Python to run directly in C++.

To run, first install the requirements:

pip install -U mlx transformers fire

Then generate text from Python with:

python llama.py generate "How tall is K2?"

To export the generation function run:

python llama.py export

Then build the C++ code (requires CMake):

cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build

And run the generation from C++ with:

./llama ../llama3.1-instruct-4bit.mlxfn "How tall is K2?"
cmake_minimum_required(VERSION 3.27)
project(import_mlx LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
find_package(
Python 3.9
COMPONENTS Interpreter Development.Module
REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m pip show mlx
COMMAND grep location
COMMAND awk "{print $4 \"/mlx\"}"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE MLX_ROOT)
find_package(MLX CONFIG REQUIRED)
add_executable(llama llama.cpp)
target_link_libraries(llama PRIVATE mlx)
include(FetchContent)
FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
FetchContent_MakeAvailable(json)
target_include_directories(
llama PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
add_executable(test test.cpp)
target_include_directories(
test PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
// Copyright © 2024 Apple Inc.
#include "tokenizer.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <mlx/mlx.h>
#include <mlx/utils.h>
namespace mx = mlx::core;
#define seconds(x) \
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e9)
#define time_now() std::chrono::high_resolution_clock::now()
mx::array create_additive_causal_mask(const mx::array &y,
mx::Dtype dtype = mx::float32) {
auto indices = mx::arange(y.shape(-1));
auto mask = mx::expand_dims(indices, 1) < mx::expand_dims(indices, 0);
return mx::astype(mask, dtype) * mx::finfo(dtype).min;
}
int main(int argc, char *argv[]) {
if (argc < 3) {
std::cerr << "Must provide model path, tokenizer path, and prompt."
<< std::endl;
return 1;
}
auto path = std::string(argv[1]);
auto tokenizer = BPETokenizer(std::string(argv[2]));
auto prompt = std::string(argv[3]);
int max_tokens = 100;
auto generate_fn = mx::import_function(path);
auto prompt_tokens = tokenizer.encode(prompt);
auto y = mx::array(prompt_tokens.data(),
{1, static_cast<int>(prompt_tokens.size())}, mx::uint32);
auto inputs = std::vector<mx::array>{y};
auto tic = time_now();
float prompt_time;
int n = 0;
{
auto mask = create_additive_causal_mask(y);
inputs = generate_fn(inputs, {{"mask", mask}});
auto logits = inputs[0];
logits = slice(logits, {0, -1, 0}, logits.shape());
y = argmax(logits, -1);
async_eval(y);
}
auto step_fn = generate_fn; //, true);
auto offset = mx::array(prompt_tokens.size(), mx::uint32);
std::vector<int> tokens;
for (; n < max_tokens; ++n) {
inputs[0] = y;
if (n < max_tokens - 1) {
inputs.push_back(offset);
inputs = step_fn(inputs);
inputs[0] = argmax(inputs[0], -1);
offset = offset + 1u;
async_eval(inputs[0]);
}
auto token = y.item<int>();
if (token == tokenizer.eos_token_id()) {
break;
}
tokens.push_back(token);
auto [result, complete] = tokenizer.try_decode(tokens);
if (complete) {
std::cout << result << std::flush;
tokens.clear();
}
if (n == 0) {
prompt_time = seconds(time_now() - tic);
tic = time_now();
}
if (n < max_tokens - 1) {
y = inputs[0];
}
}
auto result = tokenizer.decode(tokens);
std::cout << result << std::flush;
auto gen_time = seconds(time_now() - tic);
std::cout << std::endl;
std::cout << std::setprecision(5) << "Prompt toks/sec "
<< prompt_tokens.size() / prompt_time << "\nGeneration toks/sec "
<< (n + 1) / gen_time << std::endl;
return 0;
}
import fire
import json
import glob
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
from pathlib import Path
import time
from transformers import AutoTokenizer
from types import SimpleNamespace
class DynamicNTKScalingRoPE(nn.Module):
def __init__(
self,
dims,
rope_scaling,
max_position_embeddings=2048,
base=10000,
):
super().__init__()
self.dims = dims
self.max_position_embeddings = max_position_embeddings
factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
old_context_len = rope_scaling["original_max_position_embeddings"]
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
freqs = base ** (mx.arange(0, self.dims, 2) / self.dims)
wavelens = 2 * mx.pi * freqs
freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
smooth_factors = (old_context_len / wavelens - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
def __call__(self, x, offset=0):
return mx.fast.rope(
x,
self.dims,
traditional=False,
base=None,
scale=1.0,
offset=offset,
freqs=self._freqs,
)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
head_dim = args.hidden_size // n_heads
self.scale = head_dim ** (-0.5)
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
self.rope = DynamicNTKScalingRoPE(
dims=head_dim,
rope_scaling=args.rope_scaling,
max_position_embeddings=args.max_position_embeddings,
base=args.rope_theta,
)
def __call__(self, x, mask=None, cache=None, offset=None):
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
queries = mx.unflatten(queries, -1, (self.n_heads, -1)).transpose(0, 2, 1, 3)
keys = mx.unflatten(keys, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
values = mx.unflatten(values, -1, (self.n_kv_heads, -1)).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=offset)
keys = self.rope(keys, offset=offset)
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, mask=mask, scale=self.scale
)
output = output.transpose(0, 2, 1, 3).flatten(-2, -1)
return self.o_proj(output), (keys, values)
class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
def __call__(self, x):
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(self, x, mask=None, cache=None, offset=None):
r, cache = self.self_attn(self.input_layernorm(x), mask, cache, offset)
h = x + r
out = h + self.mlp(self.post_attention_layernorm(h))
return out, cache
class LlamaModel(nn.Module):
def __init__(self, args):
super().__init__()
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(self, inputs, mask=None, cache=None, offset=None):
h = self.embed_tokens(inputs)
if not cache:
cache = [None] * len(self.layers)
offset = mx.array(0, mx.uint32)
if mask is not None:
mask = mask.astype(dtype=h.dtype)
for e, l in enumerate(self.layers):
h, cache[e] = l(h, mask, cache=cache[e], offset=offset)
return self.norm(h), cache
class Model(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.model = LlamaModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs, mask=None, cache=None, offset=None):
out, *state = self.model(inputs, mask, cache, offset)
return self.lm_head(out), *state
def load(hf_repo):
model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.json", "*.safetensors"],
)
)
with open(model_path / "config.json", "r") as f:
config = json.load(f)
weight_files = glob.glob(str(model_path / "model*.safetensors"))
weights = {}
for wf in weight_files:
weights.update(mx.load(wf))
model = Model(SimpleNamespace(**config))
if (quantization := config.get("quantization", None)) is not None:
nn.quantize(model, **quantization)
model.load_weights(list(weights.items()))
mx.eval(model)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.decode([0])
return model, tokenizer
def generate_step(prompt, model):
def _step(y, *state, mask=None):
if state:
cache, offset = state
else:
cache = offset = None
logits, *state = model(y, cache=cache, offset=offset, mask=mask)
return mx.argmax(logits[:, -1], axis=-1), *state
mask = nn.MultiHeadAttention.create_additive_causal_mask(prompt.size)
y, *state = _step(prompt, mask=mask)
mx.async_eval(y)
offset = mx.array(prompt.size, mx.uint32)
while True:
next_y, *state = _step(y[None], *state, offset)
offset += 1
mx.async_eval(next_y)
yield y.item()
y = next_y
def export(
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
path="llama3.1-instruct-4bit.mlxfn",
):
model, _ = load(model)
def step(y, *state, mask=None):
# Unflatten the cache if provided
if len(state) > 0:
cache, offset = state[:-1], state[-1]
cache = list(zip(cache[::2], cache[1::2]))
else:
cache = offset = None
logits, cache = model(y, cache=cache, offset=offset, mask=mask)
flat_cache = [y for x in cache for y in x]
return logits, *flat_cache
# Make example inputs
y_prompt = mx.array([[0, 0]], mx.uint32)
y_gen = mx.array([[0]], mx.uint32)
mask = nn.MultiHeadAttention.create_additive_causal_mask(y_prompt.size)
offset = mx.array([0], mx.uint32)
_, *state = step(y_prompt, mask=mask)
with mx.exporter(path, step, shapeless=True) as exporter:
exporter(y_prompt, mask=mask)
exporter(y_gen, *state, offset)
def generate(
prompt,
model="mlx-community/Meta-Llama-3.1-8B-Instruct-4bit",
max_tokens=128,
):
print("[INFO] Loading model from disk.")
model, tokenizer = load(model)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="mlx",
)
print("[INFO] Starting generation...")
tic = time.time()
tokens = []
for token, n in zip(generate_step(prompt, model), range(max_tokens)):
if n == 0:
prompt_tps = prompt.size / (time.time() - tic)
tic = time.time()
tokens.append(token)
if token == tokenizer.eos_token_id:
break
text = tokenizer.decode(tokens)
if text.endswith("\ufffd") or (len(text) == 1 and text[0] == ' '):
continue
else:
tokens = []
print(text, end="", flush=True)
print(tokenizer.decode(tokens), flush=True)
gen_tps = (n + 1) / (time.time() - tic)
print("=" * 10)
print(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {gen_tps:.3f} tokens-per-sec")
if __name__ == "__main__":
fire.Fire(
{
"generate": generate,
"export": export,
}
)
// Copyright © 2024 Apple Inc.
#include <iostream>
#include "tokenizer.h"
template <typename T, typename U = T>
void check(const T& x, const U& y) {
if (x != y) {
std::cerr << "Mismatch" << std::endl;
}
}
void test_tokenizer(const std::string& path) {
BPETokenizer tokenizer(path);
check(tokenizer.encode("hello world!"), {128000, 15339, 1917, 0});
check(tokenizer.decode({15339}), "hello");
check(tokenizer.decode({0}), "!");
check(tokenizer.decode({1917}), " world");
}
int main(int argc, char *argv[]) {
test_tokenizer(".");
}
// Copyright © 2024 Apple Inc.
#pragma once
#include <fstream>
#include <filesystem>
#include <locale>
#include <codecvt>
#include <json.hpp>
/** BPE Tokenizer API */
class BPETokenizer {
public:
BPETokenizer(const std::string& path);
/** Encode a string of text to token integer ids. */
std::vector<int> encode(std::string text) const;
/** Try to decode the vector of ids to text. The text is truncated to
* include only the fully decodable tokens. */
std::string decode(const std::vector<int>& ids) const;
/** Try to decode the vector of ids to text. The second return value
* indicates if the decoding completed. The text is truncated to include
* only the fully decodable tokens. */
std::pair<std::string, bool> try_decode(const std::vector<int>& ids) const;
int eos_token_id() const;
private:
std::unordered_map<std::string, int> token_to_id_;
std::vector<std::string> id_to_token_;
std::unordered_map<std::string, int> merges_;
int bos_id_;
int eos_id_;
static std::unordered_map<uint16_t, char> byte_decoder_;
std::string id_to_bytes(int id) const;
};
using json = nlohmann::json;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
std::pair<std::wstring, int> utf8_to_utf16(const std::string& s) {
static std::string replace_str = std::string(1, 0xFF);
static std::wstring replace_wstr = std::wstring(1, 0xFFFD);
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> cvt(replace_str, replace_wstr);
auto out = cvt.from_bytes(s);
return {out, cvt.converted()};
}
#pragma GCC diagnostic pop
auto make_byte_decoder() {
std::unordered_map<uint16_t, char> byte_decoder;
std::vector<uint16_t> limits = {0, '!', '~' + 1, L'¡', L'¬' + 1, L'®', L'ÿ' + 1};
char n = 0;
for (int i = 0; i < limits.size() - 1; ++i) {
auto start = limits[i];
auto stop = limits[i + 1];
if (i % 2 == 0) {
for (int b = start; b < stop; ++b) {
byte_decoder[256 + n++] = b;
}
} else {
for (int b = start; b < stop; ++b) {
byte_decoder[b] = b;
}
}
}
return byte_decoder;
}
auto BPETokenizer::byte_decoder_ = make_byte_decoder();
BPETokenizer::BPETokenizer(const std::string& path_) {
auto path = std::filesystem::path(path_);
std::ifstream ifs(path / "tokenizer.json");
auto tokenizer = json::parse(ifs);
auto model = tokenizer["model"];
token_to_id_ = model["vocab"];
id_to_token_.resize(token_to_id_.size());
for (auto& [s, id] : token_to_id_) {
if (id >= id_to_token_.size()) {
id_to_token_.resize(id + 1);
}
id_to_token_[id] = s;
}
std::string type = model["type"];
std::vector<std::string> merge_vec = model["merges"];
for (auto& s : merge_vec) {
merges_.emplace(std::move(s), merges_.size());
}
auto added_tokens = tokenizer["added_tokens"];
for (auto& added_token : added_tokens) {
int id = added_token["id"];
if (id >= id_to_token_.size()) {
id_to_token_.resize(id + 1);
}
id_to_token_[id] = added_token["content"];
if (id_to_token_[id] == "<|begin_of_text|>") {
bos_id_ = id;
} else if (id_to_token_[id] == "<|eot_id|>") {
eos_id_ = id;
}
}
}
std::vector<int> BPETokenizer::encode(std::string text) const {
std::vector<std::string> tokens;
tokens.reserve(text.size());
for (auto c : text) {
tokens.push_back(c == ' ' ? "Ġ" : std::string(1, c));
}
auto one_step_merge = [&tokens, this]() {
std::string merge_l;
std::string merge_r;
int rank = INT32_MAX;
for (int i = 0; i < tokens.size() - 1; ++i) {
std::string candidate = tokens[i];
candidate += " ";
candidate += tokens[i + 1];
if (auto it = merges_.find(candidate); it != merges_.end()) {
if (it->second < rank) {
merge_l = tokens[i];
merge_r = tokens[i + 1];
rank = it->second;
}
}
}
if (rank == INT32_MAX) {
return false;
}
for (int i = tokens.size() - 2; i >= 0; --i) {
if (tokens[i] == merge_l && tokens[i+1] == merge_r) {
tokens.erase(tokens.begin() + i + 1);
tokens[i] = merge_l + merge_r;
i -= 1;
}
}
return true;
};
while (one_step_merge()) { };
std::vector<int> ids;
ids.reserve(tokens.size() + 1);
ids.push_back(bos_id_);
for (auto& s : tokens) {
if (auto it = token_to_id_.find(s); it != token_to_id_.end()) {
ids.push_back(it->second);
} else {
throw std::runtime_error("UNK ENCOUNTERED");
}
}
return ids;
}
std::string BPETokenizer::id_to_bytes(int id) const {
std::string token;
auto [wide_token, _] = utf8_to_utf16(id_to_token_[id]);
token.resize(wide_token.size());
for (int i = 0; i < wide_token.size(); ++i) {
token[i] = byte_decoder_[wide_token[i]];
}
return token;
}
std::pair<std::string, bool> BPETokenizer::try_decode(const std::vector<int>& ids) const {
std::string text;
for (auto id : ids) {
text += id_to_bytes(id);
}
auto [_, converted] = utf8_to_utf16(text);
bool complete = converted == text.size();
text.resize(converted);
return {text, complete};
}
std::string BPETokenizer::decode(const std::vector<int>& ids) const {
return try_decode(ids).first;
}
int BPETokenizer::eos_token_id() const { return eos_id_; }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment