Skip to content

Instantly share code, notes, and snippets.

@nerdalert
Created December 23, 2024 05:53
Show Gist options
  • Save nerdalert/c2f091ea37c763192e8e18609318d3b9 to your computer and use it in GitHub Desktop.
Save nerdalert/c2f091ea37c763192e8e18609318d3b9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Example script that:
1. Converts a document to a DocLing Document.
2. Chunks it using a HybridChunker.
3. Embeds each chunk using a SentenceTransformer.
4. Stores them in a LanceDB index.
5. Searches for a user-provided query and returns the best matching chunk or all matching chunks based on a flag.
"""
import sys
import os
import subprocess
from pathlib import Path
from tempfile import mkdtemp
import argparse
# Optional: Install packages if not present. You can remove or comment out these checks
# if you already have everything installed in your environment.
def install_if_needed(package):
try:
__import__(package.split("[")[0].split("==")[0])
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# Install dependencies (comment out if you already have them)
required_packages = [
"docling-core[chunking]",
"sentence-transformers",
"transformers",
"lancedb"
]
for pkg in required_packages:
install_if_needed(pkg)
# Now import everything
import lancedb
from docling.document_converter import DocumentConverter
from docling.chunking import HybridChunker
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
def build_lancedb_index(
db_path: str,
index_name: str,
chunks,
chunker,
embedding_model
):
"""
Embeds each chunk, then stores embeddings + text into a LanceDB table.
"""
# Connect to (or create) the LanceDB database
db = lancedb.connect(db_path)
# Prepare data for LanceDB
data = []
for chunk in chunks:
# Serialize the chunk so the chunker/embedding model sees the same text
serialized_text = chunker.serialize(chunk=chunk)
embeddings = embedding_model.encode(serialized_text)
data_item = {
"vector": embeddings,
"text": chunk.text,
"headings": chunk.meta.headings,
"captions": chunk.meta.captions,
}
data.append(data_item)
# Create or overwrite a table
tbl = db.create_table(index_name, data=data, exist_ok=True)
return tbl
def parse_arguments():
"""
Parses command-line arguments.
"""
parser = argparse.ArgumentParser(
description="Ingest a document, chunk it, embed chunks, store in LanceDB, and query for the best matching chunk."
)
parser.add_argument(
"document_path",
type=str,
help="Path to the document to ingest (e.g., /path/to/document.md)"
)
parser.add_argument(
"query",
type=str,
help="Query string to search for in the document chunks"
)
parser.add_argument(
"--all",
action="store_true",
help="If set, print all matching chunks instead of only the best match"
)
parser.add_argument(
"--limit",
type=int,
default=1,
help="Number of top chunks to return (default: 1). Ignored if --all is set."
)
return parser.parse_args()
def main():
"""
Main entry point.
Usage (examples):
python hybrid_search.py <path_to_document> "my query"
python hybrid_search.py <path_to_document> "my query" --all
python hybrid_search.py <path_to_document> "my query" --limit 5
"""
args = parse_arguments()
doc_source = args.document_path
user_query = args.query
print_all = args.all
limit = args.limit if not print_all else None # If --all is set, limit is None (fetch all)
# 1. Convert the file into a DocLing Document
# Adjust this if you need a different document conversion approach.
print(f"Ingesting document at: {doc_source}")
doc = DocumentConverter().convert(source=doc_source).document
# 2. Choose an embedding model and create chunker
EMBED_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2"
MAX_TOKENS = 500
print(f"Using embedding model: {EMBED_MODEL_ID}")
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)
# Create the hybrid chunker
chunker = HybridChunker(
tokenizer=tokenizer,
max_tokens=MAX_TOKENS,
# merge_peers=True # (Default is True, can be adjusted if desired)
)
# 3. Chunk the document
chunks = list(chunker.chunk(dl_doc=doc))
print(f"Number of chunks created: {len(chunks)}")
# 4. Build the embedding model
embed_model = SentenceTransformer(EMBED_MODEL_ID)
# 5. Create a temporary path for LanceDB or specify your own
db_uri = Path(mkdtemp()) / "docling.db"
index_name = "my_index"
# 6. Build the LanceDB index
print(f"Building LanceDB index at: {db_uri}")
table = build_lancedb_index(
db_path=str(db_uri),
index_name=index_name,
chunks=chunks,
chunker=chunker,
embedding_model=embed_model
)
# 7. Encode the user query
query_embedding = embed_model.encode(user_query)
# 8. Determine the number of results to fetch
if print_all:
search_limit = len(chunks) # Fetch all chunks
else:
search_limit = limit # Fetch the specified number of top chunks
# 9. Search in LanceDB
print(f"Searching for {'all matching' if print_all else f'top {search_limit}'} chunk(s) matching the query: \"{user_query}\"")
results = table.search(query_embedding).limit(search_limit)
# 10. Convert results to pandas DataFrame
results_df = results.to_pandas()
# 11. Print out the matching chunks
print("\n=== MATCHING CHUNK(S) ===")
for idx, row in results_df.iterrows():
print(f"\n--- Chunk {idx + 1} ---")
print(f"Text: {row['text']}")
if row['headings']:
print(f"Headings: {', '.join(row['headings'])}")
if row['captions']:
print(f"Captions: {', '.join(row['captions'])}")
print(f"Distance: {row['_distance']:.4f}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment