"""RAG Engine for ERP AI Assistant. This module provides the RAGEngine class that handles knowledge document storage and retrieval using ChromaDB and sentence-transformers embeddings. """ import os from typing import Optional import chromadb from chromadb.config import Settings as ChromaSettings from sentence_transformers import SentenceTransformer from loguru import logger from app.config import get_settings class RAGEngine: """RAG Engine for knowledge document retrieval. This class wraps ChromaDB vector database and sentence-transformers to provide semantic search over knowledge documents. """ # Class-level singleton for embedding model (lazy loading) _embedding_model: Optional[SentenceTransformer] = None def __init__(self) -> None: """Initialize RAG engine with ChromaDB and embedding model.""" settings = get_settings() # Set HuggingFace mirror if configured if settings.HF_ENDPOINT: os.environ['HF_ENDPOINT'] = settings.HF_ENDPOINT logger.info(f"Using HuggingFace mirror: {settings.HF_ENDPOINT}") # Initialize ChromaDB persistent client logger.info(f"Initializing ChromaDB at: {settings.CHROMA_DB_PATH}") self.chroma_client = chromadb.PersistentClient( path=settings.CHROMA_DB_PATH, settings=ChromaSettings(anonymized_telemetry=False) ) # Load sentence-transformers embedding model (lazy loading, singleton) logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}") self.embedding_model = self._get_embedding_model(settings.EMBEDDING_MODEL) # Get or create documents collection self.documents_collection = self.chroma_client.get_or_create_collection( name="documents" ) # Store chunking settings self.chunk_size = settings.CHUNK_SIZE self.chunk_overlap = settings.CHUNK_OVERLAP logger.info( f"RAG Engine initialized: chunk_size={self.chunk_size}, " f"chunk_overlap={self.chunk_overlap}" ) @classmethod def _get_embedding_model(cls, model_name: str) -> SentenceTransformer: """Get or create the embedding model (lazy loading, singleton). Args: model_name: Name of the embedding model to load Returns: SentenceTransformer embedding model instance """ if cls._embedding_model is None: logger.info(f"Loading embedding model: {model_name}") cls._embedding_model = SentenceTransformer(model_name) return cls._embedding_model def _split_text(self, text: str) -> list[str]: """Split text into overlapping chunks. Args: text: The text to split Returns: List of chunk strings """ if not text: return [] chunks = [] start = 0 text_length = len(text) while start < text_length: end = start + self.chunk_size chunk = text[start:end] if chunk.strip(): # Only add non-empty chunks chunks.append(chunk) start += self.chunk_size - self.chunk_overlap # Avoid infinite loop if overlap >= chunk_size if self.chunk_overlap >= self.chunk_size: start += 1 return chunks def _delete_chunks_for_doc(self, doc_id: str) -> None: """Delete all chunks associated with a document. Args: doc_id: The document ID to delete chunks for """ try: # Find all chunks for this document results = self.documents_collection.get( where={"doc_id": doc_id}, include=[] ) if results and results.get("ids"): self.documents_collection.delete(ids=results["ids"]) logger.debug(f"Deleted {len(results['ids'])} chunks for document '{doc_id}'") except Exception as e: logger.warning(f"Failed to delete chunks for document '{doc_id}': {e}") def add_document( self, doc_id: str, content: str, metadata: Optional[dict] = None ) -> int: """Add a document to the knowledge base. Args: doc_id: Unique identifier for the document content: The document content to index metadata: Optional metadata dict to store with the document Returns: Number of chunks added Raises: ValueError: If content is empty """ if not content or not content.strip(): raise ValueError("Cannot add empty document") try: # Delete existing chunks for this doc_id (handles duplicates) self._delete_chunks_for_doc(doc_id) # Split content into chunks chunks = self._split_text(content) logger.info(f"Split document '{doc_id}' into {len(chunks)} chunks") if not chunks: return 0 # Generate embeddings for all chunks logger.debug(f"Generating embeddings for {len(chunks)} chunks") embeddings = self.embedding_model.encode(chunks) # Prepare chunk IDs and metadata chunk_ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))] # Add metadata to each chunk chunk_metadata = [] base_metadata = metadata or {} for i, chunk in enumerate(chunks): meta = { **base_metadata, "doc_id": doc_id, "chunk_index": i, "total_chunks": len(chunks) } chunk_metadata.append(meta) # Add to ChromaDB self.documents_collection.add( ids=chunk_ids, embeddings=embeddings.tolist(), documents=chunks, metadatas=chunk_metadata ) logger.info(f"Added {len(chunks)} chunks for document '{doc_id}'") return len(chunks) except Exception as e: logger.error(f"Failed to add document '{doc_id}': {e}") raise def search(self, query: str, top_k: int = 3) -> list[dict]: """Search for relevant document chunks. Args: query: The search query top_k: Number of results to return (default: 3, max: 100) Returns: List of dicts with 'content', 'metadata', and 'distance' Raises: ValueError: If top_k exceeds maximum limit """ # Validate top_k if top_k > 100: raise ValueError(f"top_k cannot exceed 100 (got: {top_k})") if not query or not query.strip(): logger.warning("Empty search query provided") return [] try: # Generate embedding for query logger.debug(f"Generating embedding for query: {query[:50]}...") query_embedding = self.embedding_model.encode([query]) # Query ChromaDB results = self.documents_collection.query( query_embeddings=query_embedding.tolist(), n_results=top_k, include=["documents", "metadatas", "distances"] ) # Format results formatted_results = [] if results and results.get("documents"): documents = results["documents"][0] metadatas = results["metadatas"][0] if results.get("metadatas") else [] distances = results["distances"][0] if results.get("distances") else [] for i, content in enumerate(documents): formatted_results.append({ "content": content, "metadata": metadatas[i] if i < len(metadatas) else {}, "distance": distances[i] if i < len(distances) else None }) logger.info(f"Found {len(formatted_results)} results for query") return formatted_results except Exception as e: logger.error(f"Search failed: {e}") raise def close(self) -> None: """Release resources and cleanup the RAG engine. This method should be called when the engine is no longer needed to free up memory and other resources. """ logger.info("Closing RAG engine and releasing resources") self.embedding_model = None self.documents_collection = None self.chroma_client = None