Last active
January 13, 2025 00:41
-
-
Save bigsnarfdude/564b6858eef70a88968f485e2a29797e to your computer and use it in GitHub Desktop.
simple_retraining.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unsloth import FastLanguageModel, is_bfloat16_supported | |
from unsloth.chat_templates import get_chat_template, standardize_sharegpt, train_on_responses_only | |
from datasets import Dataset | |
from trl import SFTTrainer | |
from transformers import TrainingArguments, DataCollatorForSeq2Seq, TextStreamer | |
import torch | |
import wandb | |
from datetime import datetime | |
import json | |
# 1. Configuration | |
config = { | |
"model_name": "unsloth/Phi-4", | |
"max_seq_length": 2048, | |
"load_in_4bit": True, | |
"lora_r": 16, | |
"lora_alpha": 16, | |
"lora_dropout": 0, | |
"learning_rate": 2e-4, | |
"batch_size": 2, | |
"gradient_accumulation_steps": 4, | |
"num_train_epochs": 1, | |
"logging_steps": 10, | |
} | |
# 2. Initialize wandb | |
wandb.init(project="regulatory-finetuning", config=config) | |
# 3. Model Initialization | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=config["model_name"], | |
max_seq_length=config["max_seq_length"], | |
load_in_4bit=config["load_in_4bit"], | |
device_map="auto" | |
) | |
# 4. LoRA configuration | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=config["lora_r"], | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", | |
"gate_proj", "up_proj", "down_proj"], | |
lora_alpha=config["lora_alpha"], | |
lora_dropout=config["lora_dropout"], | |
bias="none", | |
use_gradient_checkpointing="unsloth", | |
random_state=3407, | |
use_rslora=False, | |
loftq_config=None, | |
) | |
# 5. Configure chat template | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template="phi-4" | |
) | |
# 6. Data preparation | |
def load_jsonl_data(file_path): | |
data = [] | |
with open(file_path, 'r', encoding='utf-8') as file: | |
for line_num, line in enumerate(file, 1): | |
try: | |
line = line.strip() | |
if not line: # Skip empty lines | |
continue | |
item = json.loads(line) | |
if "text_input" not in item or "output" not in item: | |
print(f"Missing required fields in line {line_num}") | |
continue | |
conversation = [ | |
{"role": "user", "content": item["text_input"]}, | |
{"role": "assistant", "content": item["output"]} | |
] | |
data.append({"conversations": conversation}) | |
except json.JSONDecodeError as e: | |
print(f"Error parsing line {line_num} in JSONL: {e}") | |
continue | |
except Exception as e: | |
print(f"Unexpected error at line {line_num}: {str(e)}") | |
continue | |
if not data: | |
raise ValueError("No valid data was loaded from the JSONL file") | |
print(f"Successfully loaded {len(data)} valid conversations") | |
return Dataset.from_list(data) | |
def formatting_prompts_func(examples): | |
convos = examples["conversations"] | |
texts = [ | |
tokenizer.apply_chat_template( | |
convo, | |
tokenize=False, | |
add_generation_prompt=False | |
) | |
for convo in convos | |
] | |
return {"text": texts} | |
# 7. Load and process dataset | |
dataset = load_jsonl_data("training.jsonl") | |
dataset = standardize_sharegpt(dataset) | |
dataset = dataset.map(formatting_prompts_func, batched=True) | |
# 8. Training configuration | |
training_args = TrainingArguments( | |
output_dir="outputs", | |
per_device_train_batch_size=config["batch_size"], | |
gradient_accumulation_steps=config["gradient_accumulation_steps"], | |
warmup_steps=5, | |
num_train_epochs=config["num_train_epochs"], | |
learning_rate=config["learning_rate"], | |
fp16=not is_bfloat16_supported(), | |
bf16=is_bfloat16_supported(), | |
logging_steps=config["logging_steps"], | |
optim="adamw_8bit", | |
weight_decay=0.01, | |
lr_scheduler_type="linear", | |
seed=3407, | |
report_to="wandb", | |
) | |
# 9. Initialize trainer | |
trainer = SFTTrainer( | |
model=model, | |
tokenizer=tokenizer, | |
train_dataset=dataset, | |
dataset_text_field="text", | |
max_seq_length=config["max_seq_length"], | |
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer), | |
dataset_num_proc=2, | |
packing=False, | |
args=training_args, | |
) | |
# 10. Apply response-only training | |
trainer = train_on_responses_only( | |
trainer, | |
instruction_part="<|im_start|>user<|im_sep|>", | |
response_part="<|im_start|>assistant<|im_sep|>", | |
) | |
# 11. Training with memory tracking | |
print("Show current memory stats") | |
gpu_stats = torch.cuda.get_device_properties(0) | |
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3) | |
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.") | |
print(f"{start_gpu_memory} GB of memory reserved.") | |
try: | |
trainer_stats = trainer.train() | |
print("Show final memory and time stats") | |
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) | |
used_memory_for_lora = round(used_memory - start_gpu_memory, 3) | |
used_percentage = round(used_memory / max_memory * 100, 3) | |
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3) | |
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.") | |
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.") | |
print(f"Peak reserved memory = {used_memory} GB.") | |
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.") | |
print(f"Peak reserved memory % of max memory = {used_percentage} %.") | |
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.") | |
except Exception as e: | |
print(f"An error occurred during training: {e}") | |
wandb.log({"training_error": str(e)}) | |
raise | |
# 12. Save model | |
model_name = f"regulatory-model-{datetime.now().strftime('%Y%m%d_%H%M%S')}" | |
save_path = f"lora_model/{model_name}" | |
try: | |
model.save_pretrained(save_path) | |
tokenizer.save_pretrained(save_path) | |
print(f"Model saved successfully to {save_path}") | |
except Exception as e: | |
print(f"Error saving model: {e}") | |
wandb.log({"saving_error": str(e)}) | |
# 13. Enable fast inference and test | |
FastLanguageModel.for_inference(model) | |
test_messages = [ | |
{"role": "user", "content": "What are the key regulatory requirements for financial institutions?"}, | |
] | |
inputs = tokenizer.apply_chat_template( | |
test_messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to("cuda") | |
text_streamer = TextStreamer(tokenizer, skip_prompt=True) | |
_ = model.generate( | |
input_ids=inputs, | |
streamer=text_streamer, | |
max_new_tokens=128, | |
use_cache=True, | |
temperature=1.5, | |
min_p=0.1 | |
) | |
wandb.finish() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment