You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

408 lines
18 KiB

import os
import json
import shutil
import chromadb
from fastembed import TextEmbedding
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
import threading
PREFERRED_MODELS = [
"jinaai/jina-embeddings-v2-base-zh", # 中英混合友好,~0.64GB
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # 多语 ~50 语种
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", # 多语轻量版
"intfloat/multilingual-e5-large", # 多语更强,体积约 2.2GB
]
# Final chosen model will be detected at runtime from supported list
MODEL_NAME = None
COLLECTION_NAME = "brain_kb_v5"
BATCH_SIZE = 128 # batch upserts to avoid huge single writes
# Optional imports for different file types
try:
from pypdf import PdfReader
except ImportError:
PdfReader = None
try:
from docx import Document
except ImportError:
Document = None
class KnowledgeBase:
def __init__(self, kb_path="knowledge", db_path="vector_db"):
self.kb_path = os.path.abspath(kb_path)
self.db_path = os.path.abspath(db_path)
self.meta_path = os.path.join(self.db_path, "_meta.json")
self.manifest_path = os.path.join(self.db_path, "_manifest.json")
self._collection_reset_guard = False
self._query_reset_guard = False
self._sync_lock = threading.Lock()
if not os.path.exists(self.kb_path):
os.makedirs(self.kb_path)
# Initialize Embedding Model (BAAI/bge-small-zh-v1.5 is ~100MB)
# This will load from cache if already downloaded
# Pick the first available model from the preferred list
_supported_raw = TextEmbedding.list_supported_models()
supported = set()
for item in _supported_raw:
if isinstance(item, dict) and "model" in item:
supported.add(item["model"])
elif isinstance(item, str):
supported.add(item)
chosen = None
for name in PREFERRED_MODELS:
if name in supported:
chosen = name
break
if not chosen:
raise RuntimeError(
"No preferred embedding models are supported by fastembed. "
"Please check available models via TextEmbedding.list_supported_models()."
)
print(f"Loading Knowledge Base Embedding Model: {chosen} (may take some time on first run)...")
try:
self.model = TextEmbedding(model_name=chosen)
print("Embedding Model loaded successfully.")
except Exception as e:
print(f"Error loading embedding model: {e}")
raise
# Store chosen model name for reference
global MODEL_NAME
MODEL_NAME = chosen
# Cache embedding dimension (detects library/model changes that corrupt existing indexes)
self.embed_dim = self._get_embedding_dim()
self.chroma_version = getattr(chromadb, "__version__", "unknown")
# If the stored index was built with a different model/dimension/chromadb version, wipe it
self._maybe_reset_for_incompatibility(chosen, self.embed_dim, self.chroma_version)
# Initialize Vector DB
self._init_collection()
self._healthcheck()
# Initial sync
self.sync_knowledge()
# Start Watcher
self.start_watcher()
def _init_collection(self, recreate: bool = False):
"""(Re)initialize Chroma client/collection. If recreate=True, wipe on-disk index."""
if recreate and os.path.exists(self.db_path):
shutil.rmtree(self.db_path, ignore_errors=True)
try:
self.client = chromadb.PersistentClient(path=self.db_path)
self.collection = self.client.get_or_create_collection(
name=COLLECTION_NAME,
metadata={"hnsw:space": "cosine"}
)
except Exception as exc:
# If collection load itself fails, wipe and retry once to clear corrupted segments
if not recreate:
shutil.rmtree(self.db_path, ignore_errors=True)
return self._init_collection(recreate=True)
raise
# Persist metadata about the embedding model used to build this index
try:
os.makedirs(self.db_path, exist_ok=True)
with open(self.meta_path, "w", encoding="utf-8") as f:
json.dump({
"model": MODEL_NAME,
"embed_dim": self.embed_dim,
"chroma_version": self.chroma_version,
}, f)
except Exception:
pass # Metadata failure should not block runtime
def _healthcheck(self):
"""Validate index readability right after startup; rebuild if corrupted."""
try:
_ = self.collection.count()
except Exception as e:
msg = str(e).lower()
if any(x in msg for x in ["hnsw", "segment", "compaction", "backfill"]):
print("Detected index corruption on startup. Rebuilding vector_db...")
shutil.rmtree(self.db_path, ignore_errors=True)
self._init_collection(recreate=True)
self.sync_knowledge(allow_reset=False)
else:
print(f"Index healthcheck encountered an unexpected error: {e}")
def _maybe_reset_for_incompatibility(self, chosen_model: str, embed_dim: int, chroma_version: str):
"""If existing index meta differs (model/dimension/chromadb), wipe it."""
if not os.path.exists(self.db_path):
return
try:
with open(self.meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
prev_model = meta.get("model")
prev_dim = meta.get("embed_dim")
prev_chroma = meta.get("chroma_version")
if prev_model != chosen_model or prev_dim != embed_dim or prev_chroma != chroma_version:
shutil.rmtree(self.db_path, ignore_errors=True)
except Exception:
# If meta cannot be read, assume stale/corrupted and rebuild
shutil.rmtree(self.db_path, ignore_errors=True)
def _get_embedding_dim(self) -> int:
for vec in self.model.embed(["dimension_probe"]):
try:
return len(vec)
except Exception:
return len(list(vec))
raise RuntimeError("Failed to determine embedding dimension")
def sync_knowledge(self, allow_reset: bool = True):
"""Scans the knowledge folder and updates the vector database."""
if not self._sync_lock.acquire(blocking=False):
print("Sync already running, skip this trigger.")
return
print("Syncing knowledge base...")
manifest = self._load_manifest()
updated_manifest = {}
supported_extensions = (".txt", ".md", ".pdf", ".docx", ".json")
current_files = []
try:
for filename in os.listdir(self.kb_path):
file_path = os.path.join(self.kb_path, filename)
if os.path.isfile(file_path) and filename.lower().endswith(supported_extensions):
current_files.append(filename)
mtime = os.path.getmtime(file_path)
size = os.path.getsize(file_path)
prev_meta = manifest.get(filename)
# Skip unchanged files
if prev_meta and prev_meta.get("mtime") == mtime and prev_meta.get("size") == size:
updated_manifest[filename] = prev_meta
continue
try:
content = self._extract_text(file_path)
if content:
# Sliding window chunking on original text
chunk_size = 800
overlap = 80
original_chunks = []
for i in range(0, len(content), chunk_size - overlap):
chunk = content[i:i + chunk_size].strip()
if chunk:
original_chunks.append(chunk)
if original_chunks:
# Normalize for embedding generation only (not for storage)
normalized_chunks = [c.lower().replace('_', ' ') for c in original_chunks]
ids = [f"{filename}_{i}" for i in range(len(original_chunks))]
metadatas = [{"source": filename, "chunk": i} for i in range(len(original_chunks))]
# Compute embeddings from normalized text
embeddings = []
for v in self.model.embed(normalized_chunks):
try:
embeddings.append(v.tolist())
except Exception:
embeddings.append(list(v))
# Store ORIGINAL text (not normalized) so users see the real content
for start in range(0, len(original_chunks), BATCH_SIZE):
end = start + BATCH_SIZE
self.collection.upsert(
documents=original_chunks[start:end],
ids=ids[start:end],
metadatas=metadatas[start:end],
embeddings=embeddings[start:end]
)
print(f" ✓ Indexed {filename}: {len(original_chunks)} chunks (batched)")
updated_manifest[filename] = {"mtime": mtime, "size": size}
except Exception as e:
err_msg = str(e)
print(f"Error processing {filename}: {err_msg}")
# Auto-recover if HNSW/compaction/index errors occur
if allow_reset and any(x in err_msg.lower() for x in ["hnsw", "compaction", "segment reader"]):
if not self._collection_reset_guard:
print("Detected index corruption. Rebuilding vector_db and retrying sync once...")
self._collection_reset_guard = True
self._init_collection(recreate=True)
return self.sync_knowledge(allow_reset=False)
# Remove deleted files from the index
deleted_files = set(manifest.keys()) - set(current_files)
for filename in deleted_files:
try:
self.collection.delete(where={"source": filename})
print(f" ✓ Removed deleted file from index: {filename}")
except Exception as e:
print(f" ! Failed to remove {filename}: {e}")
# Persist manifest
self._save_manifest(updated_manifest)
print("Knowledge base sync complete.")
finally:
self._sync_lock.release()
def _extract_text(self, file_path):
ext = os.path.splitext(file_path)[1].lower()
if ext == ".txt":
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
elif ext == ".md":
# Treat Markdown as plain text for retrieval
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
return f.read()
elif ext == ".pdf":
if PdfReader:
reader = PdfReader(file_path)
text = ""
for page in reader.pages:
text += page.extract_text() + "\n"
return text
else:
print("pypdf not installed, skipping PDF.")
elif ext == ".docx":
if Document:
doc = Document(file_path)
return "\n".join([para.text for para in doc.paragraphs])
else:
print("python-docx not installed, skipping Word.")
elif ext == ".json":
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return json.dumps(data, ensure_ascii=False, indent=2)
return None
def query(self, text, top_k=5, distance_threshold=0.8, allow_reset: bool = True):
"""Retrieves relevant snippets from the knowledge base.
Uses cosine distance (lower is better). A result is treated as a hit only
when best_distance <= distance_threshold.
Returns:
dict: {"hit": bool, "context": str, "hits": [{source, chunk, distance, text}, ...]}
"""
try:
# Normalize query same as indexed content
normalized_text = text.lower().replace('_', ' ')
q_vec = None
for v in self.model.embed([normalized_text]):
try:
q_vec = v.tolist()
except Exception:
q_vec = list(v)
break
if q_vec is None:
return {"hit": False, "context": "", "hits": []}
results = self.collection.query(
query_embeddings=[q_vec],
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
docs = (results or {}).get("documents") or []
metas = (results or {}).get("metadatas") or []
dists = (results or {}).get("distances") or []
if not docs or not docs[0]:
print("[KB Query] No results returned from collection")
return {"hit": False, "context": "", "hits": []}
docs0 = docs[0]
metas0 = metas[0] if metas and metas[0] else [{} for _ in docs0]
dists0 = dists[0] if dists and dists[0] else [None for _ in docs0]
hits = []
for doc_text, meta, dist in zip(docs0, metas0, dists0):
hits.append({
"source": (meta or {}).get("source", ""),
"chunk": (meta or {}).get("chunk", None),
"distance": dist,
"text": doc_text,
})
best = hits[0].get("distance")
is_hit = (best is not None) and (best <= distance_threshold)
# Debug log
best_str = f"{best:.4f}" if best is not None else "N/A"
print(f"[KB Query] '{text[:50]}...' -> best_dist={best_str}, threshold={distance_threshold}, hit={is_hit}")
if hits:
top3_dists = [f"{h['distance']:.4f}" if h['distance'] is not None else "N/A" for h in hits[:3]]
print(f"[KB Query] Top 3 distances: {top3_dists}")
context = "\n---\n".join([h["text"] for h in hits]) if is_hit else ""
return {"hit": is_hit, "context": context, "hits": hits}
except Exception as e:
err_msg = str(e)
print(f"Query error: {err_msg}")
import traceback
traceback.print_exc()
# Auto-recover if HNSW/compaction/backfill errors surface during query
if allow_reset and any(x in err_msg.lower() for x in ["hnsw", "compaction", "segment reader", "backfill"]):
if not self._query_reset_guard:
print("Detected index corruption during query. Rebuilding vector_db and retrying once...")
self._query_reset_guard = True
try:
self._init_collection(recreate=True)
self.sync_knowledge(allow_reset=False)
# Retry query once with guard disabled to avoid loops
self._query_reset_guard = False
return self.query(text, top_k=top_k, distance_threshold=distance_threshold, allow_reset=False)
except Exception as inner_e:
print(f"Auto-rebuild after query failure also failed: {inner_e}")
self._query_reset_guard = False
return {"hit": False, "context": "", "hits": []}
def start_watcher(self):
event_handler = KBHandler(self)
self.observer = Observer()
self.observer.schedule(event_handler, self.kb_path, recursive=False)
self.observer.start()
def _load_manifest(self):
if not os.path.exists(self.manifest_path):
return {}
try:
with open(self.manifest_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception:
return {}
def _save_manifest(self, data):
try:
os.makedirs(self.db_path, exist_ok=True)
with open(self.manifest_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f" ! Failed to save manifest: {e}")
class KBHandler(FileSystemEventHandler):
def __init__(self, kb_instance):
self.kb = kb_instance
self.supported_extensions = (".txt", ".md", ".pdf", ".docx", ".json")
self._debounce_timer = None
def _trigger_sync(self):
def run():
self.kb.sync_knowledge()
if self._debounce_timer and self._debounce_timer.is_alive():
return
self._debounce_timer = threading.Timer(0.5, run)
self._debounce_timer.start()
def on_modified(self, event):
if not event.is_directory and event.src_path.lower().endswith(self.supported_extensions):
print(f"File modified: {event.src_path}. Re-syncing...")
self._trigger_sync()
def on_created(self, event):
if not event.is_directory and event.src_path.lower().endswith(self.supported_extensions):
print(f"File created: {event.src_path}. Syncing...")
self._trigger_sync()