Last active
February 26, 2024 17:23
-
-
Save windshadow233/998b9b6a7765c911e77a0de239f99749 to your computer and use it in GitHub Desktop.
Hackergame 2023 "🪐 小型大语言模型星球" 解题代码
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Hackergame 2023 "🪐 小型大语言模型星球" 解题代码 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tqdm | |
from file import * | |
model.cuda() | |
for i in tqdm.tqdm(range(0, tokenizer.vocab_size, 64)): | |
model_inputs = torch.tensor([[i + _] for _ in range(64)]).cuda() | |
model_outputs = model.generate( | |
model_inputs, | |
max_new_tokens=30, | |
num_beams=1, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
for input_, output in zip(model_inputs, model_outputs): | |
output = output[len(input_):] | |
output = tokenizer.decode(output, skip_special_tokens=True) | |
if 'accepted' in output or 'Accepted' in output: | |
print(output) | |
print(tokenizer.decode(input_)) | |
break | |
else: | |
continue | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import tqdm | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.optim import Adam | |
from string import printable | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
model = AutoModelForCausalLM.from_pretrained("./TinyStories-33M").eval() | |
tokenizer = AutoTokenizer.from_pretrained("./TinyStories-33M") | |
bad_tokens = [] | |
for i in tqdm.tqdm(tokenizer.vocab.values()): | |
s = tokenizer.decode([i]) | |
if not set(s).issubset(printable[:-5]): | |
bad_tokens.append(i) | |
def get_closest_embedding(input_embedding, embedding, target): | |
embedding_weight = embedding.weight | |
norm_embedding = F.normalize(embedding_weight, p=2, dim=1) | |
norm_input_embedding = F.normalize(input_embedding, p=2, dim=1) | |
target_embedding = embedding(target[:, :-1]) | |
cosine_sim_mat = torch.mm(norm_input_embedding, norm_embedding.t()) | |
cosine_sim_mat[:, bad_tokens] = -1 | |
chosen_idx = torch.argmax(cosine_sim_mat, dim=1) | |
closest_embeddings = embedding_weight[chosen_idx] | |
closest_embeddings = input_embedding + (closest_embeddings - input_embedding).detach() | |
return torch.cat([closest_embeddings[None], target_embedding], dim=1), chosen_idx | |
def predict(message): | |
model_inputs = tokenizer.encode(message, return_tensors="pt").cuda() | |
model_outputs = model.generate( | |
model_inputs, | |
max_new_tokens=30, | |
num_beams=1, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
model_outputs = model_outputs[0, len(model_inputs[0]):] | |
model_output_text = tokenizer.decode(model_outputs, skip_special_tokens=True) | |
return model_output_text | |
s = '🐮' | |
token_length = 30 | |
max_message_length = 200 | |
model.cuda() | |
loss_fcn = nn.CrossEntropyLoss() | |
target = tokenizer.encode(s, return_tensors="pt").cuda() | |
embedding_to_train = torch.randn(size=(token_length, 768), requires_grad=True).cuda() | |
embedding_to_train.retain_grad() | |
optim = Adam(params=[embedding_to_train], lr=5e-2) | |
while 1: | |
embedding, idx = get_closest_embedding(embedding_to_train, model.transformer.wte, target) | |
logits = model(inputs_embeds=embedding).logits | |
loss = loss_fcn(logits[:, -len(target[0]):].swapaxes(1, 2), target) | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
message = tokenizer.decode(idx, skip_special_tokens=True) | |
prediction = predict(message) | |
print(f"{loss.item()}, {prediction}") | |
if s in prediction and len(message) <= max_message_length: | |
print('Success!') | |
print(f"Message: '{message}'") | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment