Last active
January 1, 2024 11:18
-
-
Save spullara/0038469eebefe27f8e9835fedb14363b to your computer and use it in GitHub Desktop.
CLI based embedding search using OpenAI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import openai | |
import tiktoken | |
import os | |
from openai.embeddings_utils import get_embedding | |
openai.organization = os.environ.get("OPENAI_ORG") | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
# Global variables | |
tokenizer = tiktoken.get_encoding( | |
"cl100k_base" | |
) # The encoding scheme to use for tokenization | |
def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.get_encoding(encoding_name) | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
# Constants | |
CHUNK_SIZE = 512 # The target size of each text chunk in tokens | |
MIN_CHUNK_SIZE_CHARS = 350 # The minimum size of each text chunk in characters | |
MIN_CHUNK_LENGTH_TO_EMBED = 5 # Discard chunks shorter than this | |
MAX_NUM_CHUNKS = 1000 # The maximum number of chunks to generate | |
# From https://github.com/openai/chatgpt-retrieval-plugin/blob/main/services/chunks.py | |
def get_text_chunks(text: str): | |
""" | |
Split a text into chunks of ~CHUNK_SIZE tokens, based on punctuation and newline boundaries. | |
Args: | |
text: The text to split into chunks. | |
chunk_token_size: The target size of each chunk in tokens, or None to use the default CHUNK_SIZE. | |
Returns: | |
A list of text chunks, each of which is a string of ~CHUNK_SIZE tokens. | |
""" | |
# Return an empty list if the text is empty or whitespace | |
if not text or text.isspace(): | |
return [] | |
# Tokenize the text | |
tokens = tokenizer.encode(text, disallowed_special=()) | |
# Initialize an empty list of chunks | |
chunks = [] | |
# Use the provided chunk token size or the default one | |
chunk_size = CHUNK_SIZE | |
# Initialize a counter for the number of chunks | |
num_chunks = 0 | |
# Loop until all tokens are consumed | |
while tokens and num_chunks < MAX_NUM_CHUNKS: | |
# Take the first chunk_size tokens as a chunk | |
chunk = tokens[:chunk_size] | |
# Decode the chunk into text | |
chunk_text = tokenizer.decode(chunk) | |
# Skip the chunk if it is empty or whitespace | |
if not chunk_text or chunk_text.isspace(): | |
# Remove the tokens corresponding to the chunk text from the remaining tokens | |
tokens = tokens[len(chunk):] | |
# Continue to the next iteration of the loop | |
continue | |
# Find the last period or punctuation mark in the chunk | |
last_punctuation = max( | |
chunk_text.rfind("."), | |
chunk_text.rfind("?"), | |
chunk_text.rfind("!"), | |
chunk_text.rfind("\n"), | |
) | |
# If there is a punctuation mark, and the last punctuation index is before MIN_CHUNK_SIZE_CHARS | |
if last_punctuation != -1 and last_punctuation > MIN_CHUNK_SIZE_CHARS: | |
# Truncate the chunk text at the punctuation mark | |
chunk_text = chunk_text[: last_punctuation + 1] | |
# Remove any newline characters and strip any leading or trailing whitespace | |
chunk_text_to_append = chunk_text.replace("\n", " ").strip() | |
if len(chunk_text_to_append) > MIN_CHUNK_LENGTH_TO_EMBED: | |
# Append the chunk text to the list of chunks | |
chunks.append(chunk_text_to_append) | |
# Remove the tokens corresponding to the chunk text from the remaining tokens | |
tokens = tokens[len(tokenizer.encode(chunk_text, disallowed_special=())):] | |
# Increment the number of chunks | |
num_chunks += 1 | |
# Handle the remaining tokens | |
if tokens: | |
remaining_text = tokenizer.decode(tokens).replace("\n", " ").strip() | |
if len(remaining_text) > MIN_CHUNK_LENGTH_TO_EMBED: | |
chunks.append(remaining_text) | |
return chunks | |
import json | |
# load the json file | |
with open("content.json", "r") as f: | |
crawl = json.load(f) | |
# count the pages | |
pages = 0 | |
# embeddings | |
embeddings = [] | |
# iterate through the json file | |
for page in crawl: | |
# check if the page has a text | |
if "text" in page and "title" in page: | |
title = page["title"] | |
print(str(pages) + " " + title) | |
pageUrl = page["url"] | |
text = title + "\n\n" + page["text"] | |
chunks = get_text_chunks(text) | |
offset = 0 | |
for chunk in chunks: | |
embedding = { | |
"page": pages, | |
"title": title, | |
"pageUrl": pageUrl, | |
"text": chunk, | |
"offset": offset, | |
"embedding": get_embedding(chunk, engine="text-embedding-ada-002") | |
} | |
embeddings.append(embedding) | |
pages += 1 | |
offset += len(chunk) | |
# write the embeddings to a json file | |
with open("embeddings.json", "w") as f: | |
json.dump(embeddings, f) | |
# Global variables | |
tokenizer = tiktoken.get_encoding( | |
"cl100k_base" | |
) # The encoding scheme to use for tokenization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from autofaiss import build_index | |
import numpy as np | |
# load the embeddings.json file | |
with open("embeddings.json", "r") as f: | |
data = json.load(f) | |
# create a numpy array from "embedding" values in the json file | |
embeddings = np.float32([np.float32(embedding["embedding"]) for embedding in data]) | |
index, index_infos = build_index(embeddings, index_path="knn.index", index_infos_path="index_infos.json") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from openai.embeddings_utils import get_embedding | |
import openai | |
import faiss | |
import numpy as np | |
import json | |
import tiktoken | |
import os | |
openai.organization = os.environ.get("OPENAI_ORG") | |
openai.api_key = os.environ.get("OPENAI_API_KEY") | |
index = faiss.read_index("knn.index", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) | |
# load the embeddings.json file | |
with open("embeddings.json", "r") as f: | |
data = json.load(f) | |
def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
"""Returns the number of tokens in a text string.""" | |
encoding = tiktoken.get_encoding(encoding_name) | |
num_tokens = len(encoding.encode(string)) | |
return num_tokens | |
while True: | |
# get the query off the command line, embed it and search the index | |
query = input("Query: ") | |
query_embedding = get_embedding(query, engine="text-embedding-ada-002") | |
results = index.search(np.array([query_embedding]), k=10) | |
# create a context from the text of the results until you reach the 4096 token limit | |
context = "" | |
for result in results[1][0]: | |
pageUrl = data[result]["pageUrl"] | |
text = "URL: " + pageUrl + ":\n" + data[result]["text"] | |
if num_tokens_from_string(context + text, "cl100k_base") < 3500: | |
context += text | |
else: | |
break | |
chat = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo-0301", | |
temperature=0.0, | |
messages=[ | |
{"role": "system", | |
"content": "You are a helpful assistant answering questions about a website given relevant text from it. If you don't know the answer from the relevant text respond that there is no answer on the site."}, | |
{"role": "user", "content": "Here is the text of the relevant pages prefixed by their URL, reference these URLs in your answer:\n" + context + "\n"}, | |
{"role": "user", "content": "User: " + query}, | |
] | |
) | |
print(chat["choices"][0]["message"]["content"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
aiohttp==3.8.4 | |
aiosignal==1.3.1 | |
async-timeout==4.0.2 | |
attrs==22.2.0 | |
autofaiss==2.15.5 | |
blobfile==2.0.1 | |
certifi==2022.12.7 | |
charset-normalizer==3.0.1 | |
contourpy==1.0.7 | |
cycler==0.11.0 | |
embedding-reader==1.5.0 | |
faiss-cpu==1.7.3 | |
filelock==3.9.0 | |
fire==0.4.0 | |
fonttools==4.38.0 | |
frozenlist==1.3.3 | |
fsspec==2023.1.0 | |
idna==3.4 | |
importlib-resources==5.12.0 | |
joblib==1.2.0 | |
kiwisolver==1.4.4 | |
lxml==4.9.2 | |
matplotlib==3.7.0 | |
multidict==6.0.4 | |
numpy==1.24.2 | |
openai==0.27.0 | |
packaging==23.0 | |
pandas==1.5.3 | |
Pillow==9.4.0 | |
plotly==5.13.1 | |
pyarrow==7.0.0 | |
pycryptodomex==3.17 | |
pyparsing==3.0.9 | |
python-dateutil==2.8.2 | |
pytz==2022.7.1 | |
regex==2022.10.31 | |
requests==2.28.2 | |
scikit-learn==1.2.1 | |
scipy==1.10.1 | |
six==1.16.0 | |
tenacity==8.2.2 | |
termcolor==2.2.0 | |
threadpoolctl==3.1.0 | |
tiktoken==0.2.0 | |
tqdm==4.64.1 | |
urllib3==1.26.14 | |
yarl==1.8.2 | |
zipp==3.15.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment