"""Config Executor for ERP AI Assistant. This module provides the ConfigExecutor class for validating and executing SQL configuration statements with safety checks. """ import re from typing import List, Tuple, Dict, Any from loguru import logger from app.core.db_engine import DatabaseEngine class ConfigExecutor: """Executor for SQL configuration statements with safety validation. This class validates SQL statements against dangerous operations before execution and provides transaction-based execution with rollback support. """ # Dangerous SQL keywords that should be blocked DANGEROUS_KEYWORDS = [ r"DROP\s+DATABASE", r"DROP\s+TABLE", r"TRUNCATE\s+TABLE", r"DELETE\s+FROM", r"UPDATE\s+.*\s+SET", r"ALTER\s+TABLE\s+.*\s+DROP" ] def __init__(self) -> None: """Initialize executor with database engine.""" self.db_engine = DatabaseEngine() logger.info("ConfigExecutor initialized") def validate_sql(self, sql: str) -> Tuple[bool, str]: """Validate SQL statement for safety. Checks SQL against a list of dangerous keywords/patterns to prevent destructive operations. Args: sql: SQL statement to validate Returns: Tuple of (is_valid, message) where is_valid indicates if SQL is safe """ if not sql or not sql.strip(): return False, "SQL语句为空" sql_upper = sql.upper().strip() # Check for dangerous operations for pattern in self.DANGEROUS_KEYWORDS: if re.search(pattern, sql_upper): # Extract matched keyword for error message match = re.search(pattern, sql_upper) matched_keyword = match.group(0) if match else pattern logger.warning(f"SQL validation failed: dangerous operation '{matched_keyword}' detected") return False, f"危险操作被拦截: {matched_keyword}" logger.debug(f"SQL validation passed: {sql[:50]}...") return True, "SQL验证通过" def execute_config( self, sql_list: List[str], session_id: str ) -> Dict[str, Any]: """Execute a list of SQL statements in a transaction. Validates all SQL statements before execution. If any validation fails, no statements are executed. Args: sql_list: List of SQL statements to execute session_id: Session ID for logging and tracking Returns: Dictionary containing: - success: Boolean indicating overall success - executed: List of executed SQL statements - failed: Error message if execution failed, None otherwise - message: Human-readable result message """ logger.info(f"[{session_id}] Starting config execution with {len(sql_list)} SQL statements") results: Dict[str, Any] = { "success": True, "executed": [], "failed": None, "message": "" } try: # Step 1: Validate all SQL statements logger.debug(f"[{session_id}] Validating {len(sql_list)} SQL statements") for i, sql in enumerate(sql_list): is_valid, msg = self.validate_sql(sql) if not is_valid: error_msg = f"SQL #{i+1} 验证失败: {msg}" logger.error(f"[{session_id}] {error_msg}") raise ValueError(error_msg) # Step 2: Execute transaction logger.debug(f"[{session_id}] Executing transaction") self.db_engine.execute_transaction(sql_list) # Step 3: Record success results["executed"] = sql_list results["message"] = f"成功执行 {len(sql_list)} 条SQL" logger.success(f"[{session_id}] {results['message']}") except ValueError as e: # Validation failure results["success"] = False results["failed"] = str(e) results["message"] = f"执行失败: {e}" logger.error(f"[{session_id}] {results['message']}") except Exception as e: # Execution failure results["success"] = False results["failed"] = str(e) results["message"] = f"执行失败: {e}" logger.error(f"[{session_id}] {results['message']}") return results def rollback(self, session_id: str) -> Dict[str, Any]: """Rollback executed operations for a session. This is a placeholder for rollback functionality. Actual implementation would require recording inverse SQL operations during execution. Args: session_id: Session ID to rollback Returns: Dictionary with success status and message """ logger.warning(f"[{session_id}] Rollback requested but not yet implemented") return { "success": False, "message": "回滚功能待实现" }