Last active
January 13, 2025 01:10
-
-
Save bigsnarfdude/f8f2973461149cff850740a1db25a014 to your computer and use it in GitHub Desktop.
inference.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from unsloth import FastLanguageModel | |
from unsloth.chat_templates import get_chat_template | |
from transformers import TextStreamer | |
from peft import PeftModel | |
import torch | |
def load_model(model_path): | |
"""Load the fine-tuned LoRA model and tokenizer""" | |
# Initialize base model and tokenizer | |
base_model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name="unsloth/Phi-4", | |
max_seq_length=512, | |
load_in_4bit=True, | |
device_map="auto" | |
) | |
# Configure LoRA | |
model = FastLanguageModel.get_peft_model( | |
base_model, | |
r=8, | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
lora_alpha=8, | |
bias="none", | |
) | |
# Load the trained LoRA weights | |
model = PeftModel.from_pretrained( | |
model, | |
model_path, | |
is_trainable=False | |
) | |
tokenizer = get_chat_template(tokenizer, chat_template="phi-4") | |
# Set model to inference mode | |
FastLanguageModel.for_inference(model) | |
return model, tokenizer | |
def get_regulatory_section(model, tokenizer, intent_text): | |
"""Get regulatory section for a single intent statement""" | |
# Format the input with explicit instruction | |
messages = [ | |
{"role": "system", "content": "You are a regulatory matching system. Given an intent statement, respond only with the matching regulatory section number in the format: Section XXX.XX"}, | |
{"role": "user", "content": intent_text} | |
] | |
# Tokenize with attention mask | |
encoded = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
) | |
# Move to GPU | |
encoded = encoded.to("cuda") | |
# Generate response | |
outputs = model.generate( | |
input_ids=encoded, | |
max_new_tokens=16, | |
temperature=0.1, | |
min_p=0.9, | |
pad_token_id=tokenizer.eos_token_id, | |
repetition_penalty=1.2 | |
) | |
# Clean up the output | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract section number (assuming format "Section XXX.XX") | |
import re | |
section_match = re.search(r'Section (\d+\.\d+)', response) | |
if section_match: | |
return section_match.group(1) | |
else: | |
return "No valid section number found in response: " + response.strip() | |
def main(): | |
# Path to your saved model | |
MODEL_PATH = "models/regulatory-matcher-20250112_223223" # Update this to your model path | |
# Load model | |
print("Loading model...") | |
try: | |
model, tokenizer = load_model(MODEL_PATH) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return | |
# Interactive loop | |
print("\nEnter regulatory intent statements (type 'quit' to exit):") | |
while True: | |
# Get input from user | |
intent = input("\nEnter intent statement: ").strip() | |
# Check for quit command | |
if intent.lower() == 'quit': | |
break | |
if intent: | |
print("\nAnalyzing...") | |
try: | |
section = get_regulatory_section(model, tokenizer, intent) | |
print(f"Regulatory section: {section}") | |
except Exception as e: | |
print(f"Error processing intent: {e}") | |
else: | |
print("Please enter a valid intent statement.") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment