Files
erp-ass/backend/app/core/rag_engine.py
dazhuang 04f7d372ea fix: add HuggingFace mirror support for China network
- Add HF_ENDPOINT configuration option
- Set HuggingFace mirror to https://hf-mirror.com
- Fix 'cannot connect to huggingface.co' error
- Update .env.example with HF_ENDPOINT setting

Problem:
- RAG engine uses sentence-transformers model
- Model download requires connection to huggingface.co
- China network cannot access huggingface.co
- Error: 'We couldn't connect to https://huggingface.co'

Solution:
- Add HF_ENDPOINT environment variable support
- Use hf-mirror.com as HuggingFace mirror
- Set mirror before loading sentence-transformers
- Document in .env.example

Files:
- backend/app/config.py: add HF_ENDPOINT config
- backend/app/core/rag_engine.py: set HF_ENDPOINT before model load
- backend/setup_hf_mirror.sh: setup script
- backend/.env: configure mirror (not tracked)
2026-03-22 03:33:17 +00:00

258 lines
8.5 KiB
Python

"""RAG Engine for ERP AI Assistant.
This module provides the RAGEngine class that handles knowledge document
storage and retrieval using ChromaDB and sentence-transformers embeddings.
"""
import os
from typing import Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
from loguru import logger
from app.config import get_settings
class RAGEngine:
"""RAG Engine for knowledge document retrieval.
This class wraps ChromaDB vector database and sentence-transformers
to provide semantic search over knowledge documents.
"""
# Class-level singleton for embedding model (lazy loading)
_embedding_model: Optional[SentenceTransformer] = None
def __init__(self) -> None:
"""Initialize RAG engine with ChromaDB and embedding model."""
settings = get_settings()
# Set HuggingFace mirror if configured
if settings.HF_ENDPOINT:
os.environ['HF_ENDPOINT'] = settings.HF_ENDPOINT
logger.info(f"Using HuggingFace mirror: {settings.HF_ENDPOINT}")
# Initialize ChromaDB persistent client
logger.info(f"Initializing ChromaDB at: {settings.CHROMA_DB_PATH}")
self.chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_DB_PATH,
settings=ChromaSettings(anonymized_telemetry=False)
)
# Load sentence-transformers embedding model (lazy loading, singleton)
logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
self.embedding_model = self._get_embedding_model(settings.EMBEDDING_MODEL)
# Get or create documents collection
self.documents_collection = self.chroma_client.get_or_create_collection(
name="documents"
)
# Store chunking settings
self.chunk_size = settings.CHUNK_SIZE
self.chunk_overlap = settings.CHUNK_OVERLAP
logger.info(
f"RAG Engine initialized: chunk_size={self.chunk_size}, "
f"chunk_overlap={self.chunk_overlap}"
)
@classmethod
def _get_embedding_model(cls, model_name: str) -> SentenceTransformer:
"""Get or create the embedding model (lazy loading, singleton).
Args:
model_name: Name of the embedding model to load
Returns:
SentenceTransformer embedding model instance
"""
if cls._embedding_model is None:
logger.info(f"Loading embedding model: {model_name}")
cls._embedding_model = SentenceTransformer(model_name)
return cls._embedding_model
def _split_text(self, text: str) -> list[str]:
"""Split text into overlapping chunks.
Args:
text: The text to split
Returns:
List of chunk strings
"""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + self.chunk_size
chunk = text[start:end]
if chunk.strip(): # Only add non-empty chunks
chunks.append(chunk)
start += self.chunk_size - self.chunk_overlap
# Avoid infinite loop if overlap >= chunk_size
if self.chunk_overlap >= self.chunk_size:
start += 1
return chunks
def _delete_chunks_for_doc(self, doc_id: str) -> None:
"""Delete all chunks associated with a document.
Args:
doc_id: The document ID to delete chunks for
"""
try:
# Find all chunks for this document
results = self.documents_collection.get(
where={"doc_id": doc_id},
include=[]
)
if results and results.get("ids"):
self.documents_collection.delete(ids=results["ids"])
logger.debug(f"Deleted {len(results['ids'])} chunks for document '{doc_id}'")
except Exception as e:
logger.warning(f"Failed to delete chunks for document '{doc_id}': {e}")
def add_document(
self,
doc_id: str,
content: str,
metadata: Optional[dict] = None
) -> int:
"""Add a document to the knowledge base.
Args:
doc_id: Unique identifier for the document
content: The document content to index
metadata: Optional metadata dict to store with the document
Returns:
Number of chunks added
Raises:
ValueError: If content is empty
"""
if not content or not content.strip():
raise ValueError("Cannot add empty document")
try:
# Delete existing chunks for this doc_id (handles duplicates)
self._delete_chunks_for_doc(doc_id)
# Split content into chunks
chunks = self._split_text(content)
logger.info(f"Split document '{doc_id}' into {len(chunks)} chunks")
if not chunks:
return 0
# Generate embeddings for all chunks
logger.debug(f"Generating embeddings for {len(chunks)} chunks")
embeddings = self.embedding_model.encode(chunks)
# Prepare chunk IDs and metadata
chunk_ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
# Add metadata to each chunk
chunk_metadata = []
base_metadata = metadata or {}
for i, chunk in enumerate(chunks):
meta = {
**base_metadata,
"doc_id": doc_id,
"chunk_index": i,
"total_chunks": len(chunks)
}
chunk_metadata.append(meta)
# Add to ChromaDB
self.documents_collection.add(
ids=chunk_ids,
embeddings=embeddings.tolist(),
documents=chunks,
metadatas=chunk_metadata
)
logger.info(f"Added {len(chunks)} chunks for document '{doc_id}'")
return len(chunks)
except Exception as e:
logger.error(f"Failed to add document '{doc_id}': {e}")
raise
def search(self, query: str, top_k: int = 3) -> list[dict]:
"""Search for relevant document chunks.
Args:
query: The search query
top_k: Number of results to return (default: 3, max: 100)
Returns:
List of dicts with 'content', 'metadata', and 'distance'
Raises:
ValueError: If top_k exceeds maximum limit
"""
# Validate top_k
if top_k > 100:
raise ValueError(f"top_k cannot exceed 100 (got: {top_k})")
if not query or not query.strip():
logger.warning("Empty search query provided")
return []
try:
# Generate embedding for query
logger.debug(f"Generating embedding for query: {query[:50]}...")
query_embedding = self.embedding_model.encode([query])
# Query ChromaDB
results = self.documents_collection.query(
query_embeddings=query_embedding.tolist(),
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
# Format results
formatted_results = []
if results and results.get("documents"):
documents = results["documents"][0]
metadatas = results["metadatas"][0] if results.get("metadatas") else []
distances = results["distances"][0] if results.get("distances") else []
for i, content in enumerate(documents):
formatted_results.append({
"content": content,
"metadata": metadatas[i] if i < len(metadatas) else {},
"distance": distances[i] if i < len(distances) else None
})
logger.info(f"Found {len(formatted_results)} results for query")
return formatted_results
except Exception as e:
logger.error(f"Search failed: {e}")
raise
def close(self) -> None:
"""Release resources and cleanup the RAG engine.
This method should be called when the engine is no longer needed
to free up memory and other resources.
"""
logger.info("Closing RAG engine and releasing resources")
self.embedding_model = None
self.documents_collection = None
self.chroma_client = None