Skip to content

Instantly share code, notes, and snippets.

@spullara
Last active January 1, 2024 11:18
Show Gist options
  • Save spullara/0038469eebefe27f8e9835fedb14363b to your computer and use it in GitHub Desktop.
Save spullara/0038469eebefe27f8e9835fedb14363b to your computer and use it in GitHub Desktop.
CLI based embedding search using OpenAI
import openai
import tiktoken
import os
from openai.embeddings_utils import get_embedding
openai.organization = os.environ.get("OPENAI_ORG")
openai.api_key = os.environ.get("OPENAI_API_KEY")
# Global variables
tokenizer = tiktoken.get_encoding(
"cl100k_base"
) # The encoding scheme to use for tokenization
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
# Constants
CHUNK_SIZE = 512 # The target size of each text chunk in tokens
MIN_CHUNK_SIZE_CHARS = 350 # The minimum size of each text chunk in characters
MIN_CHUNK_LENGTH_TO_EMBED = 5 # Discard chunks shorter than this
MAX_NUM_CHUNKS = 1000 # The maximum number of chunks to generate
# From https://github.com/openai/chatgpt-retrieval-plugin/blob/main/services/chunks.py
def get_text_chunks(text: str):
"""
Split a text into chunks of ~CHUNK_SIZE tokens, based on punctuation and newline boundaries.
Args:
text: The text to split into chunks.
chunk_token_size: The target size of each chunk in tokens, or None to use the default CHUNK_SIZE.
Returns:
A list of text chunks, each of which is a string of ~CHUNK_SIZE tokens.
"""
# Return an empty list if the text is empty or whitespace
if not text or text.isspace():
return []
# Tokenize the text
tokens = tokenizer.encode(text, disallowed_special=())
# Initialize an empty list of chunks
chunks = []
# Use the provided chunk token size or the default one
chunk_size = CHUNK_SIZE
# Initialize a counter for the number of chunks
num_chunks = 0
# Loop until all tokens are consumed
while tokens and num_chunks < MAX_NUM_CHUNKS:
# Take the first chunk_size tokens as a chunk
chunk = tokens[:chunk_size]
# Decode the chunk into text
chunk_text = tokenizer.decode(chunk)
# Skip the chunk if it is empty or whitespace
if not chunk_text or chunk_text.isspace():
# Remove the tokens corresponding to the chunk text from the remaining tokens
tokens = tokens[len(chunk):]
# Continue to the next iteration of the loop
continue
# Find the last period or punctuation mark in the chunk
last_punctuation = max(
chunk_text.rfind("."),
chunk_text.rfind("?"),
chunk_text.rfind("!"),
chunk_text.rfind("\n"),
)
# If there is a punctuation mark, and the last punctuation index is before MIN_CHUNK_SIZE_CHARS
if last_punctuation != -1 and last_punctuation > MIN_CHUNK_SIZE_CHARS:
# Truncate the chunk text at the punctuation mark
chunk_text = chunk_text[: last_punctuation + 1]
# Remove any newline characters and strip any leading or trailing whitespace
chunk_text_to_append = chunk_text.replace("\n", " ").strip()
if len(chunk_text_to_append) > MIN_CHUNK_LENGTH_TO_EMBED:
# Append the chunk text to the list of chunks
chunks.append(chunk_text_to_append)
# Remove the tokens corresponding to the chunk text from the remaining tokens
tokens = tokens[len(tokenizer.encode(chunk_text, disallowed_special=())):]
# Increment the number of chunks
num_chunks += 1
# Handle the remaining tokens
if tokens:
remaining_text = tokenizer.decode(tokens).replace("\n", " ").strip()
if len(remaining_text) > MIN_CHUNK_LENGTH_TO_EMBED:
chunks.append(remaining_text)
return chunks
import json
# load the json file
with open("content.json", "r") as f:
crawl = json.load(f)
# count the pages
pages = 0
# embeddings
embeddings = []
# iterate through the json file
for page in crawl:
# check if the page has a text
if "text" in page and "title" in page:
title = page["title"]
print(str(pages) + " " + title)
pageUrl = page["url"]
text = title + "\n\n" + page["text"]
chunks = get_text_chunks(text)
offset = 0
for chunk in chunks:
embedding = {
"page": pages,
"title": title,
"pageUrl": pageUrl,
"text": chunk,
"offset": offset,
"embedding": get_embedding(chunk, engine="text-embedding-ada-002")
}
embeddings.append(embedding)
pages += 1
offset += len(chunk)
# write the embeddings to a json file
with open("embeddings.json", "w") as f:
json.dump(embeddings, f)
# Global variables
tokenizer = tiktoken.get_encoding(
"cl100k_base"
) # The encoding scheme to use for tokenization
import json
from autofaiss import build_index
import numpy as np
# load the embeddings.json file
with open("embeddings.json", "r") as f:
data = json.load(f)
# create a numpy array from "embedding" values in the json file
embeddings = np.float32([np.float32(embedding["embedding"]) for embedding in data])
index, index_infos = build_index(embeddings, index_path="knn.index", index_infos_path="index_infos.json")
from openai.embeddings_utils import get_embedding
import openai
import faiss
import numpy as np
import json
import tiktoken
import os
openai.organization = os.environ.get("OPENAI_ORG")
openai.api_key = os.environ.get("OPENAI_API_KEY")
index = faiss.read_index("knn.index", faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY)
# load the embeddings.json file
with open("embeddings.json", "r") as f:
data = json.load(f)
def num_tokens_from_string(string: str, encoding_name: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens
while True:
# get the query off the command line, embed it and search the index
query = input("Query: ")
query_embedding = get_embedding(query, engine="text-embedding-ada-002")
results = index.search(np.array([query_embedding]), k=10)
# create a context from the text of the results until you reach the 4096 token limit
context = ""
for result in results[1][0]:
pageUrl = data[result]["pageUrl"]
text = "URL: " + pageUrl + ":\n" + data[result]["text"]
if num_tokens_from_string(context + text, "cl100k_base") < 3500:
context += text
else:
break
chat = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0301",
temperature=0.0,
messages=[
{"role": "system",
"content": "You are a helpful assistant answering questions about a website given relevant text from it. If you don't know the answer from the relevant text respond that there is no answer on the site."},
{"role": "user", "content": "Here is the text of the relevant pages prefixed by their URL, reference these URLs in your answer:\n" + context + "\n"},
{"role": "user", "content": "User: " + query},
]
)
print(chat["choices"][0]["message"]["content"])
aiohttp==3.8.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==22.2.0
autofaiss==2.15.5
blobfile==2.0.1
certifi==2022.12.7
charset-normalizer==3.0.1
contourpy==1.0.7
cycler==0.11.0
embedding-reader==1.5.0
faiss-cpu==1.7.3
filelock==3.9.0
fire==0.4.0
fonttools==4.38.0
frozenlist==1.3.3
fsspec==2023.1.0
idna==3.4
importlib-resources==5.12.0
joblib==1.2.0
kiwisolver==1.4.4
lxml==4.9.2
matplotlib==3.7.0
multidict==6.0.4
numpy==1.24.2
openai==0.27.0
packaging==23.0
pandas==1.5.3
Pillow==9.4.0
plotly==5.13.1
pyarrow==7.0.0
pycryptodomex==3.17
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.7.1
regex==2022.10.31
requests==2.28.2
scikit-learn==1.2.1
scipy==1.10.1
six==1.16.0
tenacity==8.2.2
termcolor==2.2.0
threadpoolctl==3.1.0
tiktoken==0.2.0
tqdm==4.64.1
urllib3==1.26.14
yarl==1.8.2
zipp==3.15.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment