Last active
November 4, 2024 17:57
-
-
Save ahmadasjad/35446c46e8e127a90b21501c2ce95d64 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from flask import Flask, request, jsonify | |
import torch | |
import os | |
import logging | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
logger = logging.getLogger() | |
# Load the CodeParrot model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("codeparrot/codeparrot-small") | |
model = AutoModelForCausalLM.from_pretrained("codeparrot/codeparrot-small") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
app = Flask(__name__) | |
@app.route('/generate/', methods=['POST']) | |
def generate(): | |
data = request.get_json() | |
# print('data: ', data) | |
print('data:', data, flush=True) | |
logger.info('data-info:') | |
logger.info(data) | |
# return jsonify({"response": "ai response"}) | |
input_text = data['inputs'] | |
# Process inputs and move them to the appropriate device | |
inputs = tokenizer(input_text, return_tensors="pt").to(device) | |
# Generate output and decode | |
outputs = model.generate(inputs["input_ids"], max_length=50) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
response = {"response": generated_text} | |
print('response: ', response, flush=True) | |
return jsonify(response) | |
# if __name__ == '__main__': | |
# app.run(host='0.0.0.0', port=5000) | |
# from pyngrok import ngrok | |
# public_url = ngrok.connect(5000) | |
# print("Public URL:", public_url) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment