Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Last active January 13, 2025 00:41
Show Gist options
  • Save bigsnarfdude/564b6858eef70a88968f485e2a29797e to your computer and use it in GitHub Desktop.
Save bigsnarfdude/564b6858eef70a88968f485e2a29797e to your computer and use it in GitHub Desktop.
simple_retraining.py
from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only
from datasets import Dataset
from trl import SFTTrainer
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer
import torch
import wandb
from datetime import datetime
import json
# 1. Configuration
config = {
"model_name": "unsloth/Phi-4",
"max_seq_length": 2048,
"load_in_4bit": True,
"lora_r": 16,
"lora_alpha": 16,
"lora_dropout": 0,
"learning_rate": 2e-4,
"batch_size": 2,
"gradient_accumulation_steps": 4,
"num_train_epochs": 1,
"logging_steps": 10,
}
# 2. Initialize wandb
wandb.init(project="regulatory-finetuning", config=config)
# 3. Model Initialization
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config["model_name"],
max_seq_length=config["max_seq_length"],
load_in_4bit=config["load_in_4bit"],
device_map="auto"
)
# 4. LoRA configuration
model = FastLanguageModel.get_peft_model(
model,
r=config["lora_r"],
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=config["lora_alpha"],
lora_dropout=config["lora_dropout"],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
# 5. Configure chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="phi-4"
)
# 6. Data preparation
def load_jsonl_data(file_path):
data = []
with open(file_path, 'r', encoding='utf-8') as file:
for line_num, line in enumerate(file, 1):
try:
line = line.strip()
if not line: # Skip empty lines
continue
item = json.loads(line)
if "text_input" not in item or "output" not in item:
print(f"Missing required fields in line {line_num}")
continue
conversation = [
{"role": "user", "content": item["text_input"]},
{"role": "assistant", "content": item["output"]}
]
data.append({"conversations": conversation})
except json.JSONDecodeError as e:
print(f"Error parsing line {line_num} in JSONL: {e}")
continue
except Exception as e:
print(f"Unexpected error at line {line_num}: {str(e)}")
continue
if not data:
raise ValueError("No valid data was loaded from the JSONL file")
print(f"Successfully loaded {len(data)} valid conversations")
return Dataset.from_list(data)
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo,
tokenize=False,
add_generation_prompt=False
)
for convo in convos
]
return {"text": texts}
# 7. Load and process dataset
dataset = load_jsonl_data("training.jsonl")
dataset = standardize_sharegpt(dataset)
dataset = dataset.map(formatting_prompts_func, batched=True)
# 8. Training configuration
training_args = TrainingArguments(
output_dir="outputs",
per_device_train_batch_size=config["batch_size"],
gradient_accumulation_steps=config["gradient_accumulation_steps"],
warmup_steps=5,
num_train_epochs=config["num_train_epochs"],
learning_rate=config["learning_rate"],
fp16=not is_bfloat16_supported(),
bf16=is_bfloat16_supported(),
logging_steps=config["logging_steps"],
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
report_to="wandb",
)
# 9. Initialize trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=config["max_seq_length"],
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
dataset_num_proc=2,
packing=False,
args=training_args,
)
# 10. Apply response-only training
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user<|im_sep|>",
response_part="<|im_start|>assistant<|im_sep|>",
)
# 11. Training with memory tracking
print("Show current memory stats")
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")
try:
trainer_stats = trainer.train()
print("Show final memory and time stats")
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
except Exception as e:
print(f"An error occurred during training: {e}")
wandb.log({"training_error": str(e)})
raise
# 12. Save model
model_name = f"regulatory-model-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
save_path = f"lora_model/{model_name}"
try:
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved successfully to {save_path}")
except Exception as e:
print(f"Error saving model: {e}")
wandb.log({"saving_error": str(e)})
# 13. Enable fast inference and test
FastLanguageModel.for_inference(model)
test_messages = [
{"role": "user", "content": "What are the key regulatory requirements for financial institutions?"},
]
inputs = tokenizer.apply_chat_template(
test_messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
).to("cuda")
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
_ = model.generate(
input_ids=inputs,
streamer=text_streamer,
max_new_tokens=128,
use_cache=True,
temperature=1.5,
min_p=0.1
)
wandb.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment