Created
October 10, 2023 06:07
-
-
Save GuyARoss/cd449590aebedce8591c0282da83d414 to your computer and use it in GitHub Desktop.
worst program ever
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import tiktoken | |
import pinecone | |
import os | |
import uuid | |
import re | |
import openai | |
openai.api_key = "#TODO" | |
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
MAX_TOKENS = 4096 | |
MAX_VEC_CTX = 20 | |
API_KEY = "#TODO pinecone" | |
pinecone.init(api_key=API_KEY, environment="us-east4-gcp") | |
pinecode_index = pinecone.Index(index_name="#todo") | |
def pad(vec): | |
if len(vec) > MAX_VEC_CTX: | |
return vec[:MAX_VEC_CTX] | |
if len(vec) >= MAX_VEC_CTX: | |
return vec | |
return vec + [0 for _ in range(MAX_VEC_CTX-len(vec)) ] | |
def split_sentences(st): | |
sentences = re.split(r'[.?!]\s*', st) | |
if sentences[-1]: | |
return sentences | |
else: | |
return sentences[:-1] | |
def insert(enc, fr): | |
id = uuid.uuid4() | |
pinecode_index.upsert( | |
vectors=[ | |
( | |
str(id), | |
enc, | |
{ | |
"prompt": prompt, | |
"from": fr, | |
"len": len(enc), | |
} | |
), | |
], | |
) | |
def completion(prompt, prev) -> str: | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": "You are an assistant." + prev}, | |
{"role": "user", "content": prompt} | |
] | |
) | |
return completion.choices[0].message.content | |
def main(prompt: str): | |
prompt_encoding = encoding.encode(prompt) | |
search_response = pinecode_index.query( | |
vector=pad(prompt_encoding), | |
include_metadata=True, | |
top_k=5, | |
) | |
left_over = MAX_TOKENS-(len(prompt_encoding) + (MAX_TOKENS/2)) | |
prev_statement = "background info:" | |
c = search_response.to_dict() | |
matches = c.get('matches', []) | |
if len(matches) > 0: | |
for m in matches: | |
content = m['metadata']['prompt'] | |
ln = m['metadata']['len'] | |
fr = m['metadata']['from'] | |
if ln > left_over: | |
continue | |
left_over -= ln | |
prev_statement += f"{fr}: {content};" | |
if prev_statement == "prev_statement:": | |
prev_statement = "" | |
output = completion(prompt, prev_statement) | |
print(output) | |
for p in split_sentences(prompt): | |
enc = encoding.encode(p) | |
if len(enc) > 2: | |
insert(pad(enc), "user") | |
for p in split_sentences(output): | |
enc = encoding.encode(p) | |
if len(enc) > 2: | |
insert(pad(enc), "ai") | |
if __name__ == "__main__": | |
prompt = " ".join(sys.argv[1:]) | |
main(prompt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment