- grammar_hack.py is completely new;
- all the other files are minimally modified, with additions marked by
#####
. - Bring your own
grammar.ebnf
. - The LRU cache is very sketchy, but prevented having to plumb through a correct lifetime.
- Not sure what happens if it gets backed into a corner where there's no valid next state for the state machine.
Last active
September 4, 2023 01:53
-
-
Save burke/6d035758f7492612e2ad86bb7de2d5fb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
FROM ghcr.io/huggingface/text-generation-inference:0.8 | |
COPY grammar.ebnf /opt/grammar.ebnf | |
COPY flash_causal_lm.py /opt/conda/lib/python3.9/site-packages/text_generation_server/models/ | |
COPY flash_llama.py /opt/conda/lib/python3.9/site-packages/text_generation_server/models/ | |
COPY flash_santacoder.py /opt/conda/lib/python3.9/site-packages/text_generation_server/models/ | |
COPY grammar_hack.py /opt/conda/lib/python3.9/site-packages/text_generation_server/ | |
RUN rm /opt/conda/lib/python3.9/site-packages/text_generation_server/models/__pycache__/* | |
RUN pip install --no-deps torch-grammar==0.3.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed | |
import numpy as np | |
from torch.nn import functional as F | |
from dataclasses import dataclass | |
from opentelemetry import trace | |
from transformers import AutoTokenizer, PreTrainedTokenizerBase, PreTrainedModel | |
from typing import Optional, Tuple, List, Type, Union, Dict | |
from text_generation_server.models import Model | |
from text_generation_server.models.types import ( | |
Batch, | |
PrefillTokens, | |
Generation, | |
GeneratedText, | |
) | |
from text_generation_server.pb import generate_pb2 | |
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser | |
tracer = trace.get_tracer(__name__) | |
##### hack-text-generation-inference ##### | |
from text_generation_server.grammar_hack import grammar_hack_prepare, grammar_hack_accept_tokens, grammar_hack_commit | |
@dataclass | |
class FlashCausalLMBatch(Batch): | |
batch_id: int | |
requests: List[generate_pb2.Request] | |
# request id -> idx in list mapping | |
requests_idx_mapping: Dict[int, int] | |
# Decoder values | |
input_ids: torch.Tensor | |
position_ids: torch.Tensor | |
# cumulative sequence lengths | |
cu_seqlens: torch.Tensor | |
# cumulative query sequence lengths, only used in decode | |
cu_seqlens_q: Optional[torch.Tensor] | |
# past key values, only used in decode | |
past_key_values: Optional[torch.Tensor] | |
max_seqlen: int | |
# All tokens | |
all_input_ids: List[List[int]] | |
all_input_ids_tensor: torch.Tensor | |
# Lengths of all generations present in the batch | |
input_lengths: List[int] | |
prefix_offsets: List[Optional[int]] | |
read_offsets: List[Optional[int]] | |
# Generation helpers | |
next_token_chooser: HeterogeneousNextTokenChooser | |
stopping_criterias: List[StoppingCriteria] | |
# Maximum number of tokens this batch will grow to | |
max_tokens: int | |
def to_pb(self) -> generate_pb2.CachedBatch: | |
return generate_pb2.CachedBatch( | |
id=self.batch_id, | |
request_ids=[r.id for r in self.requests], | |
size=len(self), | |
max_tokens=self.max_tokens, | |
) | |
@classmethod | |
def from_pb( | |
cls, | |
pb: generate_pb2.Batch, | |
tokenizer: PreTrainedTokenizerBase, | |
dtype: torch.dtype, | |
device: torch.device, | |
) -> "FlashCausalLMBatch": | |
position_ids = [] | |
cu_seqlens = [0] | |
max_seqlen = 0 | |
input_lengths = [] | |
prefix_offsets = [] | |
read_offsets = [] | |
all_input_ids = [] | |
requests_idx_mapping = {} | |
next_token_chooser_parameters = [] | |
stopping_criterias = [] | |
# Cumulative length | |
cumulative_length = 0 | |
max_tokens = 0 | |
max_length = 0 | |
# Parse batch | |
for i, r in enumerate(pb.requests): | |
# request id -> idx in list mapping | |
requests_idx_mapping[r.id] = i | |
tokenized_input = tokenizer( | |
r.inputs, truncation=True, max_length=r.truncate | |
)["input_ids"] | |
input_length = len(tokenized_input) | |
max_seqlen = max(max_seqlen, input_length) | |
input_lengths.append(input_length) | |
prefix_offsets.append(0) | |
read_offsets.append(input_length) | |
all_input_ids.append(tokenized_input) | |
# Position ids | |
position_ids.append(np.arange(0, input_length)) | |
# Add cumulative lengths of all previous inputs | |
cu_seqlens.append(cumulative_length + input_length) | |
next_token_chooser_parameters.append(r.parameters) | |
stopping_criteria = StoppingCriteria.from_pb( | |
r.stopping_parameters, tokenizer | |
) | |
max_new_tokens = stopping_criteria.max_new_tokens | |
stopping_criterias.append(stopping_criteria) | |
# Update | |
cumulative_length += input_length | |
max_tokens += input_length + max_new_tokens | |
max_length = max(max_length, input_length + max_new_tokens) | |
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( | |
next_token_chooser_parameters, dtype, device | |
) | |
# Padded all_input_ids_tensor | |
all_input_ids_tensor = np.zeros( | |
(len(all_input_ids), max_length), dtype=np.int64 | |
) | |
for i, input_ids in enumerate(all_input_ids): | |
all_input_ids_tensor[i, : len(input_ids)] = input_ids | |
# Create tensors on device | |
input_ids = torch.tensor( | |
np.concatenate(all_input_ids), dtype=torch.int64, device=device | |
) | |
all_input_ids_tensor = torch.tensor( | |
all_input_ids_tensor, dtype=torch.int64, device=device | |
) | |
position_ids = torch.tensor( | |
np.concatenate(position_ids), dtype=torch.int32, device=device | |
) | |
cu_seqlens = torch.tensor(cu_seqlens, device=device, dtype=torch.int32) | |
return cls( | |
batch_id=pb.id, | |
requests=pb.requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
cu_seqlens_q=None, | |
max_seqlen=max_seqlen, | |
past_key_values=None, | |
input_lengths=input_lengths, | |
prefix_offsets=prefix_offsets, | |
read_offsets=read_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_chooser=next_token_chooser, | |
stopping_criterias=stopping_criterias, | |
max_tokens=max_tokens, | |
) | |
@tracer.start_as_current_span("filter") | |
def filter(self, request_ids: List[int]) -> "FlashCausalLMBatch": | |
if len(request_ids) == 0: | |
raise ValueError("Batch must have at least one request") | |
# We assume that if len(requests) == len(self) then the requests are the same | |
if len(request_ids) == len(self): | |
return self | |
single_request = len(request_ids) == 1 | |
# Cumulative length | |
cumulative_length = 0 | |
# New values after filtering | |
requests_idx_mapping = {} | |
# Used to index into tensors | |
indices = [] | |
# Create on CPU to only move to GPU once instead of at every copy | |
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32) | |
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1] | |
max_seqlen = 0 | |
past_key_values = [] | |
requests = [] | |
all_input_ids = [] | |
input_lengths = [] | |
prefix_offsets = [] | |
read_offsets = [] | |
stopping_criterias = [] | |
max_tokens = 0 | |
for i, request_id in enumerate(request_ids): | |
idx = self.requests_idx_mapping[request_id] | |
indices.append(idx) | |
requests_idx_mapping[request_id] = i | |
requests.append(self.requests[idx]) | |
# Get length | |
request_input_length = self.input_lengths[idx] | |
# Copy to tensor (CPU) | |
cu_seqlens[i + 1] = cumulative_length + request_input_length | |
max_seqlen = max(max_seqlen, request_input_length) | |
# Slice from past | |
past_key_values.append( | |
self.past_key_values[:, self.cu_seqlens[idx] : self.cu_seqlens[idx + 1]] | |
) | |
all_input_ids.append(self.all_input_ids[idx]) | |
input_lengths.append(request_input_length) | |
prefix_offsets.append(self.prefix_offsets[idx]) | |
read_offsets.append(self.read_offsets[idx]) | |
stopping_criteria = self.stopping_criterias[idx] | |
stopping_criterias.append(stopping_criteria) | |
cumulative_length += request_input_length | |
max_tokens += request_input_length + ( | |
stopping_criteria.max_new_tokens - stopping_criteria.current_tokens | |
) | |
if single_request: | |
# Preallocate tensor for bs = 1 case | |
past_key_values = F.pad( | |
past_key_values[0], | |
( | |
0, | |
0, | |
0, | |
0, | |
0, | |
0, | |
0, | |
stopping_criterias[0].max_new_tokens | |
- stopping_criterias[0].current_tokens, | |
), | |
) | |
else: | |
# Cat all past | |
past_key_values = torch.cat(past_key_values, dim=1) | |
# Index into tensors | |
input_ids = self.input_ids[indices] | |
position_ids = self.position_ids[indices] | |
all_input_ids_tensor = self.all_input_ids_tensor[indices] | |
next_token_chooser = self.next_token_chooser.filter(indices) | |
# Move to GPU now that we have the whole tensor | |
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device) | |
return FlashCausalLMBatch( | |
batch_id=self.batch_id, | |
requests=requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
cu_seqlens_q=cu_seqlens_q, | |
max_seqlen=max_seqlen, | |
past_key_values=past_key_values, | |
input_lengths=input_lengths, | |
prefix_offsets=prefix_offsets, | |
read_offsets=read_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_chooser=next_token_chooser, | |
stopping_criterias=stopping_criterias, | |
max_tokens=max_tokens, | |
) | |
@classmethod | |
@tracer.start_as_current_span("concatenate") | |
def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch": | |
# Batch attributes | |
requests = [] | |
requests_idx_mapping = {} | |
total_batch_size = sum([len(b) for b in batches]) | |
dtype = batches[0].past_key_values.dtype | |
device = batches[0].input_ids.device | |
input_ids = batches[0].input_ids.new_empty(total_batch_size) | |
position_ids = batches[0].position_ids.new_empty(total_batch_size) | |
cu_seqlens = [0] | |
cu_seqlens_q = torch.arange( | |
0, total_batch_size + 1, device=device, dtype=torch.int32 | |
) | |
max_seqlen = 0 | |
past_key_values = [] | |
all_input_ids = [] | |
input_lengths = [] | |
prefix_offsets = [] | |
read_offsets = [] | |
next_token_chooser_parameters = [] | |
stopping_criterias = [] | |
# Cumulative length | |
cumulative_batch_size = 0 | |
cumulative_length = 0 | |
max_tokens = 0 | |
max_length = 0 | |
for i, batch in enumerate(batches): | |
requests.extend(batch.requests) | |
if i == 0: | |
requests_idx_mapping = batch.requests_idx_mapping | |
else: | |
# We need to offset the mapping for each batch by the cumulative batch size | |
for k, v in batch.requests_idx_mapping.items(): | |
requests_idx_mapping[k] = v + cumulative_batch_size | |
start_index = cumulative_batch_size | |
end_index = cumulative_batch_size + len(batch) | |
# Copy tensors (GPU) | |
input_ids[start_index:end_index] = batch.input_ids | |
position_ids[start_index:end_index] = batch.position_ids | |
# Add cumulative lengths of all previous inputs | |
cu_seqlens.extend([l + cumulative_length for l in batch.cu_seqlens[1:]]) | |
max_seqlen = max(max_seqlen, batch.max_seqlen) | |
if len(batch) != 1: | |
past_key_values.append(batch.past_key_values) | |
else: | |
# past was pre-allocated for this batch | |
# We need to slice to remove the padding | |
past_key_values.append( | |
batch.past_key_values[:, : batch.input_lengths[0]] | |
) | |
all_input_ids.extend(batch.all_input_ids) | |
input_lengths.extend(batch.input_lengths) | |
prefix_offsets.extend(batch.prefix_offsets) | |
read_offsets.extend(batch.read_offsets) | |
next_token_chooser_parameters.extend([r.parameters for r in batch.requests]) | |
stopping_criterias.extend(batch.stopping_criterias) | |
# Update | |
cumulative_length += batch.cu_seqlens[-1] | |
cumulative_batch_size += len(batch) | |
max_tokens += batch.max_tokens | |
max_length = max( | |
max_length, | |
max( | |
input_length | |
+ stopping_criteria.max_new_tokens | |
- stopping_criteria.current_tokens | |
for input_length, stopping_criteria in zip( | |
batch.input_lengths, batch.stopping_criterias | |
) | |
), | |
) | |
all_input_ids_tensor = torch.zeros( | |
(total_batch_size, max_length), dtype=torch.int64, device=device | |
) | |
cumulative_batch_size = 0 | |
for i, batch in enumerate(batches): | |
start_index = cumulative_batch_size | |
end_index = cumulative_batch_size + len(batch) | |
all_input_ids_tensor[ | |
start_index:end_index, : batch.all_input_ids_tensor.shape[1] | |
] = batch.all_input_ids_tensor[:, :max_length] | |
cumulative_batch_size += len(batch) | |
# Cat past | |
past_key_values = torch.cat(past_key_values, dim=1) | |
# Create final tensor on GPU | |
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) | |
next_token_chooser = HeterogeneousNextTokenChooser.from_pb( | |
next_token_chooser_parameters, dtype=dtype, device=device | |
) | |
return FlashCausalLMBatch( | |
batch_id=batches[0].batch_id, | |
requests=requests, | |
requests_idx_mapping=requests_idx_mapping, | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
cu_seqlens_q=cu_seqlens_q, | |
max_seqlen=max_seqlen, | |
past_key_values=past_key_values, | |
input_lengths=input_lengths, | |
prefix_offsets=prefix_offsets, | |
read_offsets=read_offsets, | |
all_input_ids=all_input_ids, | |
all_input_ids_tensor=all_input_ids_tensor, | |
next_token_chooser=next_token_chooser, | |
stopping_criterias=stopping_criterias, | |
max_tokens=max_tokens, | |
) | |
def __len__(self): | |
return len(self.requests) | |
class FlashCausalLM(Model): | |
def __init__( | |
self, | |
model_cls: Type[PreTrainedModel], | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: Optional[str] = None, | |
trust_remote_code: bool = False, | |
): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
else: | |
raise NotImplementedError("FlashCausalLM is only available on GPU") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
revision=revision, | |
padding_side="left", | |
truncation_side="left", | |
trust_remote_code=trust_remote_code, | |
) | |
model = model_cls.from_pretrained( | |
model_id, | |
revision=revision, | |
torch_dtype=dtype, | |
load_in_8bit=quantize == "bitsandbytes", | |
trust_remote_code=trust_remote_code, | |
).to(device) | |
super(FlashCausalLM, self).__init__( | |
model=model, | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
) | |
@property | |
def batch_type(self) -> Type[FlashCausalLMBatch]: | |
return FlashCausalLMBatch | |
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str: | |
return self.tokenizer.decode( | |
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
position_ids: torch.Tensor, | |
cu_seqlens: torch.Tensor, | |
cu_seqlens_q: Optional[torch.Tensor], | |
max_s: int, | |
past_key_values: Optional = None, | |
pre_allocate_past_size: Optional[int] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Model Forward | |
return self.model.forward( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
cu_seqlens=cu_seqlens, | |
cu_seqlens_q=cu_seqlens_q, | |
max_s=max_s, | |
past_key_values=past_key_values, | |
pre_allocate_past_size=pre_allocate_past_size, | |
) | |
@tracer.start_as_current_span("generate_token") | |
def generate_token( | |
self, batch: FlashCausalLMBatch | |
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]: | |
prefill = batch.past_key_values is None | |
single_request = len(batch) == 1 | |
if prefill and len(batch) == 1: | |
# Ask to pre-allocate kv to its max size | |
# == number of tokens + max_new_tokens | |
pre_allocate_past_size = ( | |
batch.input_lengths[0] + batch.stopping_criterias[0].max_new_tokens | |
) | |
else: | |
pre_allocate_past_size = None | |
out, present = self.forward( | |
batch.input_ids, | |
batch.position_ids, | |
batch.cu_seqlens, | |
batch.cu_seqlens_q, | |
batch.max_seqlen, | |
batch.past_key_values, | |
pre_allocate_past_size, | |
) | |
##### hack-text-generation-inference ##### | |
idx_stacks = grammar_hack_prepare(batch.requests_idx_mapping, self.stackstore, self.grammar) | |
if prefill: | |
next_token_logits = ( | |
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1] | |
) | |
else: | |
##### hack-text-generation-inference ##### | |
grammar_hack_accept_tokens(self.grammar, batch.input_ids, idx_stacks) | |
next_token_logits = out | |
##### hack-text-generation-inference ##### | |
grammar_hack_commit(self.grammar, self.stackstore, idx_stacks, batch.requests_idx_mapping, next_token_logits) | |
next_input_ids, next_token_logprobs = batch.next_token_chooser( | |
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits | |
) | |
if prefill: | |
if len(batch) > 1: | |
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs | |
# When batch == 1, we will just use the batch.input_ids values directly | |
prefill_tokens_indices = batch.input_ids.new_zeros(len(batch.input_ids)) | |
# Create batch.cu_seqlens_q for decode | |
batch.cu_seqlens_q = torch.arange( | |
0, len(batch) + 1, device=self.device, dtype=torch.int32 | |
) | |
next_position_ids = batch.position_ids.new_empty(len(batch)) | |
else: | |
prefill_logprobs = None | |
next_position_ids = batch.position_ids | |
# Prepare past for next decode | |
if len(batch) > 1: | |
# Used to slice next batch past | |
past_indices = torch.empty( | |
present.shape[1], dtype=torch.int64, device=self.device | |
) | |
batch.past_key_values = present.new_empty( | |
( | |
present.shape[0], | |
present.shape[1] + len(batch.requests), | |
*present.shape[2:], | |
) | |
) | |
# It is actually faster to do a whole other for loop here as the copy from present to past is fairly slow | |
# and will run asynchronously while we do the next for loop | |
cumulative_length = 0 | |
for i, input_length in enumerate(batch.input_lengths): | |
# Indexing metadata | |
start_index = cumulative_length | |
end_index = cumulative_length + input_length | |
# Indices to copy present at the correct place in past_key_values | |
torch.arange( | |
start_index + i, | |
end_index + i, | |
dtype=torch.int64, | |
device=self.device, | |
out=past_indices[start_index:end_index], | |
) | |
cumulative_length += input_length | |
# Copy from present to past_key_values | |
batch.past_key_values[:, past_indices] = present | |
# Initialize past_key_values in prefill for len(batch) == 1 | |
elif prefill: | |
# present is already pre-padded | |
batch.past_key_values = present | |
# Cumulative length | |
cumulative_length = 0 | |
# Results | |
generations: List[Generation] = [] | |
stopped = True | |
# Zipped iterator | |
iterator = zip( | |
batch.input_lengths, | |
batch.stopping_criterias, | |
batch.all_input_ids, | |
) | |
# We do two for loops as the first one can run completely asynchronously from the GPU while for the second | |
# one, we need to first do a GPU <-> CPU sync | |
# It is faster if we delay this sync for the maximum amount of time | |
# For each member of the batch | |
for i, ( | |
input_length, | |
stopping_criteria, | |
all_input_ids, | |
) in enumerate(iterator): | |
# Indexing metadata | |
start_index = cumulative_length | |
end_index = cumulative_length + input_length | |
if prefill: | |
# Initialize position_ids | |
# In decode, we do not need this as we can just increment position ids | |
next_position_ids[i] = batch.position_ids[end_index - 1] | |
# Used to gather prefill logprobs | |
# Copy batch.input_ids to prefill_token_indices | |
if len(batch) > 1: | |
prefill_tokens_indices[ | |
start_index : end_index - 1 | |
] = batch.input_ids[start_index + 1 : end_index] | |
else: | |
# Set prefill_tokens_indices to the correct slice | |
prefill_tokens_indices = batch.input_ids[ | |
start_index + 1 : end_index | |
] | |
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i] | |
cumulative_length += input_length | |
# Set values in batch | |
batch.input_ids = next_input_ids | |
batch.position_ids = next_position_ids + 1 | |
batch.cu_seqlens = batch.cu_seqlens + batch.cu_seqlens_q | |
if prefill: | |
# Get prefill logprobs | |
prefill_logprobs_tensor = torch.log_softmax(out, -1) | |
prefill_logprobs = torch.gather( | |
prefill_logprobs_tensor, 1, prefill_tokens_indices.view(-1, 1) | |
) | |
# GPU <-> CPU sync | |
prefill_logprobs = prefill_logprobs.view(-1).tolist() | |
# GPU <-> CPU sync | |
next_token_logprobs = next_token_logprobs.tolist() | |
next_token_ids = batch.input_ids.tolist() | |
cumulative_length = 0 | |
# Zipped iterator | |
iterator = zip( | |
batch.requests, | |
batch.input_lengths, | |
batch.prefix_offsets, | |
batch.read_offsets, | |
batch.stopping_criterias, | |
batch.all_input_ids, | |
batch.all_input_ids_tensor, | |
batch.next_token_chooser.do_sample, | |
batch.next_token_chooser.seeds, | |
next_token_ids, | |
next_token_logprobs, | |
) | |
# For each member of the batch | |
for i, ( | |
request, | |
input_length, | |
prefix_offset, | |
read_offset, | |
stopping_criteria, | |
all_input_ids, | |
all_input_ids_tensor, | |
do_sample, | |
seed, | |
next_token_id, | |
next_token_logprob, | |
) in enumerate(iterator): | |
start_index = cumulative_length | |
end_index = cumulative_length + input_length | |
# Append next token to all tokens | |
all_input_ids.append(next_token_id) | |
# Generated token | |
next_token_text, prefix_offset, read_offset = self.decode_token( | |
all_input_ids, | |
prefix_offset, | |
read_offset, | |
) | |
# Evaluate stopping criteria | |
stop, reason = stopping_criteria( | |
next_token_id, | |
next_token_text, | |
) | |
if not stop: | |
stopped = False | |
# Shard generations | |
# All generations will be appended in the rust sharded client | |
if i % self.world_size == self.rank: | |
if stop: | |
# Decode generated tokens | |
output_text = self.decode( | |
all_input_ids[-stopping_criteria.current_tokens :] | |
) | |
generated_text = GeneratedText( | |
output_text, | |
stopping_criteria.current_tokens, | |
reason, | |
seed if do_sample else None, | |
) | |
else: | |
generated_text = None | |
# Prefill | |
if prefill: | |
# Remove generated token to only have prefill and add nan for first prompt token | |
request_prefill_logprobs = [float("nan")] + prefill_logprobs[ | |
start_index : end_index - 1 | |
] | |
prefill_token_ids = all_input_ids[:-1] | |
prefill_texts = self.tokenizer.batch_decode( | |
prefill_token_ids, | |
clean_up_tokenization_spaces=False, | |
skip_special_tokens=False, | |
) | |
prefill_tokens = PrefillTokens( | |
prefill_token_ids, request_prefill_logprobs, prefill_texts | |
) | |
else: | |
prefill_tokens = None | |
generation = Generation( | |
request.id, | |
prefill_tokens, | |
next_token_id, | |
next_token_logprob, | |
next_token_text, | |
next_token_id in self.all_special_ids, | |
generated_text, | |
) | |
generations.append(generation) | |
new_input_length = input_length + 1 | |
# Update values | |
batch.input_lengths[i] = new_input_length | |
batch.prefix_offsets[i] = prefix_offset | |
batch.read_offsets[i] = read_offset | |
batch.all_input_ids[i] = all_input_ids | |
cumulative_length += input_length | |
batch.max_seqlen = batch.max_seqlen + 1 | |
# No need to return a batch if we know that all requests stopped | |
return generations, batch if not stopped else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed | |
from accelerate import init_empty_weights | |
from opentelemetry import trace | |
from pathlib import Path | |
from safetensors import safe_open | |
from transformers import AutoConfig | |
from transformers.models.llama import LlamaTokenizer | |
from typing import Optional, List | |
from text_generation_server.models import FlashCausalLM | |
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( | |
FlashLlamaForCausalLM, | |
TensorParallelEmbedding, | |
TensorParallelRowLinear, | |
TensorParallelColumnLinear, | |
) | |
from text_generation_server.utils import ( | |
initialize_torch_distributed, | |
weight_files, | |
download_weights, | |
weight_hub_files, | |
LocalEntryNotFoundError, | |
) | |
tracer = trace.get_tracer(__name__) | |
##### hack-text-generation-inference ##### | |
from text_generation_server.grammar_hack import grammar_hack_init | |
class FlashLlama(FlashCausalLM): | |
def __init__( | |
self, | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: Optional[str] = None, | |
trust_remote_code: bool = False, | |
): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
else: | |
raise NotImplementedError("FlashLlama is only available on GPU") | |
tokenizer = LlamaTokenizer.from_pretrained( | |
model_id, | |
revision=revision, | |
padding_side="left", | |
truncation_side="left", | |
trust_remote_code=trust_remote_code, | |
) | |
config = AutoConfig.from_pretrained( | |
model_id, revision=revision, trust_remote_code=trust_remote_code | |
) | |
# We do not use from_pretrained as we modified the model internal module layout | |
try: | |
filenames = weight_files(model_id, revision, ".bin") | |
# Local files not found | |
except LocalEntryNotFoundError: | |
hub_files = weight_hub_files(model_id, revision, ".bin") | |
filenames = download_weights(hub_files, model_id, revision) | |
with init_empty_weights(): | |
model = FlashLlamaForCausalLM(config) | |
self.load_weights(model, filenames, quantize, device, dtype) | |
super(FlashCausalLM, self).__init__( | |
model=model.to(device), | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
) | |
@staticmethod | |
def load_weights( | |
model, | |
filenames: List[Path], | |
quantize: Optional[str], | |
device: torch.device, | |
dtype: torch.dtype, | |
): | |
for filename in filenames: | |
state_dict = torch.load(filename, map_location="cpu") | |
for key, value in state_dict.items(): | |
value = value.to(device if quantize is None else "cpu").to(dtype) | |
layer_name = ".".join(key.split(".")[:4]) | |
# Fused qkv | |
if "q_proj" in key or "k_proj" in key or "v_proj" in key: | |
final_key = layer_name + ".query_key_value.weight" | |
# Fused gate and up projs | |
elif "gate_proj" in key or "up_proj" in key: | |
final_key = layer_name + ".gate_up_proj.weight" | |
else: | |
final_key = key | |
module_name, param_name = final_key.rsplit(".", 1) | |
module = model.get_submodule(module_name) | |
try: | |
current_parameter_tensor = module._parameters[param_name] | |
except KeyError: | |
current_parameter_tensor = None | |
if current_parameter_tensor is not None: | |
if current_parameter_tensor.device == torch.device("meta"): | |
# Init qkv | |
if "query_key_value" in final_key: | |
module._parameters[param_name] = value.new_empty( | |
(value.shape[0] * 3, value.shape[1]) | |
) | |
# Init gate and up proj | |
elif "gate_up_proj" in final_key: | |
module._parameters[param_name] = value.new_empty( | |
(value.shape[0] * 2, value.shape[1]) | |
) | |
# Copy to correct slice | |
if "q_proj" in key: | |
module._parameters[param_name][: value.shape[0]] = value | |
elif "k_proj" in key: | |
module._parameters[param_name][ | |
value.shape[0] : value.shape[0] * 2 | |
] = value | |
elif "v_proj" in key: | |
module._parameters[param_name][value.shape[0] * 2 :] = value | |
elif "gate_proj" in key: | |
module._parameters[param_name][: value.shape[0]] = value | |
elif "up_proj" in key: | |
module._parameters[param_name][value.shape[0] :] = value | |
else: | |
if current_parameter_tensor.shape != value.shape: | |
raise ValueError( | |
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" | |
) | |
module._parameters[param_name] = value | |
else: | |
module._buffers[param_name] = value | |
del value | |
torch.cuda.empty_cache() | |
model.post_load_weights(quantize) | |
class FlashLlamaSharded(FlashLlama): | |
def __init__( | |
self, | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: Optional[str] = None, | |
trust_remote_code: bool = False, | |
): | |
self.process_group, rank, world_size = initialize_torch_distributed() | |
if torch.cuda.is_available(): | |
device = torch.device(f"cuda:{rank}") | |
dtype = torch.float16 | |
else: | |
raise NotImplementedError("FlashLlama is only available on GPU") | |
tokenizer = LlamaTokenizer.from_pretrained( | |
model_id, | |
revision=revision, | |
padding_side="left", | |
truncation_side="left", | |
trust_remote_code=trust_remote_code, | |
) | |
config = AutoConfig.from_pretrained( | |
model_id, revision=revision, trust_remote_code=trust_remote_code | |
) | |
torch.distributed.barrier(group=self.process_group) | |
filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
with init_empty_weights(): | |
model = FlashLlamaForCausalLM(config, process_group=self.process_group) | |
torch.distributed.barrier(group=self.process_group) | |
self.load_weights( | |
model, | |
filenames, | |
quantize=quantize, | |
device=device, | |
dtype=dtype, | |
rank=rank, | |
world_size=world_size, | |
) | |
torch.distributed.barrier(group=self.process_group) | |
##### hack-text-generation-inference ##### | |
self.grammar, self.stackstore = grammar_hack_init(tokenizer) | |
super(FlashCausalLM, self).__init__( | |
model=model.to(device), | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
rank=rank, | |
world_size=world_size, | |
) | |
@staticmethod | |
def load_weights( | |
model, | |
filenames: List[str], | |
quantize: Optional[str], | |
device: torch.device, | |
dtype: torch.dtype, | |
rank: int, | |
world_size: int, | |
): | |
for file in filenames: | |
with safe_open( | |
file, framework="pt", device=str(device) if quantize is None else "cpu" | |
) as f: | |
for name in f.keys(): | |
slice_ = f.get_slice(name) | |
layer_name = ".".join(name.split(".")[:4]) | |
# Fused qkv | |
if "q_proj" in name or "k_proj" in name or "v_proj" in name: | |
final_name = layer_name + ".query_key_value.weight" | |
# Fused gate and up projs | |
elif "gate_proj" in name or "up_proj" in name: | |
final_name = layer_name + ".gate_up_proj.weight" | |
else: | |
final_name = name | |
module_name, param_name = final_name.rsplit(".", 1) | |
module = model.get_submodule(module_name) | |
if isinstance(module, TensorParallelColumnLinear): | |
size = slice_.get_shape()[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[start:stop] | |
elif isinstance(module, TensorParallelRowLinear): | |
size = slice_.get_shape()[1] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[:, start:stop] | |
elif isinstance(module, TensorParallelEmbedding): | |
size = slice_.get_shape()[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[start:stop] | |
elif name == "lm_head.weight" and model.model.tp_embeddings: | |
size = slice_.get_shape()[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[start:stop] | |
else: | |
try: | |
tensor = slice_[:] | |
except: | |
tensor = f.get_tensor(name) | |
tensor = tensor.contiguous().to(dtype) | |
try: | |
current_parameter_tensor = module._parameters[param_name] | |
except KeyError: | |
current_parameter_tensor = None | |
if current_parameter_tensor is not None: | |
if current_parameter_tensor.device == torch.device("meta"): | |
# Init qkv | |
if "query_key_value" in final_name: | |
module._parameters[param_name] = tensor.new_empty( | |
(tensor.shape[0] * 3, tensor.shape[1]) | |
) | |
# Init gate and up proj | |
elif "gate_up_proj" in final_name: | |
module._parameters[param_name] = tensor.new_empty( | |
(tensor.shape[0] * 2, tensor.shape[1]) | |
) | |
# Init gate and up proj | |
if "q_proj" in name: | |
module._parameters[param_name][: tensor.shape[0]] = tensor | |
elif "k_proj" in name: | |
module._parameters[param_name][ | |
tensor.shape[0] : tensor.shape[0] * 2 | |
] = tensor | |
elif "v_proj" in name: | |
module._parameters[param_name][ | |
tensor.shape[0] * 2 : | |
] = tensor | |
elif "gate_proj" in name: | |
module._parameters[param_name][: tensor.shape[0]] = tensor | |
elif "up_proj" in name: | |
module._parameters[param_name][tensor.shape[0] :] = tensor | |
else: | |
if current_parameter_tensor.shape != tensor.shape: | |
raise ValueError( | |
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
) | |
module._parameters[param_name] = tensor | |
else: | |
module._buffers[param_name] = tensor | |
torch.cuda.empty_cache() | |
model.post_load_weights(quantize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed | |
from accelerate import init_empty_weights | |
from opentelemetry import trace | |
from safetensors import safe_open | |
from pathlib import Path | |
from transformers import AutoTokenizer, GPT2Config | |
from typing import Optional, List | |
from text_generation_server.models import FlashCausalLM | |
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( | |
FlashSantacoderForCausalLM, | |
TensorParallelRowLinear, | |
TensorParallelColumnLinear, | |
TensorParallelEmbedding, | |
) | |
from text_generation_server.utils import ( | |
initialize_torch_distributed, | |
weight_files, | |
download_weights, | |
weight_hub_files, | |
LocalEntryNotFoundError, | |
) | |
tracer = trace.get_tracer(__name__) | |
##### hack-text-generation-inference ##### | |
from text_generation_server.grammar_hack import grammar_hack_init | |
class FlashSantacoder(FlashCausalLM): | |
def __init__( | |
self, | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: Optional[str] = None, | |
trust_remote_code: bool = False, | |
): | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
else: | |
raise NotImplementedError("FlashSantacoder is only available on GPU") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
revision=revision, | |
padding_side="left", | |
truncation_side="left", | |
trust_remote_code=trust_remote_code, | |
) | |
config = GPT2Config.from_pretrained( | |
model_id, | |
revision=revision, | |
) | |
# We do not use from_pretrained as we modified the model internal module layout | |
filenames = weight_files(model_id, revision, ".safetensors") | |
with init_empty_weights(): | |
model = FlashSantacoderForCausalLM(config) | |
self.load_weights( | |
model, | |
filenames, | |
quantize, | |
device, | |
dtype, | |
config.architectures[0].startswith("GPT2"), | |
) | |
super(FlashCausalLM, self).__init__( | |
model=model.to(device), | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
) | |
@staticmethod | |
def load_weights( | |
model: FlashSantacoderForCausalLM, | |
filenames: List[Path], | |
quantize: Optional[str], | |
device: torch.device, | |
dtype: torch.dtype, | |
transpose: bool, | |
): | |
for filename in filenames: | |
with safe_open( | |
filename, | |
framework="pt", | |
device=str(device) if quantize is None else "cpu", | |
) as f: | |
for key in f.keys(): | |
value = f.get_tensor(key) | |
value = value.to(device if quantize is None else "cpu").to(dtype) | |
layer_name = ".".join(key.split(".")[:4]) | |
# Fused qkv | |
if "q_attn.weight" in key or "kv_attn.weight" in key: | |
final_key = layer_name + ".c_attn.weight" | |
elif "q_attn.bias" in key or "kv_attn.bias" in key: | |
final_key = layer_name + ".c_attn.bias" | |
else: | |
final_key = key | |
module_name, param_name = final_key.rsplit(".", 1) | |
module = model.get_submodule(module_name) | |
try: | |
current_parameter_tensor = module._parameters[param_name] | |
except KeyError: | |
current_parameter_tensor = None | |
if current_parameter_tensor is not None: | |
if transpose and ( | |
"c_fc.weight" in key | |
or "c_proj.weight" in key | |
or "q_attn.weight" in key | |
or "kv_attn.weight" in key | |
or "c_attn.weight" in key | |
): | |
# Tranpose as we use nn.Linear instead of Conv1D | |
value = value.T | |
if current_parameter_tensor.device == torch.device("meta"): | |
# Init qkv | |
if "c_attn.weight" in final_key: | |
module._parameters[param_name] = value.new_empty( | |
( | |
model.transformer.head_size | |
* (model.transformer.num_heads + 2), | |
value.shape[1], | |
) | |
) | |
elif "c_attn.bias" in final_key: | |
module._parameters[param_name] = value.new_empty( | |
( | |
model.transformer.head_size | |
* (model.transformer.num_heads + 2) | |
) | |
) | |
# Copy to correct slice | |
if "q_attn.weight" in key: | |
module._parameters[param_name][: value.shape[0]] = value | |
elif "q_attn.bias" in key: | |
module._parameters[param_name][: value.shape[0]] = value | |
elif "kv_attn.weight" in key: | |
module._parameters[param_name][ | |
model.transformer.head_size | |
* model.transformer.num_heads : | |
] = value | |
elif "kv_attn.bias" in key: | |
module._parameters[param_name][ | |
model.transformer.head_size | |
* model.transformer.num_heads : | |
] = value | |
else: | |
if current_parameter_tensor.shape != value.shape: | |
raise ValueError( | |
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}" | |
) | |
module._parameters[param_name] = value | |
else: | |
module._buffers[param_name] = value | |
del value | |
if model.lm_head.weight.device == torch.device("meta"): | |
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) | |
torch.cuda.empty_cache() | |
model.post_load_weights(quantize) | |
uninitialized_parameters = [] | |
for n, p in model.named_parameters(): | |
if p.data.device == torch.device("meta"): | |
uninitialized_parameters.append(n) | |
if uninitialized_parameters: | |
raise RuntimeError( | |
f"found uninitialized parameters in model : {uninitialized_parameters}" | |
) | |
def decode(self, generated_ids: List[int]) -> str: | |
# Do not skip special tokens as they are used for custom parsing rules of the generated text | |
return self.tokenizer.decode( | |
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False | |
) | |
class FlashSantacoderSharded(FlashSantacoder): | |
def __init__( | |
self, | |
model_id: str, | |
revision: Optional[str] = None, | |
quantize: Optional[str] = None, | |
trust_remote_code: bool = False, | |
): | |
self.process_group, rank, world_size = initialize_torch_distributed() | |
if torch.cuda.is_available(): | |
device = torch.device(f"cuda:{rank}") | |
dtype = torch.float16 | |
else: | |
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
revision=revision, | |
padding_side="left", | |
truncation_side="left", | |
trust_remote_code=trust_remote_code, | |
) | |
config = GPT2Config.from_pretrained( | |
model_id, | |
revision=revision, | |
) | |
torch.distributed.barrier(group=self.process_group) | |
filenames = weight_files(model_id, revision=revision, extension=".safetensors") | |
with init_empty_weights(): | |
model = FlashSantacoderForCausalLM(config, self.process_group) | |
torch.distributed.barrier(group=self.process_group) | |
self.load_weights( | |
model, | |
filenames, | |
quantize=quantize, | |
device=device, | |
dtype=dtype, | |
rank=rank, | |
world_size=world_size, | |
transpose=config.architectures[0].startswith("GPT2"), | |
) | |
torch.distributed.barrier(group=self.process_group) | |
##### hack-text-generation-inference ##### | |
self.grammar, self.stackstore = grammar_hack_init(tokenizer) | |
super(FlashCausalLM, self).__init__( | |
model=model.to(device), | |
tokenizer=tokenizer, | |
requires_padding=False, | |
dtype=dtype, | |
device=device, | |
rank=rank, | |
world_size=world_size, | |
) | |
@staticmethod | |
def load_weights( | |
model, | |
filenames: List[str], | |
quantize: Optional[str], | |
device: torch.device, | |
dtype: torch.dtype, | |
rank: int, | |
world_size: int, | |
transpose: bool, | |
): | |
for file in filenames: | |
with safe_open( | |
file, framework="pt", device=str(device) if quantize is None else "cpu" | |
) as f: | |
for key in f.keys(): | |
slice_ = f.get_slice(key) | |
layer_name = ".".join(key.split(".")[:4]) | |
# Fused qkv | |
if "q_attn.weight" in key or "kv_attn.weight" in key: | |
final_key = layer_name + ".c_attn.weight" | |
elif "q_attn.bias" in key or "kv_attn.bias" in key: | |
final_key = layer_name + ".c_attn.bias" | |
else: | |
final_key = key | |
module_name, param_name = final_key.rsplit(".", 1) | |
module = model.get_submodule(module_name) | |
if isinstance(module, TensorParallelColumnLinear): | |
dim = 1 if transpose and "weight" in param_name else 0 | |
size = slice_.get_shape()[dim] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = ( | |
slice_[start:stop] if dim == 0 else slice_[:, start:stop] | |
) | |
elif isinstance(module, TensorParallelRowLinear): | |
if param_name == "weight": | |
dim = 0 if transpose else 1 | |
size = slice_.get_shape()[dim] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = ( | |
slice_[start:stop] | |
if dim == 0 | |
else slice_[:, start:stop] | |
) | |
else: | |
tensor = slice_[:] | |
# XXX: Hack for Rowlinear to add the bias only once. | |
if rank != 0: | |
tensor = torch.zeros_like(tensor) | |
elif isinstance(module, TensorParallelEmbedding): | |
size = slice_.get_shape()[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[start:stop] | |
elif key == "lm_head.weight" and model.transformer.tp_embeddings: | |
size = slice_.get_shape()[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = slice_[start:stop] | |
else: | |
try: | |
tensor = slice_[:] | |
except: | |
tensor = f.get_tensor(key) | |
tensor = tensor.contiguous().to(dtype) | |
try: | |
current_parameter_tensor = module._parameters[param_name] | |
except KeyError: | |
current_parameter_tensor = None | |
if current_parameter_tensor is not None: | |
if transpose and ( | |
"c_fc.weight" in key | |
or "c_proj.weight" in key | |
or "q_attn.weight" in key | |
or "kv_attn.weight" in key | |
or "c_attn.weight" in key | |
): | |
# Tranpose as we use nn.Linear instead of Conv1D | |
tensor = tensor.T | |
if current_parameter_tensor.device == torch.device("meta"): | |
# Init qkv | |
if "c_attn.weight" in final_key: | |
module._parameters[param_name] = tensor.new_empty( | |
( | |
model.transformer.head_size | |
* (model.transformer.num_heads + 2), | |
tensor.shape[1], | |
) | |
) | |
elif "c_attn.bias" in final_key: | |
module._parameters[param_name] = tensor.new_empty( | |
( | |
model.transformer.head_size | |
* (model.transformer.num_heads + 2) | |
) | |
) | |
# Copy to correct slice | |
if "q_attn" in key: | |
size = tensor.shape[0] | |
block_size = size // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
tensor = tensor[start:stop] | |
module._parameters[param_name][: tensor.shape[0]] = tensor | |
elif "kv_attn.weight" in key: | |
module._parameters[param_name][ | |
model.transformer.head_size | |
* model.transformer.num_heads : | |
] = tensor | |
elif "kv_attn.bias" in key: | |
module._parameters[param_name][ | |
model.transformer.head_size | |
* model.transformer.num_heads : | |
] = tensor | |
elif "c_attn" in key: | |
# Slice q_tensor by shard | |
q_tensor = tensor[: -2 * model.transformer.head_size] | |
block_size = q_tensor.shape[0] // world_size | |
start = rank * block_size | |
stop = (rank + 1) * block_size | |
q_tensor = q_tensor[start:stop] | |
module._parameters[param_name][ | |
: q_tensor.shape[0] | |
] = q_tensor | |
# Kv tensor is copied for every shard | |
kv_tensor = tensor[-2 * model.transformer.head_size :] | |
module._parameters[param_name][ | |
q_tensor.shape[0] : | |
] = kv_tensor | |
else: | |
if current_parameter_tensor.shape != tensor.shape: | |
raise ValueError( | |
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}" | |
) | |
module._parameters[param_name] = tensor | |
else: | |
module._buffers[param_name] = tensor | |
if model.lm_head.weight.device == torch.device("meta"): | |
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight) | |
torch.cuda.empty_cache() | |
model.post_load_weights(quantize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
from torch_grammar import GrammarSampler | |
class LRUCache: | |
def __init__(self, capacity): | |
self.capacity = capacity | |
self.cache = collections.OrderedDict() | |
def get(self, key): | |
try: | |
value = self.cache.pop(key) | |
self.cache[key] = value | |
return value | |
except KeyError: | |
return -1 | |
def set(self, key, value): | |
try: | |
self.cache.pop(key) | |
except KeyError: | |
if len(self.cache) >= self.capacity: | |
self.cache.popitem(last=False) | |
self.cache[key] = value | |
def grammar_hack_init(tokenizer): | |
with open("/opt/grammar.ebnf", "r") as f: | |
ebnf = f.read() | |
grammar = GrammarSampler(ebnf, "root", tokenizer) | |
stackstore = LRUCache(10000) | |
return grammar, stackstore | |
def grammar_hack_prepare(requests_idx_mapping, stackstore, grammar): | |
request_ids = [None] * len(requests_idx_mapping) | |
idx_stacks = [None] * len(requests_idx_mapping) | |
for _, (request_id, idx) in enumerate(requests_idx_mapping.items()): | |
request_ids[idx] = request_id | |
stacks = stackstore.get(request_id) | |
if stacks == -1: | |
stacks = grammar.init_stacks() | |
idx_stacks[idx] = stacks | |
return idx_stacks | |
def grammar_hack_accept_tokens(grammar, input_ids, idx_stacks): | |
for idx, input_id in enumerate(input_ids): | |
idx_stacks[idx] = grammar.accept_token(input_id, idx_stacks[idx]) | |
def grammar_hack_commit(grammar, stackstore, idx_stacks, requests_idx_mapping, next_token_logits): | |
for idx, stack in enumerate(idx_stacks): | |
grammar.filter_logits(next_token_logits[idx], stack, next_token_logits.device) | |
for _, (request_id, idx) in enumerate(requests_idx_mapping.items()): | |
stackstore.set(request_id, idx_stacks[idx]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment