Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Last active January 13, 2025 01:10
Show Gist options
  • Save bigsnarfdude/f8f2973461149cff850740a1db25a014 to your computer and use it in GitHub Desktop.
Save bigsnarfdude/f8f2973461149cff850740a1db25a014 to your computer and use it in GitHub Desktop.
inference.py
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer
from peft import PeftModel
import torch
def load_model(model_path):
"""Load the fine-tuned LoRA model and tokenizer"""
# Initialize base model and tokenizer
base_model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/Phi-4",
max_seq_length=512,
load_in_4bit=True,
device_map="auto"
)
# Configure LoRA
model = FastLanguageModel.get_peft_model(
base_model,
r=8,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_alpha=8,
bias="none",
)
# Load the trained LoRA weights
model = PeftModel.from_pretrained(
model,
model_path,
is_trainable=False
)
tokenizer = get_chat_template(tokenizer, chat_template="phi-4")
# Set model to inference mode
FastLanguageModel.for_inference(model)
return model, tokenizer
def get_regulatory_section(model, tokenizer, intent_text):
"""Get regulatory section for a single intent statement"""
# Format the input with explicit instruction
messages = [
{"role": "system", "content": "You are a regulatory matching system. Given an intent statement, respond only with the matching regulatory section number in the format: Section XXX.XX"},
{"role": "user", "content": intent_text}
]
# Tokenize with attention mask
encoded = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt"
)
# Move to GPU
encoded = encoded.to("cuda")
# Generate response
outputs = model.generate(
input_ids=encoded,
max_new_tokens=16,
temperature=0.1,
min_p=0.9,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.2
)
# Clean up the output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract section number (assuming format "Section XXX.XX")
import re
section_match = re.search(r'Section (\d+\.\d+)', response)
if section_match:
return section_match.group(1)
else:
return "No valid section number found in response: " + response.strip()
def main():
# Path to your saved model
MODEL_PATH = "models/regulatory-matcher-20250112_223223" # Update this to your model path
# Load model
print("Loading model...")
try:
model, tokenizer = load_model(MODEL_PATH)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
return
# Interactive loop
print("\nEnter regulatory intent statements (type 'quit' to exit):")
while True:
# Get input from user
intent = input("\nEnter intent statement: ").strip()
# Check for quit command
if intent.lower() == 'quit':
break
if intent:
print("\nAnalyzing...")
try:
section = get_regulatory_section(model, tokenizer, intent)
print(f"Regulatory section: {section}")
except Exception as e:
print(f"Error processing intent: {e}")
else:
print("Please enter a valid intent statement.")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment