Last active
May 9, 2024 05:00
-
-
Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
simple perplexity for huggingface models similar to llam..cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py | |
# TODO replace with a strided version https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 | |
import numpy as np | |
import torch | |
import itertools | |
from torch.nn import CrossEntropyLoss | |
from tqdm.auto import tqdm | |
import torch.nn.functional as F | |
from datasets import load_dataset, Dataset | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
def nll_loss_no_mean(logits, labels): | |
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1228 | |
logits = logits.float() | |
# Shift so that tokens < n predict n | |
vocab_size = logits.shape[-1] | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss(ignore_index=-100, reduce=False) | |
shift_logits = shift_logits.view(-1, vocab_size) | |
shift_labels = shift_labels.view(-1) | |
# Enable model parallelism | |
shift_labels = shift_labels.to(shift_logits.device) | |
return loss_fct(shift_logits, shift_labels) | |
def create_batch(input_ids, loss_mask, batch_i, batch_size, stride): | |
text_len = input_ids.size(1) | |
# create batch inds | |
begin_locs, end_locs, trg_lens = [], [], [] | |
for j in range(batch_size): | |
j = batch_i + j * stride | |
if j >= text_len: | |
break | |
begin_loc = max(j, 0) | |
end_loc = min(j + stride, text_len) | |
trg_len = end_loc - j # may be different from stride on last loop | |
begin_locs.append(begin_loc) | |
end_locs.append(end_loc) | |
trg_lens.append(trg_len) | |
# create batch | |
b_input_ids = [input_ids[:, b:e] for b, e in zip(begin_locs, end_locs)] | |
b_input_ids = torch.stack(b_input_ids, dim=1).squeeze(0) | |
b_loss_mask = [loss_mask[:, b:e] for b, e in zip(begin_locs, end_locs)] | |
b_loss_mask = torch.stack(b_loss_mask, dim=1).squeeze(0) | |
# create target | |
target_ids = torch.ones_like(b_input_ids) * -100 # -100 is the default ingore_index value in torch.nn.CrossEntropyLoss | |
target_end_locs = [sen.size(-1) for sen in b_input_ids] | |
for i, (b, e) in enumerate(zip(trg_lens, target_end_locs)): | |
labels = b_input_ids[i, -b:e].clone() | |
target_ids[i, -b:e] = labels | |
target_ids[~b_loss_mask]=-100 | |
return b_input_ids, target_ids | |
@torch.no_grad() | |
def batched_perplexity(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset: Dataset = None, batch_size=32, stride=512): | |
""" | |
Better perplexity calculation for causal language models. | |
Args: | |
model: A pretrained language model | |
tokenizer: The tokenizer used to preprocess the data | |
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used. | |
batch_size: The batch size to use for perplexity calculation | |
stride: The stride to use for perplexity calculation - Important, changing this will change your results | |
Comparison again other implementations: | |
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value | |
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable | |
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp | |
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens | |
Limitations of this implementation: | |
- if a token is at the start of a strided window, it has no context, so it's perplexity is higher. TODO: have overlapping windows | |
- uses special tokens, hard to compare to scores that do not | |
""" | |
if dataset is None: | |
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"] | |
dataset = dataset.filter(lambda x: len(x) > 0) | |
device = next(iter(model.parameters())).device | |
i = tokenizer(dataset, add_special_tokens=True, return_special_tokens_mask=True) | |
input_ids = torch.tensor(list(itertools.chain(*i.input_ids))).to(torch.long).unsqueeze(0) | |
# without padding or truncation we don't need attention but we do need special_tokens | |
attention_mask = torch.tensor(list(itertools.chain(*i.attention_mask))).to(torch.bool).unsqueeze(0) | |
special_tokens_mask = torch.tensor(list(itertools.chain(*i.special_tokens_mask))).to(torch.bool).unsqueeze(0) | |
# let's not calc the perplexity on special_tokens | |
loss_mask = attention_mask & ~special_tokens_mask | |
text_len = input_ids.size(1) | |
lls = [] | |
for i in tqdm(range(0, text_len, batch_size * stride)): | |
b_input_ids, target_ids = create_batch(input_ids, loss_mask, i, batch_size, stride) | |
b_input_ids = b_input_ids.to(device) | |
target_ids = target_ids.to(device) | |
logits = model(b_input_ids).logits | |
log_likelihood = nll_loss_no_mean(logits, target_ids) | |
lls.extend(log_likelihood.view(-1).cpu().tolist()) | |
lls = torch.tensor(lls) | |
ppl = lls.mean().exp() | |
return ppl.cpu().item() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment