Last active
October 11, 2024 15:55
-
-
Save zucchini-nlp/56ce57276d7b1ee666e957912d8d36ca to your computer and use it in GitHub Desktop.
Script to benchmark the latency and memory consumption of different cache implementations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# tested on https://github.com/zucchini-nlp/transformers/tree/quant (commit_id 5f3046a) | |
import os | |
import argparse | |
from pathlib import Path | |
from time import perf_counter | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.cache_utils import DynamicCache, QuantCache | |
os.environ["TOKENIZERS_PARALLELISM"] = "0" | |
class TorchTracemalloc(): | |
track_memory_consumption = [] | |
def __enter__(self): | |
self.begin = torch.cuda.memory_allocated() | |
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero | |
return self | |
def __exit__(self, *exc): | |
peak = torch.cuda.max_memory_allocated() | |
peaked = (peak - self.begin) // 1024 ** 2 | |
TorchTracemalloc.track_memory_consumption.append(peaked) | |
#print(f"peak: {peaked}; reserved: {torch.cuda.max_memory_reserved() // 1024 ** 2}") | |
@torch.no_grad() | |
def prefill(model, inputs, cache_implementation, nbits=4): | |
if cache_implementation == "quantized": | |
past_key_values = QuantCache(nbits=nbits) | |
else: | |
past_key_values = DynamicCache() | |
input_length = inputs["input_ids"].shape[1] | |
inputs["cache_position"] = torch.arange(input_length, device=inputs["input_ids"].device) | |
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_tokens = torch.argmax(next_token_logits, dim=-1) | |
next_input_ids = torch.cat([inputs["input_ids"], next_tokens[:, None]], dim=-1) | |
next_model_kwargs = model._update_model_kwargs_for_generation( | |
outputs, | |
inputs, | |
is_encoder_decoder=False, | |
) | |
return next_input_ids, next_model_kwargs | |
def save_bar_chart(title, x, y, ylabel, xlabel, output_path): | |
width = 0.4 | |
xs = np.arange(len(x)) | |
plt.bar(xs, height=y, width=width) | |
plt.title(title) | |
plt.xticks(xs, x) | |
plt.xlabel(ylabel) | |
plt.ylabel(xlabel) | |
plt.savefig(output_path) | |
def eval_generated_lengths(model, tokenizer, dataset, cache_implementation, nbits, feature, plot_title, output_path): | |
# warm up | |
generate_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0} | |
for _ in range(3): | |
inputs_warmup = tokenizer(["Today a dragon flew over Paris"] * 2, return_tensors="pt").to(model.device) | |
model.generate(**inputs_warmup, max_new_tokens=20, **generate_kwargs) | |
memory_avg, tokens_per_sec_avg = [], [] | |
time_to_first_token_avg = [] | |
TTFT, TIME_PER_DECODING = [], [] | |
# set default values, only one of them will be changing | |
parameters = {"max_new_tokens": 500, "batch_size": 1, "input_length": 100} | |
num_batches = 2 # NOTE: 200 samples total only in dataset | |
if feature == "batch_size": | |
x_iterable = [1, 20, 50, 100, 200] | |
else: | |
x_iterable = [500, 1000, 4000, 10_000] | |
for item in x_iterable: | |
parameters[feature] = item | |
generate_kwargs_curr = generate_kwargs.copy() | |
generate_kwargs_curr["min_new_tokens"] = parameters["max_new_tokens"] | |
generate_kwargs_curr["max_new_tokens"] = parameters["max_new_tokens"] | |
batch_size = parameters["batch_size"] | |
with TorchTracemalloc() as tt: | |
for batch in range(num_batches): | |
start = perf_counter() | |
torch.cuda.synchronize() | |
# chunk this way since we do not have many data samples | |
curr_chunk = dataset[batch: batch+batch_size] | |
inputs = tokenizer( | |
curr_chunk['prompt'], | |
padding="max_length", | |
max_length=parameters["input_length"], | |
truncation=True, | |
return_tensors="pt" | |
).to(model.device) | |
# pre-fill stage | |
next_input_ids, next_model_kwargs = prefill(model, inputs, cache_implementation, nbits) | |
TTFT.append(perf_counter() - start) | |
next_model_kwargs.pop("input_ids") | |
torch.cuda.synchronize() | |
# decoding stage | |
out = model.generate( | |
next_input_ids, | |
**next_model_kwargs, | |
**generate_kwargs_curr | |
) | |
TIME_PER_DECODING.append((perf_counter() - start - TTFT[-1]) / batch_size / parameters["max_new_tokens"]) | |
del out | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
memory_avg.append(TorchTracemalloc.track_memory_consumption[-1]) | |
tokens_per_sec_avg.append(1 / (sum(TIME_PER_DECODING) / len(TIME_PER_DECODING))) | |
time_to_first_token_avg.append(sum(TTFT) / len(TTFT)) | |
save_bar_chart( | |
title=plot_title, | |
x=x_iterable, | |
y=memory_avg, | |
ylabel=feature, | |
xlabel="GPU Memory comsumption in MiB", | |
output_path=f"{output_path}/memory.png", | |
) | |
save_bar_chart( | |
title=plot_title, | |
x=x_iterable, | |
y=tokens_per_sec_avg, | |
ylabel=feature, | |
xlabel="Tokens per second", | |
output_path=f"{output_path}/latency.png", | |
) | |
print(f"Tokens per sec (avg) - one per condition: {tokens_per_sec_avg}") | |
print(f"Time to first token (avg) - one per condition: {tokens_per_sec_avg}") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--cache_implementation", type=str, default="quantized") | |
parser.add_argument("--nbits", type=int, default=4) | |
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") | |
parser.add_argument("--trust_remote_code", action="store_true") | |
parser.add_argument("--attn_implementation", type=str, default="sdpa") | |
parser.add_argument("--dtype", type=str, default="fp16") | |
parser.add_argument("--num_samples", type=int, default=5) | |
parser.add_argument("--feature", type=str, default="batch_size", choices=["batch_size", "input_length", "max_new_tokens"]) | |
parser.add_argument("--output_path", type=str, default="./output") | |
parser.add_argument("--plot_title", type=str, default="Quantized cache in int4") | |
args = parser.parse_args() | |
if args.dtype == "fp16": | |
dtype = torch.float16 | |
elif args.dtype == "fp32": | |
dtype = torch.float32 | |
elif args.dtype == "bf16": | |
dtype = torch.bfloat16 | |
else: | |
raise ValueError(f"Unknown dtype: {args.dtype}") | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=bool(args.trust_remote_code), | |
attn_implementation=args.attn_implementation, | |
torch_dtype=dtype, | |
).to("cuda:0") | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code), padding_side="left") | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
def collate_fn(example): | |
prompt=f"Question: {example['input']}\nContext: {example['context']}\nAnswer:" | |
example['prompt'] = prompt | |
return example | |
dataset = load_dataset('THUDM/LongBench', "samsum", split='test') | |
dataset = dataset.map(collate_fn, batched=False) | |
eval_generated_lengths( | |
model, | |
tokenizer, | |
dataset, | |
cache_implementation=args.cache_implementation, | |
nbits=args.nbits, | |
feature=args.feature, | |
plot_title=args.plot_title, | |
output_path=args.output_path, | |
) | |
if __name__ == "__main__": | |
main() |
Holy....
sorry, I forgot to add "cuda" on the last script. Can you move dummy tensor to cuda first before doing all the quant part? On an A100 machine if you can pls
And if it doesn't make things faster, feel free to open an issue in transformers so we can track it there
Like this? dummy_tensor_inputs = torch.randn(1, 32, 10_000, 128).to("cuda")
I tested it and it was still 80 seconds on A100 Colab
yes, and if you can open an issue where I'll tag involved ppl to help you out?
okay, I made an issue here: huggingface/transformers#34096
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I am getting 164sec 🙃 on free colab with T4, but local machine with A100-80GB runs in 14sec