-
-
Save zucchini-nlp/56ce57276d7b1ee666e957912d8d36ca to your computer and use it in GitHub Desktop.
# tested on https://github.com/zucchini-nlp/transformers/tree/quant (commit_id 5f3046a) | |
import os | |
import argparse | |
from pathlib import Path | |
from time import perf_counter | |
import numpy as np | |
from matplotlib import pyplot as plt | |
import torch | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers.cache_utils import DynamicCache, QuantCache | |
os.environ["TOKENIZERS_PARALLELISM"] = "0" | |
class TorchTracemalloc(): | |
track_memory_consumption = [] | |
def __enter__(self): | |
self.begin = torch.cuda.memory_allocated() | |
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero | |
return self | |
def __exit__(self, *exc): | |
peak = torch.cuda.max_memory_allocated() | |
peaked = (peak - self.begin) // 1024 ** 2 | |
TorchTracemalloc.track_memory_consumption.append(peaked) | |
#print(f"peak: {peaked}; reserved: {torch.cuda.max_memory_reserved() // 1024 ** 2}") | |
@torch.no_grad() | |
def prefill(model, inputs, cache_implementation, nbits=4): | |
if cache_implementation == "quantized": | |
past_key_values = QuantCache(nbits=nbits) | |
else: | |
past_key_values = DynamicCache() | |
input_length = inputs["input_ids"].shape[1] | |
inputs["cache_position"] = torch.arange(input_length, device=inputs["input_ids"].device) | |
outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) | |
next_token_logits = outputs.logits[:, -1, :] | |
next_tokens = torch.argmax(next_token_logits, dim=-1) | |
next_input_ids = torch.cat([inputs["input_ids"], next_tokens[:, None]], dim=-1) | |
next_model_kwargs = model._update_model_kwargs_for_generation( | |
outputs, | |
inputs, | |
is_encoder_decoder=False, | |
) | |
return next_input_ids, next_model_kwargs | |
def save_bar_chart(title, x, y, ylabel, xlabel, output_path): | |
width = 0.4 | |
xs = np.arange(len(x)) | |
plt.bar(xs, height=y, width=width) | |
plt.title(title) | |
plt.xticks(xs, x) | |
plt.xlabel(ylabel) | |
plt.ylabel(xlabel) | |
plt.savefig(output_path) | |
def eval_generated_lengths(model, tokenizer, dataset, cache_implementation, nbits, feature, plot_title, output_path): | |
# warm up | |
generate_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0} | |
for _ in range(3): | |
inputs_warmup = tokenizer(["Today a dragon flew over Paris"] * 2, return_tensors="pt").to(model.device) | |
model.generate(**inputs_warmup, max_new_tokens=20, **generate_kwargs) | |
memory_avg, tokens_per_sec_avg = [], [] | |
time_to_first_token_avg = [] | |
TTFT, TIME_PER_DECODING = [], [] | |
# set default values, only one of them will be changing | |
parameters = {"max_new_tokens": 500, "batch_size": 1, "input_length": 100} | |
num_batches = 2 # NOTE: 200 samples total only in dataset | |
if feature == "batch_size": | |
x_iterable = [1, 20, 50, 100, 200] | |
else: | |
x_iterable = [500, 1000, 4000, 10_000] | |
for item in x_iterable: | |
parameters[feature] = item | |
generate_kwargs_curr = generate_kwargs.copy() | |
generate_kwargs_curr["min_new_tokens"] = parameters["max_new_tokens"] | |
generate_kwargs_curr["max_new_tokens"] = parameters["max_new_tokens"] | |
batch_size = parameters["batch_size"] | |
with TorchTracemalloc() as tt: | |
for batch in range(num_batches): | |
start = perf_counter() | |
torch.cuda.synchronize() | |
# chunk this way since we do not have many data samples | |
curr_chunk = dataset[batch: batch+batch_size] | |
inputs = tokenizer( | |
curr_chunk['prompt'], | |
padding="max_length", | |
max_length=parameters["input_length"], | |
truncation=True, | |
return_tensors="pt" | |
).to(model.device) | |
# pre-fill stage | |
next_input_ids, next_model_kwargs = prefill(model, inputs, cache_implementation, nbits) | |
TTFT.append(perf_counter() - start) | |
next_model_kwargs.pop("input_ids") | |
torch.cuda.synchronize() | |
# decoding stage | |
out = model.generate( | |
next_input_ids, | |
**next_model_kwargs, | |
**generate_kwargs_curr | |
) | |
TIME_PER_DECODING.append((perf_counter() - start - TTFT[-1]) / batch_size / parameters["max_new_tokens"]) | |
del out | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
memory_avg.append(TorchTracemalloc.track_memory_consumption[-1]) | |
tokens_per_sec_avg.append(1 / (sum(TIME_PER_DECODING) / len(TIME_PER_DECODING))) | |
time_to_first_token_avg.append(sum(TTFT) / len(TTFT)) | |
save_bar_chart( | |
title=plot_title, | |
x=x_iterable, | |
y=memory_avg, | |
ylabel=feature, | |
xlabel="GPU Memory comsumption in MiB", | |
output_path=f"{output_path}/memory.png", | |
) | |
save_bar_chart( | |
title=plot_title, | |
x=x_iterable, | |
y=tokens_per_sec_avg, | |
ylabel=feature, | |
xlabel="Tokens per second", | |
output_path=f"{output_path}/latency.png", | |
) | |
print(f"Tokens per sec (avg) - one per condition: {tokens_per_sec_avg}") | |
print(f"Time to first token (avg) - one per condition: {tokens_per_sec_avg}") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--cache_implementation", type=str, default="quantized") | |
parser.add_argument("--nbits", type=int, default=4) | |
parser.add_argument("--model_name_or_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") | |
parser.add_argument("--trust_remote_code", action="store_true") | |
parser.add_argument("--attn_implementation", type=str, default="sdpa") | |
parser.add_argument("--dtype", type=str, default="fp16") | |
parser.add_argument("--num_samples", type=int, default=5) | |
parser.add_argument("--feature", type=str, default="batch_size", choices=["batch_size", "input_length", "max_new_tokens"]) | |
parser.add_argument("--output_path", type=str, default="./output") | |
parser.add_argument("--plot_title", type=str, default="Quantized cache in int4") | |
args = parser.parse_args() | |
if args.dtype == "fp16": | |
dtype = torch.float16 | |
elif args.dtype == "fp32": | |
dtype = torch.float32 | |
elif args.dtype == "bf16": | |
dtype = torch.bfloat16 | |
else: | |
raise ValueError(f"Unknown dtype: {args.dtype}") | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name_or_path, | |
trust_remote_code=bool(args.trust_remote_code), | |
attn_implementation=args.attn_implementation, | |
torch_dtype=dtype, | |
).to("cuda:0") | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=bool(args.trust_remote_code), padding_side="left") | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
def collate_fn(example): | |
prompt=f"Question: {example['input']}\nContext: {example['context']}\nAnswer:" | |
example['prompt'] = prompt | |
return example | |
dataset = load_dataset('THUDM/LongBench', "samsum", split='test') | |
dataset = dataset.map(collate_fn, batched=False) | |
eval_generated_lengths( | |
model, | |
tokenizer, | |
dataset, | |
cache_implementation=args.cache_implementation, | |
nbits=args.nbits, | |
feature=args.feature, | |
plot_title=args.plot_title, | |
output_path=args.output_path, | |
) | |
if __name__ == "__main__": | |
main() |
@kazunator I mean on gist script, sorry. I can't change colab
@zucchini-nlp Thank you! The code runs now!
@zucchini-nlp Just wanted to ask. Is it normal that it takes about 100 seconds to generate with quantized cashing for 50k context window? Normal generation takes about 10 to 15 seconds
It is expected to be slower in general. I didn't try with 50k context length, though with smaller lengths the token/second latency was just slighlty worse in quantized cache.
What is the batch size you are using and did you enable FA2 for long-context generation? Did you try with smaller batch-size or context length to compare the latency?
I tried with 20k 50k and 100k, and they were all around 80 to 100 seconds for a generation. I used the simpler code from the colab to test it out. It's probably lacking.
Code:
tokenizer = AutoTokenizer.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset('THUDM/LongBench', "samsum", split='test')
very_long_context = " ".join(dataset["context"])
inputs = tokenizer(very_long_context, max_length=100000, truncation="only_first", return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
max_new_tokens=20,
cache_implementation="quantized",
cache_config={"backend": "quanto", "nbits": 4, "q_group_size": 64, "residual_length": 128}
)
generated_text = tokenizer.batch_decode(out)
I also found something very weird. When using your fork, a normal generation only has slightly worse memory footprint than the quantized version (tapping out at 80k on 40Gb A100, while the quantized taps out at 100k). But when using the main branch, the normal generation has a massive memory footprint, filling out the the entire 40Gbs of VRAM with just 20k tokens. That quite confused me.
Hmm, that is interesting that memory is almost same. It should give a 2.5-3 times less memory footprint in general. Maybe you can provide a reproducer for normal cache which shows that memory footprint is higher in the last release? I'll try to dig into that as it might be an issue we need to solve.
For quant cache, I remember not having memory saving from a certain token count since the bottleneck in those cases was the matmul in Attn module. That is the reason I suggest FA2, but also I'd suggest to calculate memory for decoding stage only (not get the actual sense of how much is saved in cache). Otherwise, the general consumption might not reflect memory used up by cache
Here is the collab to reproduce the normal cache with abnormally high memory footprint on last release: https://colab.research.google.com/drive/1tNDHC7-z2pOEQRLIDf4YfzY339HPl2Q3?usp=sharing
As you can see, it's almost filling up the entire VRAM with just 20k context. For comparison, here is how the memory looks like using the same context and the same code but just changing the library to your fork
I'm also using FA2 like this on all my tests, whether using fork or main, or using quant or normal cache: model = AutoModelForCausalLM.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
Also, here is a link to a documentation of all the test that I've done on your fork and the main lib, using both quant and normal generation (on 20k, 50k, 80k, and 100k. The normal gen on main transformers was only tested on 20k as the other context lengths will result in OOM, and only the quant on the forked lib was able to reach 100k.) : https://docs.google.com/document/d/1TgSIhjwVg04V20mCLqZZcNNcOMhHuSP5loAavfBRXPI/edit?usp=sharing
Honestly, these tests left me thinking that there is some voodoo magic trick in your fork's normal generation, because when I tested vllm without quant using the same model on 20k context, it gave me similar results to the normal generation with the main transformer library. But this could be a massive coincidence of the two libraries having a bad pr merge lately (or maybe vllm is just inefficient)
Thanks a lot for detailed info! I'll take a few hours tomorrow to dig into this
You're welcome!
I used the following code with my local A100 (I don't have colab gpus) and I got sensible memory usage. Normal cache is almost twice as much as quantized cache, but still doesn't require much memory as you got. If we assume model weights need 16GiB then the total memory needed is around 21GiB for normal cache.
That is exactly what you got for fork branch, and I got it for the latest main
branch (same as release as we didn't change anything afaik)
Normal cache: 4783 MiB VRAM used
Quant cache: 2987 MiB VRAM used
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset('THUDM/LongBench', "samsum", split='test')
very_long_context = " ".join(dataset["context"])
inputs = tokenizer(very_long_context, max_length=20000, truncation="only_first", return_tensors="pt").to(model.device)
generation_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 20, "min_new_tokens": 20, "cache_implementation": "quantized"}
begin_mem = torch.cuda.memory_allocated()
out_fp16 = model.generate(**inputs, **generation_kwargs)
generated_text = tokenizer.batch_decode(out_fp16)
end_mem = torch.cuda.max_memory_allocated()
print(f"{(end_mem - begin_mem) // 1024 ** 2} MiB VRAM used")
Hmm, I think I found the issue. When I just pip install the normal transformers package, the issue persists with high memory footprint, but when I clone the main branch, I get similar results to you. The transformers version where the issue exists is 4.44.2 (at least on colab). I do remember that I've tried to test the main branch with gitclone yesterday and got the memory footprint issue, but maybe that's just my faulty memory. In any case, this new test I've done was with the pip package and it's still reproducible on my end.
Other than that, I also wanted to ask you about the time it takes to generate with quantized cache. Was I doing something egregiously wrong with my code? Or is that expected behavior?
This code:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
tokenizer = AutoTokenizer.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset('THUDM/LongBench', "samsum", split='test')
very_long_context = " ".join(dataset["context"])
inputs = tokenizer(very_long_context, max_length=20000, truncation="only_first", return_tensors="pt").to(model.device)
generation_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 20, "min_new_tokens": 20, "cache_implementation": "quantized"}
begin_mem = torch.cuda.memory_allocated()
start_time = time.time()
out_fp16 = model.generate(**inputs, **generation_kwargs)
generated_text = tokenizer.batch_decode(out_fp16)
end_time = time.time()
time_taken = end_time - start_time
end_mem = torch.cuda.max_memory_allocated()
print(f"Time taken: {time_taken:.2f} seconds")
print(f"{(end_mem - begin_mem) // 1024 ** 2} MiB VRAM used")
Results in a time to generate of: 93.68 seconds
The normal generation on the other hand just takes: 4.62 seconds
The transformers version where the issue exists is 4.44.2 (at least on colab).
Hmm, I remember we had some memory issues with normal cache a while ago and it was fixed, though I don't remember the exact release. That might be the problem. I'll check out, but as long as the latest branch is not leaking memory we should be okay :)
Was I doing something egregiously wrong with my code? Or is that expected behavior?
Which branch is this test on latency done? On main branch I am getting 9 sec and 16 sec for each of the cache types, which is expected to be that way (x2 latency diff at maximum)
Huh, is that with optimum-quanto? I just used normal quanto for my tests because when using optimum, I get this weird error for the same code: ValueError: shift must be specified for qtypes lower than 8-bit
Code:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import psutil
import gc
def get_gpu_memory():
return torch.cuda.memory_allocated() / 1024**2 # Convert to MB
def get_ram_usage():
return psutil.Process().memory_info().rss / 1024**2 # Convert to MB
# Memory usage before generation
gpu_memory_before = get_gpu_memory()
ram_before = get_ram_usage()
tokenizer = AutoTokenizer.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("aifeifei798/DarkIdol-Llama-3.1-8B-Instruct-1.2-Uncensored", torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset('THUDM/LongBench', "samsum", split='test')
very_long_context = " ".join(dataset["context"])
inputs = tokenizer(very_long_context, max_length=20000, truncation="only_first", return_tensors="pt").to(model.device)
generation_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 20, "min_new_tokens": 20, "cache_implementation": "quantized"}
# Time the generation
start_time = time.time()
out_fp16 = model.generate(**inputs, **generation_kwargs)
generated_text = tokenizer.batch_decode(out_fp16)
end_time = time.time()
# Memory usage after generation
gpu_memory_after = get_gpu_memory()
ram_after = get_ram_usage()
# Calculate differences
gpu_memory_used = gpu_memory_after - gpu_memory_before
ram_used = ram_after - ram_before
time_taken = end_time - start_time
print(f"Generated text: {generated_text}")
print(f"Time taken: {time_taken:.2f} seconds")
print(f"GPU memory used: {gpu_memory_used:.2f} MB")
print(f"RAM used: {ram_used:.2f} MB")
The installation for optimum-quanto was done like this:
!pip install optimum-quanto
With normal quanto, this code takes 90s
I still can't get 90s for quant cache. My env is as follows. And for optimum-quanto you have to install 2.4.0
version (not more) and have transformers from main
. I recently fixes a bug there. But I didn't notice much difference between quanto
and optimum-quanto
- quanto==0.2.0
- `transformers` version: 4.46.0.dev0
- Platform: Linux-5.4.0-166-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.24.3
- Safetensors version: 0.4.3
- Accelerate version: 0.34.2
- Accelerate config: not found
- PyTorch version (GPU?): 2.4.1+cu121 (True)
- Tensorflow version (GPU?): 2.13.1 (False)
- Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- GPU type: NVIDIA A100-SXM4-80GB
That's so weird. Here is my current settings where I still get 90s generation time
Quanto version: 0.2.0
Platform: Linux-6.1.85+-x86_64-with-glibc2.35
Python version: 3.10.12
PyTorch version (GPU?): 2.4.1+cu121 (True)
Tensorflow version (GPU?): 2.17.0 (True)
Flax version (CPU?/GPU?/TPU?): 0.8.5 (cpu)
Jax version: 0.4.33
JaxLib version: 0.4.33
Huggingface_hub version: 0.24.7
Safetensors version: 0.4.5
Transformers version: 4.46.0.dev0
Accelerate version: 0.34.2
GPU type: NVIDIA A100-SXM4-40GB
- `Accelerate` default config: Not found
I also tried with optimum quanto using the right version and it's still taking 90s to generate
Does colab provide free GPUs? I could restructure the code a bit, use a 1B model, and test the quantized version on like 10k context length if that fits in memory. Then we can compare with similar environments
Yes, colab has free 14GiB T4 gpu and if you can get it working there would be nice. I was thinking that it was hardware related since the quantization can get better performance in more recent GPUs. But you had half A100 in your local setting, so I don't have an idea currently
You can't use flash attention on the T4 so it's hard to do this comparison, but this code also leads to an 80s generation time
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import psutil
import gc
def get_gpu_memory():
return torch.cuda.memory_allocated() / 1024**2 # Convert to MB
def get_ram_usage():
return psutil.Process().memory_info().rss / 1024**2 # Convert to MB
# Memory usage before generation
gpu_memory_before = get_gpu_memory()
ram_before = get_ram_usage()
tokenizer = AutoTokenizer.from_pretrained("unsloth/Llama-3.2-1B-Instruct", padding_side="left")
model = AutoModelForCausalLM.from_pretrained("unsloth/Llama-3.2-1B-Instruct", torch_dtype=torch.float16, device_map="auto")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset = load_dataset('THUDM/LongBench', "samsum", split='test')
very_long_context = " ".join(dataset["context"])
inputs = tokenizer(very_long_context, max_length=10000, truncation="only_first", return_tensors="pt").to(model.device)
generation_kwargs = {"do_sample": False, "temperature": 1.0, "top_p": 1.0, "max_new_tokens": 20, "min_new_tokens": 20, "cache_implementation": "quantized"}
# Time the generation
start_time = time.time()
out_fp16 = model.generate(**inputs, **generation_kwargs)
generated_text = tokenizer.batch_decode(out_fp16)
end_time = time.time()
# Memory usage after generation
gpu_memory_after = get_gpu_memory()
ram_after = get_ram_usage()
# Calculate differences
gpu_memory_used = gpu_memory_after - gpu_memory_before
ram_used = ram_after - ram_before
time_taken = end_time - start_time
print(f"Generated text: {generated_text}")
print(f"Time taken: {time_taken:.2f} seconds")
print(f"GPU memory used: {gpu_memory_used:.2f} MB")
print(f"RAM used: {ram_used:.2f} MB")
Btw, these are my pip installs:
!pip install -q git+https://github.com/huggingface/transformers
!pip install datasets accelerate
!pip install -q flash-attn --no-build-isolation
!pip install quanto
Can you try out that code to see your generation time on collab? Would also be interesting to test with and without flash attention on ur A100 to see if u get similar generation time (just in case flash attention isn't working well on the colab A100 for me)
Can you try this code pls, so we can verify it is the quanto quantization that is slow? I am getting 14sec on this. Basically it is all the quantization and dequantization that happens during generating with your input
I will ask internally from guys that maintain quanto what can be the reason
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
import time
import torch
dummy_tensor_inputs = torch.randn(1, 32, 10_000, 128)
optimizer = MaxOptimizer()
qtype = qint4
q_group_size = 64
axis = 0
# quantize once per layer
for _ in range(16):
scale, zeropoint = optimizer(dummy_tensor_inputs, qtype.bits, axis, q_group_size)
qtensor = AffineQuantizer.apply(dummy_tensor_inputs, qtype, axis, q_group_size, scale, zeropoint)
start = time.perf_counter()
for _ in range(16 * 20):
dequant_tensor = qtensor.dequantize()
end = time.perf_counter()
print(f"Time taken: {(end - start):.2f} seconds")
I got this on colab A100: Time taken: 52.24 seconds
oke, so it is something related to how quanto internally works. I will ask what can affect latency, except for hardware
Did you get 14s on the colab T4? I got 60 seconds on T4 on my end
I am getting 164sec 🙃 on free colab with T4, but local machine with A100-80GB runs in 14sec
Holy....
sorry, I forgot to add "cuda" on the last script. Can you move dummy tensor to cuda first before doing all the quant part? On an A100 machine if you can pls
And if it doesn't make things faster, feel free to open an issue in transformers so we can track it there
Like this? dummy_tensor_inputs = torch.randn(1, 32, 10_000, 128).to("cuda")
I tested it and it was still 80 seconds on A100 Colab
yes, and if you can open an issue where I'll tag involved ppl to help you out?
okay, I made an issue here: huggingface/transformers#34096
@zucchini-nlp Somehow I don't see your changes on my end, colab being colab. But if the change looked like this:
cache_position = torch.arange(inputs["input_ids"].shape[1], device="cuda:0")
next_model_kwargs = model._update_model_kwargs_for_generation(
outputs,
inputs,
is_encoder_decoder=False,
cache_position = cache_position
)
Then it will give this error:
in prefill(model, inputs, cache_implementation, nbits)
41 next_input_ids = torch.cat([inputs["input_ids"], next_tokens[:, None]], dim=-1)
42 cache_position = torch.arange(inputs["input_ids"].shape[1], device="cuda:0")
---> 43 next_model_kwargs = model._update_model_kwargs_for_generation(
44 outputs,
45 inputs,
TypeError: GenerationMixin._update_model_kwargs_for_generation() got an unexpected keyword argument 'cache_position'
Honestly, what I really want from this code is how to prefill the kv cache because I'm not sure how to do that with the new transformers code, or at least how to fully utilize the kv cache quantization for long context generation (we want to use it for our long context agent)