Last active
June 3, 2024 13:44
-
-
Save wassname/c6f660f92501a017e8f5792b7a125a3f to your computer and use it in GitHub Desktop.
for huggingface transformers sometime you want to constrain output to json schema and record the probabilities on choices/enums. I use it when rating, judging. It's much more efficient than sampling multiple times.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jaxtyping import Float, Int | |
import torch | |
from torch.nn import functional as F | |
from torch import Tensor | |
from typing import List, Callable, Tuple, Dict, Optional | |
import pandas as pd | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def get_valid_next_choices(choices_tokens, current_tokens): | |
next_choices = [] | |
for choice_tokens in choices_tokens: | |
# if we have some more slots left | |
if len(current_tokens) < len(choice_tokens): | |
# see if current_tokens matches | |
if (choice_tokens[: len(current_tokens)] == current_tokens).all(): | |
c = choice_tokens[len(current_tokens)].item() | |
next_choices.append(c) | |
next_choices = list(set(next_choices)) | |
return torch.LongTensor(next_choices) | |
def choice_tree( | |
model: AutoModelForCausalLM, | |
tokenizer: AutoTokenizer, | |
input_ids: Int[Tensor, "seq"], | |
choices_tokens: List[Int[Tensor, "seq"]], | |
choice: Optional[Int[Tensor, ""]] = None, | |
prob: float = 1, | |
current_tokens: Int[Tensor, "seq"] = torch.LongTensor([]), | |
z=[], | |
): | |
if choice is not None: | |
c = choice[None].to(current_tokens.device) | |
current_tokens = torch.cat([current_tokens, c], dim=-1) | |
c = choice[None].to(input_ids.device) | |
input_ids = torch.cat([input_ids, c], dim=-1) | |
next_choices = get_valid_next_choices(choices_tokens, current_tokens) | |
if len(next_choices) == 0: | |
s = tokenizer.decode(current_tokens) | |
r = dict(prob=prob, choice=s) | |
yield r | |
else: | |
o = model(input_ids[None]) | |
logits_constrained = o.logits[0, -1][next_choices] | |
probs = F.softmax(logits_constrained, dim=-1) | |
for i in range(len(next_choices)): | |
next_choice = next_choices[i] | |
next_prob = prob * probs[i].item() | |
yield from choice_tree( | |
model=model, | |
tokenizer=tokenizer, | |
choices_tokens=choices_tokens, | |
input_ids=input_ids, | |
choice=next_choice, | |
prob=next_prob, | |
current_tokens=current_tokens, | |
z=z + [i], | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
see https://github.com/wassname/prob_jsonformer.git