my-streamlit-app / indexer.py
samarth09healthPM's picture
Fix duplicate key error with session state
7d10354
import os
import json
import argparse
from pathlib import Path
from typing import List, Dict, Tuple
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings as ChromaSettings
import faiss
import pickle
DEFAULT_CHUNK_TOKENS = 200
DEFAULT_OVERLAP_TOKENS = 50
def approx_tokenize(text: str) -> List[str]:
return text.split()
def detokenize(tokens: List[str]) -> str:
return " ".join(tokens)
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]:
tokens = approx_tokenize(text)
chunks = []
i = 0
n = len(tokens)
while i < n:
j = min(i + chunk_tokens, n)
chunk = detokenize(tokens[i:j])
if chunk.strip():
chunks.append(chunk)
if j == n:
break
i = j - overlap_tokens
if i < 0:
i = 0
return chunks
def index_note(
text: str,
note_id: str = "temp_note",
persist_dir: str = "./data/vector_store",
db_type: str = "chroma",
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
collection: str = "notes"
) -> str:
os.makedirs(persist_dir, exist_ok=True)
model = SentenceTransformer(model_name)
chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS)
chunk_ids = [f"{note_id}::chunk_{i}" for i in range(len(chunks))]
metadatas = [{"note_id": note_id, "chunk_index": i} for i in range(len(chunks))]
vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)
if db_type == "chroma":
# FIX: Use get_or_create with consistent settings
client = chromadb.PersistentClient(
path=persist_dir,
settings=ChromaSettings(
allow_reset=False, # Changed to False for consistency
anonymized_telemetry=False
)
)
coll = client.get_or_create_collection(collection)
coll.upsert(
ids=chunk_ids,
embeddings=vectors.tolist(),
documents=chunks,
metadatas=metadatas,
)
elif db_type == "faiss":
d = vectors.shape[1]
index = faiss.IndexFlatIP(d)
index.add(vectors)
vectors_meta = [
{"id": chunk_ids[k], "text": chunks[k], "meta": metadatas[k]}
for k in range(len(chunks))
]
faiss_path = os.path.join(persist_dir, "index.faiss")
meta_path = os.path.join(persist_dir, "meta.pkl")
faiss.write_index(index, faiss_path)
with open(meta_path, "wb") as f:
pickle.dump(vectors_meta, f)
return note_id