You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
408 lines
18 KiB
408 lines
18 KiB
import os
|
|
import json
|
|
import shutil
|
|
import chromadb
|
|
from fastembed import TextEmbedding
|
|
from watchdog.observers import Observer
|
|
from watchdog.events import FileSystemEventHandler
|
|
import threading
|
|
|
|
|
|
PREFERRED_MODELS = [
|
|
"jinaai/jina-embeddings-v2-base-zh", # 中英混合友好,~0.64GB
|
|
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # 多语 ~50 语种
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", # 多语轻量版
|
|
"intfloat/multilingual-e5-large", # 多语更强,体积约 2.2GB
|
|
]
|
|
|
|
# Final chosen model will be detected at runtime from supported list
|
|
MODEL_NAME = None
|
|
COLLECTION_NAME = "brain_kb_v5"
|
|
BATCH_SIZE = 128 # batch upserts to avoid huge single writes
|
|
|
|
# Optional imports for different file types
|
|
try:
|
|
from pypdf import PdfReader
|
|
except ImportError:
|
|
PdfReader = None
|
|
|
|
try:
|
|
from docx import Document
|
|
except ImportError:
|
|
Document = None
|
|
|
|
class KnowledgeBase:
|
|
def __init__(self, kb_path="knowledge", db_path="vector_db"):
|
|
self.kb_path = os.path.abspath(kb_path)
|
|
self.db_path = os.path.abspath(db_path)
|
|
self.meta_path = os.path.join(self.db_path, "_meta.json")
|
|
self.manifest_path = os.path.join(self.db_path, "_manifest.json")
|
|
self._collection_reset_guard = False
|
|
self._query_reset_guard = False
|
|
self._sync_lock = threading.Lock()
|
|
|
|
if not os.path.exists(self.kb_path):
|
|
os.makedirs(self.kb_path)
|
|
|
|
# Initialize Embedding Model (BAAI/bge-small-zh-v1.5 is ~100MB)
|
|
# This will load from cache if already downloaded
|
|
# Pick the first available model from the preferred list
|
|
_supported_raw = TextEmbedding.list_supported_models()
|
|
supported = set()
|
|
for item in _supported_raw:
|
|
if isinstance(item, dict) and "model" in item:
|
|
supported.add(item["model"])
|
|
elif isinstance(item, str):
|
|
supported.add(item)
|
|
chosen = None
|
|
for name in PREFERRED_MODELS:
|
|
if name in supported:
|
|
chosen = name
|
|
break
|
|
if not chosen:
|
|
raise RuntimeError(
|
|
"No preferred embedding models are supported by fastembed. "
|
|
"Please check available models via TextEmbedding.list_supported_models()."
|
|
)
|
|
|
|
print(f"Loading Knowledge Base Embedding Model: {chosen} (may take some time on first run)...")
|
|
try:
|
|
self.model = TextEmbedding(model_name=chosen)
|
|
print("Embedding Model loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading embedding model: {e}")
|
|
raise
|
|
|
|
# Store chosen model name for reference
|
|
global MODEL_NAME
|
|
MODEL_NAME = chosen
|
|
|
|
# Cache embedding dimension (detects library/model changes that corrupt existing indexes)
|
|
self.embed_dim = self._get_embedding_dim()
|
|
self.chroma_version = getattr(chromadb, "__version__", "unknown")
|
|
|
|
# If the stored index was built with a different model/dimension/chromadb version, wipe it
|
|
self._maybe_reset_for_incompatibility(chosen, self.embed_dim, self.chroma_version)
|
|
|
|
# Initialize Vector DB
|
|
self._init_collection()
|
|
self._healthcheck()
|
|
|
|
# Initial sync
|
|
self.sync_knowledge()
|
|
|
|
# Start Watcher
|
|
self.start_watcher()
|
|
|
|
def _init_collection(self, recreate: bool = False):
|
|
"""(Re)initialize Chroma client/collection. If recreate=True, wipe on-disk index."""
|
|
if recreate and os.path.exists(self.db_path):
|
|
shutil.rmtree(self.db_path, ignore_errors=True)
|
|
try:
|
|
self.client = chromadb.PersistentClient(path=self.db_path)
|
|
self.collection = self.client.get_or_create_collection(
|
|
name=COLLECTION_NAME,
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
except Exception as exc:
|
|
# If collection load itself fails, wipe and retry once to clear corrupted segments
|
|
if not recreate:
|
|
shutil.rmtree(self.db_path, ignore_errors=True)
|
|
return self._init_collection(recreate=True)
|
|
raise
|
|
|
|
# Persist metadata about the embedding model used to build this index
|
|
try:
|
|
os.makedirs(self.db_path, exist_ok=True)
|
|
with open(self.meta_path, "w", encoding="utf-8") as f:
|
|
json.dump({
|
|
"model": MODEL_NAME,
|
|
"embed_dim": self.embed_dim,
|
|
"chroma_version": self.chroma_version,
|
|
}, f)
|
|
except Exception:
|
|
pass # Metadata failure should not block runtime
|
|
|
|
def _healthcheck(self):
|
|
"""Validate index readability right after startup; rebuild if corrupted."""
|
|
try:
|
|
_ = self.collection.count()
|
|
except Exception as e:
|
|
msg = str(e).lower()
|
|
if any(x in msg for x in ["hnsw", "segment", "compaction", "backfill"]):
|
|
print("Detected index corruption on startup. Rebuilding vector_db...")
|
|
shutil.rmtree(self.db_path, ignore_errors=True)
|
|
self._init_collection(recreate=True)
|
|
self.sync_knowledge(allow_reset=False)
|
|
else:
|
|
print(f"Index healthcheck encountered an unexpected error: {e}")
|
|
|
|
def _maybe_reset_for_incompatibility(self, chosen_model: str, embed_dim: int, chroma_version: str):
|
|
"""If existing index meta differs (model/dimension/chromadb), wipe it."""
|
|
if not os.path.exists(self.db_path):
|
|
return
|
|
try:
|
|
with open(self.meta_path, "r", encoding="utf-8") as f:
|
|
meta = json.load(f)
|
|
prev_model = meta.get("model")
|
|
prev_dim = meta.get("embed_dim")
|
|
prev_chroma = meta.get("chroma_version")
|
|
if prev_model != chosen_model or prev_dim != embed_dim or prev_chroma != chroma_version:
|
|
shutil.rmtree(self.db_path, ignore_errors=True)
|
|
except Exception:
|
|
# If meta cannot be read, assume stale/corrupted and rebuild
|
|
shutil.rmtree(self.db_path, ignore_errors=True)
|
|
|
|
def _get_embedding_dim(self) -> int:
|
|
for vec in self.model.embed(["dimension_probe"]):
|
|
try:
|
|
return len(vec)
|
|
except Exception:
|
|
return len(list(vec))
|
|
raise RuntimeError("Failed to determine embedding dimension")
|
|
|
|
def sync_knowledge(self, allow_reset: bool = True):
|
|
"""Scans the knowledge folder and updates the vector database."""
|
|
if not self._sync_lock.acquire(blocking=False):
|
|
print("Sync already running, skip this trigger.")
|
|
return
|
|
|
|
print("Syncing knowledge base...")
|
|
manifest = self._load_manifest()
|
|
updated_manifest = {}
|
|
supported_extensions = (".txt", ".md", ".pdf", ".docx", ".json")
|
|
current_files = []
|
|
try:
|
|
for filename in os.listdir(self.kb_path):
|
|
file_path = os.path.join(self.kb_path, filename)
|
|
if os.path.isfile(file_path) and filename.lower().endswith(supported_extensions):
|
|
current_files.append(filename)
|
|
mtime = os.path.getmtime(file_path)
|
|
size = os.path.getsize(file_path)
|
|
prev_meta = manifest.get(filename)
|
|
# Skip unchanged files
|
|
if prev_meta and prev_meta.get("mtime") == mtime and prev_meta.get("size") == size:
|
|
updated_manifest[filename] = prev_meta
|
|
continue
|
|
try:
|
|
content = self._extract_text(file_path)
|
|
if content:
|
|
# Sliding window chunking on original text
|
|
chunk_size = 800
|
|
overlap = 80
|
|
original_chunks = []
|
|
for i in range(0, len(content), chunk_size - overlap):
|
|
chunk = content[i:i + chunk_size].strip()
|
|
if chunk:
|
|
original_chunks.append(chunk)
|
|
|
|
if original_chunks:
|
|
# Normalize for embedding generation only (not for storage)
|
|
normalized_chunks = [c.lower().replace('_', ' ') for c in original_chunks]
|
|
|
|
ids = [f"{filename}_{i}" for i in range(len(original_chunks))]
|
|
metadatas = [{"source": filename, "chunk": i} for i in range(len(original_chunks))]
|
|
|
|
# Compute embeddings from normalized text
|
|
embeddings = []
|
|
for v in self.model.embed(normalized_chunks):
|
|
try:
|
|
embeddings.append(v.tolist())
|
|
except Exception:
|
|
embeddings.append(list(v))
|
|
|
|
# Store ORIGINAL text (not normalized) so users see the real content
|
|
for start in range(0, len(original_chunks), BATCH_SIZE):
|
|
end = start + BATCH_SIZE
|
|
self.collection.upsert(
|
|
documents=original_chunks[start:end],
|
|
ids=ids[start:end],
|
|
metadatas=metadatas[start:end],
|
|
embeddings=embeddings[start:end]
|
|
)
|
|
print(f" ✓ Indexed {filename}: {len(original_chunks)} chunks (batched)")
|
|
updated_manifest[filename] = {"mtime": mtime, "size": size}
|
|
except Exception as e:
|
|
err_msg = str(e)
|
|
print(f"Error processing {filename}: {err_msg}")
|
|
# Auto-recover if HNSW/compaction/index errors occur
|
|
if allow_reset and any(x in err_msg.lower() for x in ["hnsw", "compaction", "segment reader"]):
|
|
if not self._collection_reset_guard:
|
|
print("Detected index corruption. Rebuilding vector_db and retrying sync once...")
|
|
self._collection_reset_guard = True
|
|
self._init_collection(recreate=True)
|
|
return self.sync_knowledge(allow_reset=False)
|
|
# Remove deleted files from the index
|
|
deleted_files = set(manifest.keys()) - set(current_files)
|
|
for filename in deleted_files:
|
|
try:
|
|
self.collection.delete(where={"source": filename})
|
|
print(f" ✓ Removed deleted file from index: {filename}")
|
|
except Exception as e:
|
|
print(f" ! Failed to remove {filename}: {e}")
|
|
# Persist manifest
|
|
self._save_manifest(updated_manifest)
|
|
print("Knowledge base sync complete.")
|
|
finally:
|
|
self._sync_lock.release()
|
|
|
|
def _extract_text(self, file_path):
|
|
ext = os.path.splitext(file_path)[1].lower()
|
|
if ext == ".txt":
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
elif ext == ".md":
|
|
# Treat Markdown as plain text for retrieval
|
|
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
|
return f.read()
|
|
elif ext == ".pdf":
|
|
if PdfReader:
|
|
reader = PdfReader(file_path)
|
|
text = ""
|
|
for page in reader.pages:
|
|
text += page.extract_text() + "\n"
|
|
return text
|
|
else:
|
|
print("pypdf not installed, skipping PDF.")
|
|
elif ext == ".docx":
|
|
if Document:
|
|
doc = Document(file_path)
|
|
return "\n".join([para.text for para in doc.paragraphs])
|
|
else:
|
|
print("python-docx not installed, skipping Word.")
|
|
elif ext == ".json":
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
|
return None
|
|
|
|
def query(self, text, top_k=5, distance_threshold=0.8, allow_reset: bool = True):
|
|
"""Retrieves relevant snippets from the knowledge base.
|
|
|
|
Uses cosine distance (lower is better). A result is treated as a hit only
|
|
when best_distance <= distance_threshold.
|
|
Returns:
|
|
dict: {"hit": bool, "context": str, "hits": [{source, chunk, distance, text}, ...]}
|
|
"""
|
|
try:
|
|
# Normalize query same as indexed content
|
|
normalized_text = text.lower().replace('_', ' ')
|
|
|
|
q_vec = None
|
|
for v in self.model.embed([normalized_text]):
|
|
try:
|
|
q_vec = v.tolist()
|
|
except Exception:
|
|
q_vec = list(v)
|
|
break
|
|
if q_vec is None:
|
|
return {"hit": False, "context": "", "hits": []}
|
|
|
|
results = self.collection.query(
|
|
query_embeddings=[q_vec],
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"]
|
|
)
|
|
|
|
docs = (results or {}).get("documents") or []
|
|
metas = (results or {}).get("metadatas") or []
|
|
dists = (results or {}).get("distances") or []
|
|
|
|
if not docs or not docs[0]:
|
|
print("[KB Query] No results returned from collection")
|
|
return {"hit": False, "context": "", "hits": []}
|
|
|
|
docs0 = docs[0]
|
|
metas0 = metas[0] if metas and metas[0] else [{} for _ in docs0]
|
|
dists0 = dists[0] if dists and dists[0] else [None for _ in docs0]
|
|
|
|
hits = []
|
|
for doc_text, meta, dist in zip(docs0, metas0, dists0):
|
|
hits.append({
|
|
"source": (meta or {}).get("source", ""),
|
|
"chunk": (meta or {}).get("chunk", None),
|
|
"distance": dist,
|
|
"text": doc_text,
|
|
})
|
|
|
|
best = hits[0].get("distance")
|
|
is_hit = (best is not None) and (best <= distance_threshold)
|
|
|
|
# Debug log
|
|
best_str = f"{best:.4f}" if best is not None else "N/A"
|
|
print(f"[KB Query] '{text[:50]}...' -> best_dist={best_str}, threshold={distance_threshold}, hit={is_hit}")
|
|
if hits:
|
|
top3_dists = [f"{h['distance']:.4f}" if h['distance'] is not None else "N/A" for h in hits[:3]]
|
|
print(f"[KB Query] Top 3 distances: {top3_dists}")
|
|
|
|
context = "\n---\n".join([h["text"] for h in hits]) if is_hit else ""
|
|
return {"hit": is_hit, "context": context, "hits": hits}
|
|
except Exception as e:
|
|
err_msg = str(e)
|
|
print(f"Query error: {err_msg}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Auto-recover if HNSW/compaction/backfill errors surface during query
|
|
if allow_reset and any(x in err_msg.lower() for x in ["hnsw", "compaction", "segment reader", "backfill"]):
|
|
if not self._query_reset_guard:
|
|
print("Detected index corruption during query. Rebuilding vector_db and retrying once...")
|
|
self._query_reset_guard = True
|
|
try:
|
|
self._init_collection(recreate=True)
|
|
self.sync_knowledge(allow_reset=False)
|
|
# Retry query once with guard disabled to avoid loops
|
|
self._query_reset_guard = False
|
|
return self.query(text, top_k=top_k, distance_threshold=distance_threshold, allow_reset=False)
|
|
except Exception as inner_e:
|
|
print(f"Auto-rebuild after query failure also failed: {inner_e}")
|
|
self._query_reset_guard = False
|
|
return {"hit": False, "context": "", "hits": []}
|
|
|
|
def start_watcher(self):
|
|
event_handler = KBHandler(self)
|
|
self.observer = Observer()
|
|
self.observer.schedule(event_handler, self.kb_path, recursive=False)
|
|
self.observer.start()
|
|
|
|
def _load_manifest(self):
|
|
if not os.path.exists(self.manifest_path):
|
|
return {}
|
|
try:
|
|
with open(self.manifest_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception:
|
|
return {}
|
|
|
|
def _save_manifest(self, data):
|
|
try:
|
|
os.makedirs(self.db_path, exist_ok=True)
|
|
with open(self.manifest_path, "w", encoding="utf-8") as f:
|
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
|
except Exception as e:
|
|
print(f" ! Failed to save manifest: {e}")
|
|
|
|
class KBHandler(FileSystemEventHandler):
|
|
def __init__(self, kb_instance):
|
|
self.kb = kb_instance
|
|
self.supported_extensions = (".txt", ".md", ".pdf", ".docx", ".json")
|
|
self._debounce_timer = None
|
|
|
|
def _trigger_sync(self):
|
|
def run():
|
|
self.kb.sync_knowledge()
|
|
if self._debounce_timer and self._debounce_timer.is_alive():
|
|
return
|
|
self._debounce_timer = threading.Timer(0.5, run)
|
|
self._debounce_timer.start()
|
|
|
|
def on_modified(self, event):
|
|
if not event.is_directory and event.src_path.lower().endswith(self.supported_extensions):
|
|
print(f"File modified: {event.src_path}. Re-syncing...")
|
|
self._trigger_sync()
|
|
|
|
def on_created(self, event):
|
|
if not event.is_directory and event.src_path.lower().endswith(self.supported_extensions):
|
|
print(f"File created: {event.src_path}. Syncing...")
|
|
self._trigger_sync()
|
|
|
|
|