- Add HF_ENDPOINT configuration option - Set HuggingFace mirror to https://hf-mirror.com - Fix 'cannot connect to huggingface.co' error - Update .env.example with HF_ENDPOINT setting Problem: - RAG engine uses sentence-transformers model - Model download requires connection to huggingface.co - China network cannot access huggingface.co - Error: 'We couldn't connect to https://huggingface.co' Solution: - Add HF_ENDPOINT environment variable support - Use hf-mirror.com as HuggingFace mirror - Set mirror before loading sentence-transformers - Document in .env.example Files: - backend/app/config.py: add HF_ENDPOINT config - backend/app/core/rag_engine.py: set HF_ENDPOINT before model load - backend/setup_hf_mirror.sh: setup script - backend/.env: configure mirror (not tracked)
258 lines
8.5 KiB
Python
258 lines
8.5 KiB
Python
"""RAG Engine for ERP AI Assistant.
|
|
|
|
This module provides the RAGEngine class that handles knowledge document
|
|
storage and retrieval using ChromaDB and sentence-transformers embeddings.
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import chromadb
|
|
from chromadb.config import Settings as ChromaSettings
|
|
from sentence_transformers import SentenceTransformer
|
|
from loguru import logger
|
|
|
|
from app.config import get_settings
|
|
|
|
|
|
class RAGEngine:
|
|
"""RAG Engine for knowledge document retrieval.
|
|
|
|
This class wraps ChromaDB vector database and sentence-transformers
|
|
to provide semantic search over knowledge documents.
|
|
"""
|
|
|
|
# Class-level singleton for embedding model (lazy loading)
|
|
_embedding_model: Optional[SentenceTransformer] = None
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize RAG engine with ChromaDB and embedding model."""
|
|
settings = get_settings()
|
|
|
|
# Set HuggingFace mirror if configured
|
|
if settings.HF_ENDPOINT:
|
|
os.environ['HF_ENDPOINT'] = settings.HF_ENDPOINT
|
|
logger.info(f"Using HuggingFace mirror: {settings.HF_ENDPOINT}")
|
|
|
|
# Initialize ChromaDB persistent client
|
|
logger.info(f"Initializing ChromaDB at: {settings.CHROMA_DB_PATH}")
|
|
self.chroma_client = chromadb.PersistentClient(
|
|
path=settings.CHROMA_DB_PATH,
|
|
settings=ChromaSettings(anonymized_telemetry=False)
|
|
)
|
|
|
|
# Load sentence-transformers embedding model (lazy loading, singleton)
|
|
logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
|
|
self.embedding_model = self._get_embedding_model(settings.EMBEDDING_MODEL)
|
|
|
|
# Get or create documents collection
|
|
self.documents_collection = self.chroma_client.get_or_create_collection(
|
|
name="documents"
|
|
)
|
|
|
|
# Store chunking settings
|
|
self.chunk_size = settings.CHUNK_SIZE
|
|
self.chunk_overlap = settings.CHUNK_OVERLAP
|
|
|
|
logger.info(
|
|
f"RAG Engine initialized: chunk_size={self.chunk_size}, "
|
|
f"chunk_overlap={self.chunk_overlap}"
|
|
)
|
|
|
|
@classmethod
|
|
def _get_embedding_model(cls, model_name: str) -> SentenceTransformer:
|
|
"""Get or create the embedding model (lazy loading, singleton).
|
|
|
|
Args:
|
|
model_name: Name of the embedding model to load
|
|
|
|
Returns:
|
|
SentenceTransformer embedding model instance
|
|
"""
|
|
if cls._embedding_model is None:
|
|
logger.info(f"Loading embedding model: {model_name}")
|
|
cls._embedding_model = SentenceTransformer(model_name)
|
|
return cls._embedding_model
|
|
|
|
def _split_text(self, text: str) -> list[str]:
|
|
"""Split text into overlapping chunks.
|
|
|
|
Args:
|
|
text: The text to split
|
|
|
|
Returns:
|
|
List of chunk strings
|
|
"""
|
|
if not text:
|
|
return []
|
|
|
|
chunks = []
|
|
start = 0
|
|
text_length = len(text)
|
|
|
|
while start < text_length:
|
|
end = start + self.chunk_size
|
|
chunk = text[start:end]
|
|
|
|
if chunk.strip(): # Only add non-empty chunks
|
|
chunks.append(chunk)
|
|
|
|
start += self.chunk_size - self.chunk_overlap
|
|
|
|
# Avoid infinite loop if overlap >= chunk_size
|
|
if self.chunk_overlap >= self.chunk_size:
|
|
start += 1
|
|
|
|
return chunks
|
|
|
|
def _delete_chunks_for_doc(self, doc_id: str) -> None:
|
|
"""Delete all chunks associated with a document.
|
|
|
|
Args:
|
|
doc_id: The document ID to delete chunks for
|
|
"""
|
|
try:
|
|
# Find all chunks for this document
|
|
results = self.documents_collection.get(
|
|
where={"doc_id": doc_id},
|
|
include=[]
|
|
)
|
|
if results and results.get("ids"):
|
|
self.documents_collection.delete(ids=results["ids"])
|
|
logger.debug(f"Deleted {len(results['ids'])} chunks for document '{doc_id}'")
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete chunks for document '{doc_id}': {e}")
|
|
|
|
def add_document(
|
|
self,
|
|
doc_id: str,
|
|
content: str,
|
|
metadata: Optional[dict] = None
|
|
) -> int:
|
|
"""Add a document to the knowledge base.
|
|
|
|
Args:
|
|
doc_id: Unique identifier for the document
|
|
content: The document content to index
|
|
metadata: Optional metadata dict to store with the document
|
|
|
|
Returns:
|
|
Number of chunks added
|
|
|
|
Raises:
|
|
ValueError: If content is empty
|
|
"""
|
|
if not content or not content.strip():
|
|
raise ValueError("Cannot add empty document")
|
|
|
|
try:
|
|
# Delete existing chunks for this doc_id (handles duplicates)
|
|
self._delete_chunks_for_doc(doc_id)
|
|
|
|
# Split content into chunks
|
|
chunks = self._split_text(content)
|
|
logger.info(f"Split document '{doc_id}' into {len(chunks)} chunks")
|
|
|
|
if not chunks:
|
|
return 0
|
|
|
|
# Generate embeddings for all chunks
|
|
logger.debug(f"Generating embeddings for {len(chunks)} chunks")
|
|
embeddings = self.embedding_model.encode(chunks)
|
|
|
|
# Prepare chunk IDs and metadata
|
|
chunk_ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
|
|
|
|
# Add metadata to each chunk
|
|
chunk_metadata = []
|
|
base_metadata = metadata or {}
|
|
for i, chunk in enumerate(chunks):
|
|
meta = {
|
|
**base_metadata,
|
|
"doc_id": doc_id,
|
|
"chunk_index": i,
|
|
"total_chunks": len(chunks)
|
|
}
|
|
chunk_metadata.append(meta)
|
|
|
|
# Add to ChromaDB
|
|
self.documents_collection.add(
|
|
ids=chunk_ids,
|
|
embeddings=embeddings.tolist(),
|
|
documents=chunks,
|
|
metadatas=chunk_metadata
|
|
)
|
|
|
|
logger.info(f"Added {len(chunks)} chunks for document '{doc_id}'")
|
|
return len(chunks)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to add document '{doc_id}': {e}")
|
|
raise
|
|
|
|
def search(self, query: str, top_k: int = 3) -> list[dict]:
|
|
"""Search for relevant document chunks.
|
|
|
|
Args:
|
|
query: The search query
|
|
top_k: Number of results to return (default: 3, max: 100)
|
|
|
|
Returns:
|
|
List of dicts with 'content', 'metadata', and 'distance'
|
|
|
|
Raises:
|
|
ValueError: If top_k exceeds maximum limit
|
|
"""
|
|
# Validate top_k
|
|
if top_k > 100:
|
|
raise ValueError(f"top_k cannot exceed 100 (got: {top_k})")
|
|
|
|
if not query or not query.strip():
|
|
logger.warning("Empty search query provided")
|
|
return []
|
|
|
|
try:
|
|
# Generate embedding for query
|
|
logger.debug(f"Generating embedding for query: {query[:50]}...")
|
|
query_embedding = self.embedding_model.encode([query])
|
|
|
|
# Query ChromaDB
|
|
results = self.documents_collection.query(
|
|
query_embeddings=query_embedding.tolist(),
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"]
|
|
)
|
|
|
|
# Format results
|
|
formatted_results = []
|
|
|
|
if results and results.get("documents"):
|
|
documents = results["documents"][0]
|
|
metadatas = results["metadatas"][0] if results.get("metadatas") else []
|
|
distances = results["distances"][0] if results.get("distances") else []
|
|
|
|
for i, content in enumerate(documents):
|
|
formatted_results.append({
|
|
"content": content,
|
|
"metadata": metadatas[i] if i < len(metadatas) else {},
|
|
"distance": distances[i] if i < len(distances) else None
|
|
})
|
|
|
|
logger.info(f"Found {len(formatted_results)} results for query")
|
|
return formatted_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Search failed: {e}")
|
|
raise
|
|
|
|
def close(self) -> None:
|
|
"""Release resources and cleanup the RAG engine.
|
|
|
|
This method should be called when the engine is no longer needed
|
|
to free up memory and other resources.
|
|
"""
|
|
logger.info("Closing RAG engine and releasing resources")
|
|
self.embedding_model = None
|
|
self.documents_collection = None
|
|
self.chroma_client = None
|