Created
December 23, 2024 05:53
-
-
Save nerdalert/c2f091ea37c763192e8e18609318d3b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Example script that: | |
1. Converts a document to a DocLing Document. | |
2. Chunks it using a HybridChunker. | |
3. Embeds each chunk using a SentenceTransformer. | |
4. Stores them in a LanceDB index. | |
5. Searches for a user-provided query and returns the best matching chunk or all matching chunks based on a flag. | |
""" | |
import sys | |
import os | |
import subprocess | |
from pathlib import Path | |
from tempfile import mkdtemp | |
import argparse | |
# Optional: Install packages if not present. You can remove or comment out these checks | |
# if you already have everything installed in your environment. | |
def install_if_needed(package): | |
try: | |
__import__(package.split("[")[0].split("==")[0]) | |
except ImportError: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
# Install dependencies (comment out if you already have them) | |
required_packages = [ | |
"docling-core[chunking]", | |
"sentence-transformers", | |
"transformers", | |
"lancedb" | |
] | |
for pkg in required_packages: | |
install_if_needed(pkg) | |
# Now import everything | |
import lancedb | |
from docling.document_converter import DocumentConverter | |
from docling.chunking import HybridChunker | |
from transformers import AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
def build_lancedb_index( | |
db_path: str, | |
index_name: str, | |
chunks, | |
chunker, | |
embedding_model | |
): | |
""" | |
Embeds each chunk, then stores embeddings + text into a LanceDB table. | |
""" | |
# Connect to (or create) the LanceDB database | |
db = lancedb.connect(db_path) | |
# Prepare data for LanceDB | |
data = [] | |
for chunk in chunks: | |
# Serialize the chunk so the chunker/embedding model sees the same text | |
serialized_text = chunker.serialize(chunk=chunk) | |
embeddings = embedding_model.encode(serialized_text) | |
data_item = { | |
"vector": embeddings, | |
"text": chunk.text, | |
"headings": chunk.meta.headings, | |
"captions": chunk.meta.captions, | |
} | |
data.append(data_item) | |
# Create or overwrite a table | |
tbl = db.create_table(index_name, data=data, exist_ok=True) | |
return tbl | |
def parse_arguments(): | |
""" | |
Parses command-line arguments. | |
""" | |
parser = argparse.ArgumentParser( | |
description="Ingest a document, chunk it, embed chunks, store in LanceDB, and query for the best matching chunk." | |
) | |
parser.add_argument( | |
"document_path", | |
type=str, | |
help="Path to the document to ingest (e.g., /path/to/document.md)" | |
) | |
parser.add_argument( | |
"query", | |
type=str, | |
help="Query string to search for in the document chunks" | |
) | |
parser.add_argument( | |
"--all", | |
action="store_true", | |
help="If set, print all matching chunks instead of only the best match" | |
) | |
parser.add_argument( | |
"--limit", | |
type=int, | |
default=1, | |
help="Number of top chunks to return (default: 1). Ignored if --all is set." | |
) | |
return parser.parse_args() | |
def main(): | |
""" | |
Main entry point. | |
Usage (examples): | |
python hybrid_search.py <path_to_document> "my query" | |
python hybrid_search.py <path_to_document> "my query" --all | |
python hybrid_search.py <path_to_document> "my query" --limit 5 | |
""" | |
args = parse_arguments() | |
doc_source = args.document_path | |
user_query = args.query | |
print_all = args.all | |
limit = args.limit if not print_all else None # If --all is set, limit is None (fetch all) | |
# 1. Convert the file into a DocLing Document | |
# Adjust this if you need a different document conversion approach. | |
print(f"Ingesting document at: {doc_source}") | |
doc = DocumentConverter().convert(source=doc_source).document | |
# 2. Choose an embedding model and create chunker | |
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" | |
MAX_TOKENS = 500 | |
print(f"Using embedding model: {EMBED_MODEL_ID}") | |
# Initialize tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID) | |
# Create the hybrid chunker | |
chunker = HybridChunker( | |
tokenizer=tokenizer, | |
max_tokens=MAX_TOKENS, | |
# merge_peers=True # (Default is True, can be adjusted if desired) | |
) | |
# 3. Chunk the document | |
chunks = list(chunker.chunk(dl_doc=doc)) | |
print(f"Number of chunks created: {len(chunks)}") | |
# 4. Build the embedding model | |
embed_model = SentenceTransformer(EMBED_MODEL_ID) | |
# 5. Create a temporary path for LanceDB or specify your own | |
db_uri = Path(mkdtemp()) / "docling.db" | |
index_name = "my_index" | |
# 6. Build the LanceDB index | |
print(f"Building LanceDB index at: {db_uri}") | |
table = build_lancedb_index( | |
db_path=str(db_uri), | |
index_name=index_name, | |
chunks=chunks, | |
chunker=chunker, | |
embedding_model=embed_model | |
) | |
# 7. Encode the user query | |
query_embedding = embed_model.encode(user_query) | |
# 8. Determine the number of results to fetch | |
if print_all: | |
search_limit = len(chunks) # Fetch all chunks | |
else: | |
search_limit = limit # Fetch the specified number of top chunks | |
# 9. Search in LanceDB | |
print(f"Searching for {'all matching' if print_all else f'top {search_limit}'} chunk(s) matching the query: \"{user_query}\"") | |
results = table.search(query_embedding).limit(search_limit) | |
# 10. Convert results to pandas DataFrame | |
results_df = results.to_pandas() | |
# 11. Print out the matching chunks | |
print("\n=== MATCHING CHUNK(S) ===") | |
for idx, row in results_df.iterrows(): | |
print(f"\n--- Chunk {idx + 1} ---") | |
print(f"Text: {row['text']}") | |
if row['headings']: | |
print(f"Headings: {', '.join(row['headings'])}") | |
if row['captions']: | |
print(f"Captions: {', '.join(row['captions'])}") | |
print(f"Distance: {row['_distance']:.4f}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment