|
|
import os |
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Tuple |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import chromadb |
|
|
from chromadb.config import Settings as ChromaSettings |
|
|
import faiss |
|
|
import pickle |
|
|
|
|
|
DEFAULT_CHUNK_TOKENS = 200 |
|
|
DEFAULT_OVERLAP_TOKENS = 50 |
|
|
|
|
|
def approx_tokenize(text: str) -> List[str]: |
|
|
return text.split() |
|
|
|
|
|
def detokenize(tokens: List[str]) -> str: |
|
|
return " ".join(tokens) |
|
|
|
|
|
def chunk_text(text: str, chunk_tokens: int, overlap_tokens: int) -> List[str]: |
|
|
tokens = approx_tokenize(text) |
|
|
chunks = [] |
|
|
i = 0 |
|
|
n = len(tokens) |
|
|
while i < n: |
|
|
j = min(i + chunk_tokens, n) |
|
|
chunk = detokenize(tokens[i:j]) |
|
|
if chunk.strip(): |
|
|
chunks.append(chunk) |
|
|
if j == n: |
|
|
break |
|
|
i = j - overlap_tokens |
|
|
if i < 0: |
|
|
i = 0 |
|
|
return chunks |
|
|
|
|
|
def index_note( |
|
|
text: str, |
|
|
note_id: str = "temp_note", |
|
|
persist_dir: str = "./data/vector_store", |
|
|
db_type: str = "chroma", |
|
|
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", |
|
|
collection: str = "notes" |
|
|
) -> str: |
|
|
os.makedirs(persist_dir, exist_ok=True) |
|
|
model = SentenceTransformer(model_name) |
|
|
chunks = chunk_text(text, DEFAULT_CHUNK_TOKENS, DEFAULT_OVERLAP_TOKENS) |
|
|
chunk_ids = [f"{note_id}::chunk_{i}" for i in range(len(chunks))] |
|
|
metadatas = [{"note_id": note_id, "chunk_index": i} for i in range(len(chunks))] |
|
|
vectors = model.encode(chunks, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True) |
|
|
|
|
|
if db_type == "chroma": |
|
|
|
|
|
client = chromadb.PersistentClient( |
|
|
path=persist_dir, |
|
|
settings=ChromaSettings( |
|
|
allow_reset=False, |
|
|
anonymized_telemetry=False |
|
|
) |
|
|
) |
|
|
coll = client.get_or_create_collection(collection) |
|
|
coll.upsert( |
|
|
ids=chunk_ids, |
|
|
embeddings=vectors.tolist(), |
|
|
documents=chunks, |
|
|
metadatas=metadatas, |
|
|
) |
|
|
elif db_type == "faiss": |
|
|
d = vectors.shape[1] |
|
|
index = faiss.IndexFlatIP(d) |
|
|
index.add(vectors) |
|
|
vectors_meta = [ |
|
|
{"id": chunk_ids[k], "text": chunks[k], "meta": metadatas[k]} |
|
|
for k in range(len(chunks)) |
|
|
] |
|
|
faiss_path = os.path.join(persist_dir, "index.faiss") |
|
|
meta_path = os.path.join(persist_dir, "meta.pkl") |
|
|
faiss.write_index(index, faiss_path) |
|
|
with open(meta_path, "wb") as f: |
|
|
pickle.dump(vectors_meta, f) |
|
|
|
|
|
return note_id |
|
|
|