Skip to content

Instantly share code, notes, and snippets.

@ahmadasjad
Last active November 4, 2024 17:57
Show Gist options
  • Save ahmadasjad/35446c46e8e127a90b21501c2ce95d64 to your computer and use it in GitHub Desktop.
Save ahmadasjad/35446c46e8e127a90b21501c2ce95d64 to your computer and use it in GitHub Desktop.
from flask import Flask, request, jsonify
import torch
import os
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
logger = logging.getLogger()
# Load the CodeParrot model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("codeparrot/codeparrot-small")
model = AutoModelForCausalLM.from_pretrained("codeparrot/codeparrot-small")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
app = Flask(__name__)
@app.route('/generate/', methods=['POST'])
def generate():
data = request.get_json()
# print('data: ', data)
print('data:', data, flush=True)
logger.info('data-info:')
logger.info(data)
# return jsonify({"response": "ai response"})
input_text = data['inputs']
# Process inputs and move them to the appropriate device
inputs = tokenizer(input_text, return_tensors="pt").to(device)
# Generate output and decode
outputs = model.generate(inputs["input_ids"], max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = {"response": generated_text}
print('response: ', response, flush=True)
return jsonify(response)
# if __name__ == '__main__':
# app.run(host='0.0.0.0', port=5000)
# from pyngrok import ngrok
# public_url = ngrok.connect(5000)
# print("Public URL:", public_url)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment