Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active May 9, 2024 05:00
Show Gist options
  • Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
Save wassname/4af760435447d38a3012c6e39abb58e1 to your computer and use it in GitHub Desktop.
simple perplexity for huggingface models similar to llam..cpp
# Directly taken from https://huggingface.co/spaces/evaluate-measurement/perplexity/blob/main/perplexity.py
# TODO replace with a strided version https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524
import numpy as np
import torch
import itertools
from torch.nn import CrossEntropyLoss
from tqdm.auto import tqdm
import torch.nn.functional as F
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
def nll_loss_no_mean(logits, labels):
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1228
logits = logits.float()
# Shift so that tokens < n predict n
vocab_size = logits.shape[-1]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100, reduce=False)
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
return loss_fct(shift_logits, shift_labels)
def create_batch(input_ids, loss_mask, batch_i, batch_size, stride):
text_len = input_ids.size(1)
# create batch inds
begin_locs, end_locs, trg_lens = [], [], []
for j in range(batch_size):
j = batch_i + j * stride
if j >= text_len:
break
begin_loc = max(j, 0)
end_loc = min(j + stride, text_len)
trg_len = end_loc - j # may be different from stride on last loop
begin_locs.append(begin_loc)
end_locs.append(end_loc)
trg_lens.append(trg_len)
# create batch
b_input_ids = [input_ids[:, b:e] for b, e in zip(begin_locs, end_locs)]
b_input_ids = torch.stack(b_input_ids, dim=1).squeeze(0)
b_loss_mask = [loss_mask[:, b:e] for b, e in zip(begin_locs, end_locs)]
b_loss_mask = torch.stack(b_loss_mask, dim=1).squeeze(0)
# create target
target_ids = torch.ones_like(b_input_ids) * -100 # -100 is the default ingore_index value in torch.nn.CrossEntropyLoss
target_end_locs = [sen.size(-1) for sen in b_input_ids]
for i, (b, e) in enumerate(zip(trg_lens, target_end_locs)):
labels = b_input_ids[i, -b:e].clone()
target_ids[i, -b:e] = labels
target_ids[~b_loss_mask]=-100
return b_input_ids, target_ids
@torch.no_grad()
def batched_perplexity(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, dataset: Dataset = None, batch_size=32, stride=512):
"""
Better perplexity calculation for causal language models.
Args:
model: A pretrained language model
tokenizer: The tokenizer used to preprocess the data
dataset: A dataset to calculate perplexity on. If None, the wikitext-2 test set is used.
batch_size: The batch size to use for perplexity calculation
stride: The stride to use for perplexity calculation - Important, changing this will change your results
Comparison again other implementations:
- https://huggingface.co/docs/transformers/perplexity - takes the mean of means giving it the wrong value
- https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py - compelx and crops sentances so it's not comparable
- https://github.com/ggerganov/llama.cpp/tree/master/examples/perplexity - good but in cpp
- https://github.com/huggingface/transformers/issues/9648#issuecomment-812981524 - doesn't use special tokens
Limitations of this implementation:
- if a token is at the start of a strided window, it has no context, so it's perplexity is higher. TODO: have overlapping windows
- uses special tokens, hard to compare to scores that do not
"""
if dataset is None:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test[:10%]")["text"]
dataset = dataset.filter(lambda x: len(x) > 0)
device = next(iter(model.parameters())).device
i = tokenizer(dataset, add_special_tokens=True, return_special_tokens_mask=True)
input_ids = torch.tensor(list(itertools.chain(*i.input_ids))).to(torch.long).unsqueeze(0)
# without padding or truncation we don't need attention but we do need special_tokens
attention_mask = torch.tensor(list(itertools.chain(*i.attention_mask))).to(torch.bool).unsqueeze(0)
special_tokens_mask = torch.tensor(list(itertools.chain(*i.special_tokens_mask))).to(torch.bool).unsqueeze(0)
# let's not calc the perplexity on special_tokens
loss_mask = attention_mask & ~special_tokens_mask
text_len = input_ids.size(1)
lls = []
for i in tqdm(range(0, text_len, batch_size * stride)):
b_input_ids, target_ids = create_batch(input_ids, loss_mask, i, batch_size, stride)
b_input_ids = b_input_ids.to(device)
target_ids = target_ids.to(device)
logits = model(b_input_ids).logits
log_likelihood = nll_loss_no_mean(logits, target_ids)
lls.extend(log_likelihood.view(-1).cpu().tolist())
lls = torch.tensor(lls)
ppl = lls.mean().exp()
return ppl.cpu().item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment