Created
January 2, 2024 16:10
-
-
Save Ttl/0d51f739dc59254b4b2183e259c97d82 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Calculate KL-divergence of two models output logits on data set. | |
First call the program with write_path and text_path using fp16 model. | |
./llama_kl.py -m <fp16 model> -t <wiki.test.raw> -w <logits.gz> | |
This writes logits to file. Then call the program with quantized model with read path | |
./llama_kl.py -m <quantized model> -r <logits.gz> | |
KL-divergence to the first run is calculated. | |
See ./llama_kl.py --help for more options. | |
""" | |
import llama_cpp | |
import numpy as np | |
import sys | |
import argparse | |
import os.path | |
import struct | |
import ast | |
from scipy.special import rel_entr, softmax | |
import gzip | |
import pickle | |
from scipy.stats.mstats import mquantiles_cimj | |
from scipy.stats import bayes_mvs | |
from scipy.stats import t as student_t | |
import random | |
import time | |
def kl_div(p, q): | |
p = softmax(p) | |
q = softmax(q) | |
return np.sum(rel_entr(p, q)) | |
def write_header(f, args, ctx, vocab_len, batch): | |
f.write("llama_kl_divergence_v1\n".encode('utf-8')) | |
d = vars(args) | |
d["n_ctx"] = ctx | |
d["n_vocab"] = vocab_len | |
d["n_batch"] = batch | |
f.write((str(d)+"\n").encode('utf-8')) | |
def read_header(f): | |
header = "llama_kl_divergence_v1\n".encode('utf-8') | |
if f.read(len(header)) != header: | |
raise ValueError("Invalid header in input logit file") | |
args = ast.literal_eval(f.readline().decode('utf-8').strip()) | |
return args | |
def write_logits(f, tokens, logits): | |
f.write(struct.pack("<I", len(tokens))) | |
f.write(struct.pack("<I", len(logits))) | |
f.write(struct.pack("<I", len(logits[0]))) | |
t = np.array(tokens, dtype=np.uint32).tobytes() | |
assert len(t) == 4 * len(tokens) | |
f.write(t) | |
l = np.array(logits, dtype=np.float32).tobytes() | |
assert len(l) == 4 * len(logits) * len(logits[0]) | |
f.write(l) | |
def read_logits(f): | |
n_tokens = f.read(4) | |
if len(n_tokens) != 4: | |
# EOF | |
return None, None | |
n_tokens = struct.unpack("<I", n_tokens)[0] | |
n_logits = struct.unpack("<I",f.read(4))[0] | |
n_vocab = struct.unpack("<I",f.read(4))[0] | |
tokens = [int(i) for i in np.frombuffer(f.read(n_tokens * 4), dtype=np.uint32)] | |
logits = np.frombuffer(f.read(n_logits * n_vocab * 4), dtype=np.float32).reshape(n_logits, n_vocab) | |
return tokens, logits | |
def main(args): | |
ctx = args.n_ctx | |
read_file = None | |
if args.read_path is not None: | |
print(f"Computing KL-divergence against: {args.read_path}") | |
read_file = gzip.open(args.read_path, "rb") | |
input_args = read_header(read_file) | |
ctx = input_args["n_ctx"] | |
model = llama_cpp.Llama(model_path=args.model, n_ctx=ctx, n_batch=args.n_batch, | |
logits_all=True, n_gpu_layers=args.n_gpu_layers, verbose=args.verbose) | |
model_name = os.path.split(args.model)[1] | |
tokens = None | |
if args.text_path and args.read_path is None: | |
with open(args.text_path, "r") as f: | |
prompt = f.read() | |
print(f"Computing logits from text file: {args.text_path}") | |
tokens = model.tokenize(prompt.encode('utf-8')) | |
bos = model.token_bos() | |
b = 1 if bos is not None else 0 | |
tokens = [tokens[i:i+ctx-b] for i in range(0, len(tokens), ctx-b)] | |
random.seed(123) | |
if bos is not None: | |
for i in range(len(tokens)): | |
tokens[i].insert(0, bos) | |
# Improves error estimation during calculation as context correlation to previous | |
# context is reduced compared to unshuffled order. Doesn't affect the final result. | |
random.shuffle(tokens) | |
write_file = None | |
if args.write_path is not None: | |
write_file = gzip.open(args.write_path, "wb") | |
write_header(write_file, args, model.n_ctx(), model.n_vocab(), model.n_batch) | |
def next_sample(): | |
if read_file is not None: | |
while True: | |
try: | |
t, logits = read_logits(read_file) | |
except EOFError: | |
print("EOF at unexpected location") | |
return | |
if t is None: | |
return | |
yield logits, t | |
elif tokens is not None: | |
for t in tokens: | |
yield None, t | |
# Confidence interval bound | |
alpha = 0.01 | |
kls = [] | |
top1 = 0 | |
top5 = 0 | |
top10 = 0 | |
eval_top5 = 0 | |
eval_top10 = 0 | |
samples = 0 | |
written = 0 | |
written_tokens = 0 | |
i = 0 | |
errors = 0 | |
max_tokens = args.n_tokens | |
if max_tokens < 0: | |
max_tokens = float('inf') | |
try: | |
for logits, chunk in next_sample(): | |
#print(model.detokenize(chunk)) | |
model.reset() | |
output = model.eval(chunk) | |
eval_logits = model.eval_logits | |
if np.any(np.isnan(eval_logits)): | |
errors += 1 | |
print("Nan in logits!") | |
eval_logits = np.nan_to_num(eval_logits) | |
if write_file: | |
write_logits(write_file, model.eval_tokens, eval_logits) | |
written_tokens += len(model.eval_tokens) | |
written += 1 | |
print(f"[{written}/{len(tokens)}] tokens {written_tokens}") | |
if logits is not None: | |
# It would probably be better to throw away at least two first tokens | |
# in the context window since those are always the same. It doesn't | |
# matter that much though unlike in perplexity calculation since | |
# we are comparing to reference. | |
# This is really slow. | |
new_kls = [kl_div(eval_logits[i], logits[i]) for i in range(len(logits))] | |
if np.any(np.isnan(new_kls)): | |
errors += 1 | |
print("Nan in computed kls!") | |
new_kls = np.nan_to_num(new_kls) | |
kls.extend(new_kls) | |
samples += len(logits) | |
# This is even slower. | |
eval_argmax = np.argmax(eval_logits, axis=-1) | |
ref_argmax = np.argmax(logits, axis=-1) | |
eval_part5 = np.argpartition(eval_logits, -5, axis=-1)[:,-5:] | |
ref_part5 = np.argpartition(logits, -5, axis=-1)[:,-5:] | |
eval_part10 = np.argpartition(eval_logits, -10, axis=-1)[:,-10:] | |
ref_part10 = np.argpartition(logits, -10, axis=-1)[:,-10:] | |
top1 += sum([eval_argmax[i] == ref_argmax[i] for i in range(len(logits))]) | |
top5 += sum([ref_argmax[i] in eval_part5[i] for i in range(len(logits))]) | |
top10 += sum([ref_argmax[i] in eval_part10[i] for i in range(len(logits))]) | |
eval_top5 += sum([eval_argmax[i] in ref_part5[i] for i in range(len(logits))]) | |
eval_top10 += sum([eval_argmax[i] in ref_part10[i] for i in range(len(logits))]) | |
print(f"[{i}] kl {np.mean(kls):.4g}, top1 {top1 / samples:.4g}", flush=True) | |
i += 1 | |
if samples >= max_tokens: | |
print("Token limit reached") | |
break | |
except KeyboardInterrupt: | |
print("Interrupted") | |
if write_file: | |
write_file.close() | |
print(f"Finished writing file: {args.write_path}") | |
if read_file: | |
read_file.close() | |
print(f"Finished reading file: {args.read_path}") | |
def bin_conf(p, n, z): | |
# Binomial distribution confidence bounds | |
# Bayes estimator when p is degenerate | |
if p == 0: | |
p = 1 / (n + 2) | |
if p == 1: | |
p = 1 - 1 / (n + 2) | |
return z * np.sqrt(p*(1-p)/n) | |
if len(kls) > 0: | |
z = student_t.ppf(1 - alpha/2, samples) | |
print() | |
print("Model:", model_name) | |
bpw = 8 * llama_cpp.llama_model_size(model.model) / llama_cpp.llama_model_n_params(model.model) | |
print(f"Size: {llama_cpp.llama_model_size(model.model) / 1024**3:.3g} GiB, (BPW {bpw:.2f})") | |
print("Tokens:", samples) | |
print("KL-divergence:") | |
# Confidence interval assuming i.i.d, but that likely isn't true. | |
m_conf = z*np.sqrt(np.mean([k**2 for k in kls])/len(kls)) | |
m, _, __ = bayes_mvs(kls, 1-alpha) | |
print(f"mean: {m[0]:.6g}, [{m[1][0]:.6g} - {m[1][1]:.6g}]") | |
q90 = np.quantile(kls, 0.90) | |
q95 = np.quantile(kls, 0.95) | |
q99 = np.quantile(kls, 0.99) | |
q_bounds = mquantiles_cimj(kls, prob=[0.90, 0.95, 0.99]) | |
print(f"q90: {q90:.4g}, [{q_bounds[0][0]:.4g} - {q_bounds[1][0]:.4g}]") | |
print(f"q95: {q95:.4g}, [{q_bounds[0][1]:.4g} - {q_bounds[1][1]:.4g}]") | |
print(f"q99: {q99:.4g}, [{q_bounds[0][2]:.4g} - {q_bounds[1][2]:.4g}]") | |
print(f"max: {np.max(kls):.4g}") | |
print("Reference top token in eval top-n probability:") | |
print(f"ref_top1: {top1 / samples:.4g} ± {bin_conf(top1/samples, samples, z):.4g}") | |
print(f"ref_top5: {top5 / samples:.4g} ± {bin_conf(top5/samples, samples, z):.4g}") | |
print(f"ref_top10: {top10 / samples:4g} ± {bin_conf(top10/samples, samples, z):.4g}") | |
print("Eval top token in reference top-n probability:") | |
print(f"eval_top5: {eval_top5 / samples:.4g} ± {bin_conf(eval_top5/samples, samples, z):.4g}") | |
print(f"eval_top10: {eval_top10 / samples:4g} ± {bin_conf(eval_top10/samples, samples, z):.4g}") | |
print(f"errors: {errors}") | |
with open(model_name + ".kls.p", 'wb') as f: | |
pickle.dump(kls, f) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
prog='llama.cpp KL-divergence', | |
description="Calculate KL-divergence of two models output logits on data set.\n" | |
"First call the program with write_path and text_path using fp16 model.\n" | |
"This writes logits to file. Then call the program with quantized model with read path\n" | |
"KL-divergence to the first run is calculated\n") | |
parser.add_argument('-m', '--model', help="Model path", required=True) | |
parser.add_argument('-t', '--text_path', help="Text dataset path", required=False) | |
parser.add_argument('-c', '--n_ctx', help="Context size", default=512, type=int, required=False) | |
parser.add_argument('-b', '--n_batch', help="Batch size", default=512, type=int, required=False) | |
parser.add_argument('-w', '--write_path', help="Output logits file", required=False) | |
parser.add_argument('-r', '--read_path', help="Input logits file", required=False) | |
parser.add_argument('-n', '--n_tokens', help="Number of tokens to evaluate. (-1 = whole file)", default=-1, type=int, required=False) | |
parser.add_argument('-ngl', '--n-gpu-layers', help="Number of GPU layers", default=0, type=int, required=False) | |
parser.add_argument('-v', '--verbose', help="Verbose output", action="store_true") | |
args = parser.parse_args() | |
if args.read_path is None and args.text_path is None: | |
print("Either text dataset or input logit file should be specified") | |
if args.write_path is None and args.read_path is None: | |
print("At least one of read_path or write_path needs to be specified") | |
sys.exit(1) | |
if args.write_path is not None and os.path.exists(args.write_path): | |
print(f"write_path {args.write_path} already exists") | |
sys.exit(1) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment