feat: implement ERP AI Assistant Phase 1

Backend (FastAPI + SQLAlchemy + Claude API + RAG):
- Config management with Pydantic v2
- Database engine with connection pooling and SQL injection prevention
- AI engine with Claude API integration (support custom base URL)
- RAG engine with ChromaDB and sentence-transformers
- Requirement analysis service
- Config generation service
- Executor engine with SQL validation
- REST API endpoints: /analyze, /generate, /execute

Frontend (Vue 3 + Element Plus + Pinia):
- Complete 3-step workflow: analyze → generate → execute
- Step indicator with progress visualization
- Analysis result display with field table
- SQL preview with monospace font
- Execute confirmation dialog with safety warning
- Execution result display
- State management with Pinia
- API service integration

Security:
- SQL injection prevention with parameterized queries
- Dangerous SQL operation blocking
- Database password URL encoding
- Transaction auto-rollback
- Pydantic config validation

Features:
- Natural language requirement analysis
- Automated SQL configuration generation
- Safe execution with human review
- LAN access support
- Custom Claude API endpoint support

Documentation:
- README with quick start guide
- Quick start guide
- LAN access configuration
- Dependency fixes guide
- Claude API configuration
- Git operation guide
- Implementation report

Dependencies fixed:
- numpy<2.0.0 for chromadb compatibility
- sentence-transformers==2.7.0 for huggingface_hub compatibility

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-21 14:23:20 +00:00
commit acd73431ae
60 changed files with 11284 additions and 0 deletions

26
backend/.env.example Normal file
View File

@@ -0,0 +1,26 @@
APP_NAME=ERP AI Assistant
APP_ENV=development
DEBUG=True
SECRET_KEY=change-this-in-production
# Database
DB_DRIVER=ODBC Driver 17 for SQL Server
DB_SERVER=192.168.120.19
DB_PORT=1433
DB_NAME=DMPF_HY
DB_USER=sa
DB_PASSWORD=your-password
# Claude API
ANTHROPIC_API_KEY=your-claude-api-key
# ANTHROPIC_BASE_URL=https://api.anthropic.com # Optional: uncomment to use custom base URL (for proxy or self-hosted)
CLAUDE_MODEL=claude-sonnet-4-6
CLAUDE_MAX_TOKENS=8192
CLAUDE_TEMPERATURE=0.7
# Knowledge Base
KNOWLEDGE_BASE_PATH=./knowledge_base
CHROMA_DB_PATH=./knowledge_base/chroma_db
EMBEDDING_MODEL=all-MiniLM-L6-v2
CHUNK_SIZE=500
CHUNK_OVERLAP=50

2
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""ERP AI Assistant Backend"""
__version__ = "1.0.0"

View File

@@ -0,0 +1 @@
"""API routes for ERP AI Assistant."""

113
backend/app/api/analyze.py Normal file
View File

@@ -0,0 +1,113 @@
"""Analyze API endpoint for requirement analysis.
This module provides the /analyze endpoint for analyzing user requirements.
"""
import uuid
from typing import Dict
from fastapi import APIRouter, HTTPException, status
from loguru import logger
from app.models.request import AnalyzeRequest
from app.models.response import AnalyzeResponse, ErrorResponse
from app.services.requirement_service import RequirementService
# Create router
router = APIRouter()
@router.post(
"/analyze",
response_model=AnalyzeResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
500: {"model": ErrorResponse, "description": "Internal server error"}
},
summary="Analyze user requirement",
description="Analyze natural language or structured requirement and return structured specification"
)
async def analyze_requirement(request: AnalyzeRequest) -> AnalyzeResponse:
"""Analyze user requirement and return structured specification.
This endpoint accepts either natural language or structured input,
processes it through Claude AI with RAG knowledge retrieval, and
returns a structured requirement specification.
Args:
request: AnalyzeRequest containing input_type, content, and optional session_id
Returns:
AnalyzeResponse with session_id, status, and structured data
Raises:
HTTPException: 400 for invalid input, 500 for processing errors
"""
# Generate session ID if not provided
session_id = request.session_id or str(uuid.uuid4())
try:
# Validate input type
if request.input_type not in ["natural_language", "structured"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "INVALID_INPUT_TYPE",
"message": "input_type must be 'natural_language' or 'structured'",
"session_id": session_id
}
)
# Validate content
if not request.content or not request.content.strip():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "EMPTY_CONTENT",
"message": "content cannot be empty",
"session_id": session_id
}
)
logger.info(f"[{session_id}] Processing analyze request: {request.content[:50]}...")
# Create service and analyze
service = RequirementService()
result = await service.analyze(
user_input=request.content,
session_id=session_id
)
logger.success(f"[{session_id}] Analysis completed successfully")
return AnalyzeResponse(
session_id=session_id,
status="success",
data=result
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except ValueError as e:
logger.error(f"[{session_id}] Validation error: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "VALIDATION_ERROR",
"message": str(e),
"session_id": session_id
}
)
except Exception as e:
logger.error(f"[{session_id}] Analysis failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "ANALYSIS_FAILED",
"message": f"Failed to analyze requirement: {str(e)}",
"session_id": session_id
}
)

151
backend/app/api/execute.py Normal file
View File

@@ -0,0 +1,151 @@
"""Execute API endpoint for SQL configuration execution.
This module provides the /execute endpoint for executing SQL configuration.
"""
import uuid
from typing import Dict, Any
from fastapi import APIRouter, HTTPException, status
from loguru import logger
from app.models.request import ExecuteRequest
from app.models.response import ExecuteResponse, ErrorResponse
from app.core.executor import ConfigExecutor
# Create router
router = APIRouter()
# In-memory storage for SQL lists (should use Redis/database in production)
_session_sql_store: Dict[str, list] = {}
def store_session_sql(session_id: str, sql_list: list) -> None:
"""Store SQL list for a session.
Args:
session_id: Session ID
sql_list: List of SQL statements
"""
_session_sql_store[session_id] = sql_list
logger.debug(f"Stored {len(sql_list)} SQL statements for session {session_id}")
def get_session_sql(session_id: str) -> list:
"""Retrieve SQL list for a session.
Args:
session_id: Session ID
Returns:
List of SQL statements (empty list if not found)
"""
sql_list = _session_sql_store.get(session_id, [])
logger.debug(f"Retrieved {len(sql_list)} SQL statements for session {session_id}")
return sql_list
@router.post(
"/execute",
response_model=ExecuteResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
500: {"model": ErrorResponse, "description": "Internal server error"}
},
summary="Execute SQL configuration",
description="Execute SQL configuration after user confirmation"
)
async def execute_config(request: ExecuteRequest) -> ExecuteResponse:
"""Execute SQL configuration after user confirmation.
This endpoint executes the SQL statements associated with the session.
User must set confirmed=True to proceed with execution.
Args:
request: ExecuteRequest with session_id, confirmed, and backup_enabled
Returns:
ExecuteResponse with execution_id, status, and message
Raises:
HTTPException: 400 if not confirmed or invalid, 500 for execution errors
"""
# Generate execution ID
execution_id = str(uuid.uuid4())
try:
# Check user confirmation
if not request.confirmed:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "NOT_CONFIRMED",
"message": "User must confirm execution by setting confirmed=True",
"session_id": request.session_id,
"execution_id": execution_id
}
)
logger.info(f"[{request.session_id}] Processing execute request")
# Retrieve SQL list for session
sql_list = get_session_sql(request.session_id)
if not sql_list:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "NO_SQL_FOUND",
"message": "No SQL statements found for this session",
"session_id": request.session_id,
"execution_id": execution_id
}
)
logger.info(f"[{request.session_id}] Retrieved {len(sql_list)} SQL statements")
# Create executor
executor = ConfigExecutor()
# Execute configuration
result = executor.execute_config(sql_list, request.session_id)
if result["success"]:
logger.success(
f"[{request.session_id}] Execution completed: {result['message']}"
)
return ExecuteResponse(
execution_id=execution_id,
status="success",
message=result["message"]
)
else:
logger.error(
f"[{request.session_id}] Execution failed: {result['failed']}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "EXECUTION_FAILED",
"message": result["message"],
"error": result["failed"],
"session_id": request.session_id,
"execution_id": execution_id
}
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except Exception as e:
logger.error(f"[{request.session_id}] Execution failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "EXECUTION_ERROR",
"message": f"Failed to execute config: {str(e)}",
"session_id": request.session_id,
"execution_id": execution_id
}
)

102
backend/app/api/generate.py Normal file
View File

@@ -0,0 +1,102 @@
"""Generate API endpoint for configuration generation.
This module provides the /generate endpoint for generating SQL configuration.
"""
from fastapi import APIRouter, HTTPException, status
from loguru import logger
from app.models.request import GenerateRequest
from app.models.response import GenerateResponse, ErrorResponse
from app.services.config_service import ConfigService
from app.api.execute import store_session_sql # Import SQL storage function
# Create router
router = APIRouter()
@router.post(
"/generate",
response_model=GenerateResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid request"},
500: {"model": ErrorResponse, "description": "Internal server error"}
},
summary="Generate SQL configuration",
description="Generate SQL configuration based on structured requirements"
)
async def generate_config(request: GenerateRequest) -> GenerateResponse:
"""Generate SQL configuration based on structured requirements.
This endpoint takes structured requirements from the analysis phase
and generates SQL configuration statements using Claude AI.
Args:
request: GenerateRequest with session_id and requirements
Returns:
GenerateResponse with session_id, status, and generated config
Raises:
HTTPException: 400 for invalid input, 500 for processing errors
"""
try:
# Validate requirements
if not request.requirements:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "EMPTY_REQUIREMENTS",
"message": "requirements cannot be empty",
"session_id": request.session_id
}
)
logger.info(f"[{request.session_id}] Processing generate request")
# Create service and generate config
service = ConfigService()
result = await service.generate(
requirements=request.requirements,
session_id=request.session_id
)
# Store generated SQL for later execution
if result and result.get("配置方案") and result["配置方案"].get("sql_list"):
sql_list = result["配置方案"]["sql_list"]
store_session_sql(request.session_id, sql_list)
logger.info(f"[{request.session_id}] Stored {len(sql_list)} SQL statements for execution")
logger.success(f"[{request.session_id}] Config generation completed")
return GenerateResponse(
session_id=request.session_id,
status="success",
data=result
)
except HTTPException:
# Re-raise HTTP exceptions
raise
except ValueError as e:
logger.error(f"[{request.session_id}] Validation error: {e}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"code": "VALIDATION_ERROR",
"message": str(e),
"session_id": request.session_id
}
)
except Exception as e:
logger.error(f"[{request.session_id}] Config generation failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"code": "GENERATION_FAILED",
"message": f"Failed to generate config: {str(e)}",
"session_id": request.session_id
}
)

67
backend/app/config.py Normal file
View File

@@ -0,0 +1,67 @@
from pydantic_settings import BaseSettings
from pydantic import ConfigDict, field_validator
from functools import lru_cache
from urllib.parse import quote_plus
class Settings(BaseSettings):
# Application
APP_NAME: str = "ERP AI Assistant"
APP_ENV: str = "development"
DEBUG: bool = True
SECRET_KEY: str
# Database
DB_DRIVER: str
DB_SERVER: str
DB_PORT: int = 1433
DB_NAME: str
DB_USER: str
DB_PASSWORD: str
# Claude API
ANTHROPIC_API_KEY: str
ANTHROPIC_BASE_URL: str | None = None # Optional custom base URL for proxy/self-hosted
CLAUDE_MODEL: str = "claude-sonnet-4-6"
CLAUDE_MAX_TOKENS: int = 8192
CLAUDE_TEMPERATURE: float = 0.7
# Knowledge Base
KNOWLEDGE_BASE_PATH: str = "./knowledge_base"
CHROMA_DB_PATH: str = "./knowledge_base/chroma_db"
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
CHUNK_SIZE: int = 500
CHUNK_OVERLAP: int = 50
@property
def DATABASE_URL(self) -> str:
"""构建数据库连接 URL密码安全编码"""
password = quote_plus(self.DB_PASSWORD)
return (
f"mssql+pyodbc://{self.DB_USER}:{password}"
f"@{self.DB_SERVER}:{self.DB_PORT}/{self.DB_NAME}"
f"?driver={quote_plus(self.DB_DRIVER)}"
)
@field_validator('CLAUDE_TEMPERATURE')
@classmethod
def validate_temperature(cls, v):
if not 0 <= v <= 2:
raise ValueError('CLAUDE_TEMPERATURE must be between 0 and 2')
return v
@field_validator('CHUNK_OVERLAP')
@classmethod
def validate_chunk_overlap(cls, v, info):
chunk_size = info.data.get('CHUNK_SIZE', 500)
if v >= chunk_size:
raise ValueError('CHUNK_OVERLAP must be less than CHUNK_SIZE')
return v
model_config = ConfigDict(env_file=".env", case_sensitive=True)
@lru_cache()
def get_settings() -> Settings:
"""获取配置单例"""
return Settings()

View File

@@ -0,0 +1 @@
"""Core modules"""

View File

@@ -0,0 +1,120 @@
"""AI Engine for ERP AI Assistant.
This module provides the ClaudeEngine class that wraps Claude API calls
and provides JSON parsing utilities.
"""
import json
import re
from typing import Any
import anthropic
from loguru import logger
from app.config import get_settings
class ClaudeEngine:
"""Engine for interacting with Claude API.
This class wraps the Anthropic Claude API client and provides
utilities for parsing JSON responses from Claude.
"""
def __init__(self) -> None:
"""Initialize Claude engine with settings."""
settings = get_settings()
# Initialize Anthropic client with optional custom base_url
client_kwargs = {"api_key": settings.ANTHROPIC_API_KEY}
if settings.ANTHROPIC_BASE_URL:
client_kwargs["base_url"] = settings.ANTHROPIC_BASE_URL
logger.info(f"Using custom Anthropic base URL: {settings.ANTHROPIC_BASE_URL}")
self.client = anthropic.AsyncAnthropic(**client_kwargs)
self.model = settings.CLAUDE_MODEL
self.max_tokens = settings.CLAUDE_MAX_TOKENS
self.temperature = settings.CLAUDE_TEMPERATURE
def parse_json_response(self, content: str) -> dict[str, Any]:
"""Parse JSON from Claude responses.
Attempts multiple parsing strategies:
1. Direct JSON parse
2. Extract from markdown code blocks
3. Extract any {...} block
Args:
content: The response content from Claude
Returns:
Parsed JSON as a dictionary
Raises:
ValueError: If JSON cannot be parsed using any strategy
"""
if not content or not content.strip():
raise ValueError("Empty content provided")
# Strategy 1: Try direct JSON parse
try:
return json.loads(content)
except json.JSONDecodeError:
pass
# Strategy 2: Try extracting from markdown code blocks
json_code_block_pattern = r'```json\s*(\{.*?\})\s*```'
json_match = re.search(json_code_block_pattern, content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# Also try any code block (not just json tagged)
code_block_pattern = r'```\s*(\{.*?\})\s*```'
code_block_match = re.search(code_block_pattern, content, re.DOTALL)
if code_block_match:
try:
return json.loads(code_block_match.group(1))
except json.JSONDecodeError:
pass
# Strategy 3: Try extracting any {...} block
# Find balanced braces
brace_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
json_blocks = re.findall(brace_pattern, content, re.DOTALL)
for json_block in json_blocks:
try:
return json.loads(json_block)
except json.JSONDecodeError:
continue
# All strategies failed
logger.error(f"无法解析 Claude 返回的 JSON: {content[:200]}")
raise ValueError("无法解析 Claude 返回的 JSON请检查响应格式")
async def call_claude(
self,
messages: list[dict[str, str]],
temperature: float | None = None
) -> str:
"""Call Claude API.
Args:
messages: List of message dictionaries with 'role' and 'content'
temperature: Optional temperature override (0-2)
Returns:
The text content from Claude's response
Raises:
Exception: If the API call fails
"""
response = await self.client.messages.create(
model=self.model,
max_tokens=self.max_tokens,
temperature=temperature if temperature is not None else self.temperature,
messages=messages
)
return response.content[0].text

View File

@@ -0,0 +1,78 @@
from typing import Optional
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
from contextlib import contextmanager
from loguru import logger
from app.config import get_settings
class DatabaseEngine:
"""数据库操作引擎"""
def __init__(self):
settings = get_settings()
self.engine = create_engine(
settings.DATABASE_URL,
pool_size=20,
max_overflow=10,
pool_pre_ping=True,
echo=settings.DEBUG
)
self.Session = sessionmaker(bind=self.engine)
@contextmanager
def get_session(self):
"""获取数据库会话(上下文管理器)"""
session = self.Session()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作失败:{e}")
raise
finally:
session.close()
def execute_sql(self, sql: str, params: Optional[dict] = None) -> list:
"""执行单条 SQL"""
with self.get_session() as session:
result = session.execute(text(sql), params or {})
return result.fetchall()
def execute_transaction(self, sql_list: list, params_list: Optional[list] = None) -> bool:
"""执行事务(多条 SQL"""
params_list = params_list or [None] * len(sql_list)
with self.get_session() as session:
for sql, params in zip(sql_list, params_list):
session.execute(text(sql), params or {})
return True
def get_table_structure(self, table_name: str):
"""获取表结构(安全参数化查询)"""
sql = """
SELECT
COLUMN_NAME,
DATA_TYPE,
CHARACTER_MAXIMUM_LENGTH,
IS_NULLABLE,
COLUMN_DEFAULT
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME = :table_name
ORDER BY ORDINAL_POSITION
"""
return self.execute_sql(sql, {"table_name": table_name})
def table_exists(self, table_name: str) -> bool:
"""检查表是否存在(安全参数化查询)"""
sql = """
SELECT COUNT(*)
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_NAME = :table_name
"""
result = self.execute_sql(sql, {"table_name": table_name})
return result[0][0] > 0
def dispose(self):
"""关闭连接池,释放资源"""
self.engine.dispose()

View File

@@ -0,0 +1,147 @@
"""Config Executor for ERP AI Assistant.
This module provides the ConfigExecutor class for validating and executing
SQL configuration statements with safety checks.
"""
import re
from typing import List, Tuple, Dict, Any
from loguru import logger
from app.core.db_engine import DatabaseEngine
class ConfigExecutor:
"""Executor for SQL configuration statements with safety validation.
This class validates SQL statements against dangerous operations before
execution and provides transaction-based execution with rollback support.
"""
# Dangerous SQL keywords that should be blocked
DANGEROUS_KEYWORDS = [
r"DROP\s+DATABASE",
r"DROP\s+TABLE",
r"TRUNCATE\s+TABLE",
r"DELETE\s+FROM",
r"UPDATE\s+.*\s+SET",
r"ALTER\s+TABLE\s+.*\s+DROP"
]
def __init__(self) -> None:
"""Initialize executor with database engine."""
self.db_engine = DatabaseEngine()
logger.info("ConfigExecutor initialized")
def validate_sql(self, sql: str) -> Tuple[bool, str]:
"""Validate SQL statement for safety.
Checks SQL against a list of dangerous keywords/patterns to prevent
destructive operations.
Args:
sql: SQL statement to validate
Returns:
Tuple of (is_valid, message) where is_valid indicates if SQL is safe
"""
if not sql or not sql.strip():
return False, "SQL语句为空"
sql_upper = sql.upper().strip()
# Check for dangerous operations
for pattern in self.DANGEROUS_KEYWORDS:
if re.search(pattern, sql_upper):
# Extract matched keyword for error message
match = re.search(pattern, sql_upper)
matched_keyword = match.group(0) if match else pattern
logger.warning(f"SQL validation failed: dangerous operation '{matched_keyword}' detected")
return False, f"危险操作被拦截: {matched_keyword}"
logger.debug(f"SQL validation passed: {sql[:50]}...")
return True, "SQL验证通过"
def execute_config(
self,
sql_list: List[str],
session_id: str
) -> Dict[str, Any]:
"""Execute a list of SQL statements in a transaction.
Validates all SQL statements before execution. If any validation fails,
no statements are executed.
Args:
sql_list: List of SQL statements to execute
session_id: Session ID for logging and tracking
Returns:
Dictionary containing:
- success: Boolean indicating overall success
- executed: List of executed SQL statements
- failed: Error message if execution failed, None otherwise
- message: Human-readable result message
"""
logger.info(f"[{session_id}] Starting config execution with {len(sql_list)} SQL statements")
results: Dict[str, Any] = {
"success": True,
"executed": [],
"failed": None,
"message": ""
}
try:
# Step 1: Validate all SQL statements
logger.debug(f"[{session_id}] Validating {len(sql_list)} SQL statements")
for i, sql in enumerate(sql_list):
is_valid, msg = self.validate_sql(sql)
if not is_valid:
error_msg = f"SQL #{i+1} 验证失败: {msg}"
logger.error(f"[{session_id}] {error_msg}")
raise ValueError(error_msg)
# Step 2: Execute transaction
logger.debug(f"[{session_id}] Executing transaction")
self.db_engine.execute_transaction(sql_list)
# Step 3: Record success
results["executed"] = sql_list
results["message"] = f"成功执行 {len(sql_list)} 条SQL"
logger.success(f"[{session_id}] {results['message']}")
except ValueError as e:
# Validation failure
results["success"] = False
results["failed"] = str(e)
results["message"] = f"执行失败: {e}"
logger.error(f"[{session_id}] {results['message']}")
except Exception as e:
# Execution failure
results["success"] = False
results["failed"] = str(e)
results["message"] = f"执行失败: {e}"
logger.error(f"[{session_id}] {results['message']}")
return results
def rollback(self, session_id: str) -> Dict[str, Any]:
"""Rollback executed operations for a session.
This is a placeholder for rollback functionality. Actual implementation
would require recording inverse SQL operations during execution.
Args:
session_id: Session ID to rollback
Returns:
Dictionary with success status and message
"""
logger.warning(f"[{session_id}] Rollback requested but not yet implemented")
return {
"success": False,
"message": "回滚功能待实现"
}

144
backend/app/core/prompts.py Normal file
View File

@@ -0,0 +1,144 @@
"""
Prompt 模板定义
模板说明:
- SYSTEM_PROMPT: 系统提示词,定义 Claude 的角色和专业领域
- ANALYZE_PROMPT_TEMPLATE: 需求解析模板占位符user_input, knowledge_context, existing_tables
- GENERATE_PROMPT_TEMPLATE: 配置生成模板占位符requirements, platform_rules, similar_cases
"""
SYSTEM_PROMPT = """你是一个 ERP 平台配置专家助手,专门帮助开发人员配置一零软件结构化开发平台。
## 你的职责
你是 ERP 系统配置和开发的专业顾问,负责:
1. 分析用户需求,理解业务场景
2. 设计合理的数据库表结构
3. 生成符合平台规范的配置方案
4. 提供完整的 SQL 脚本和配置说明
## 平台知识
你熟悉以下平台概念:
- 窗体类型:基础资料、单据、报表、系统设置等
- 标准字段命名规范F 开头的主键、FPrefix 前缀的自定义字段
- 配置流程:需求分析 → 表结构设计 → 功能配置 → 页面配置 → 菜单配置
- 命名约定:表名以 T_开头功能号以功能类别前缀开头
## 输出要求
1. 提供完整的 SQL 脚本,包括建表语句、函数配置、页面配置等
2. 确保配置符合平台规范和最佳实践
3. 进行风险评估,提示潜在问题
4. 使用 JSON 格式输出结构化结果
5. 所有字段和表名使用英文,注释使用中文
请始终保持专业、严谨的工作态度,确保输出的配置方案可落地执行。"""
ANALYZE_PROMPT_TEMPLATE = """请分析以下用户需求,生成结构化的需求分析文档。
## 用户输入
{user_input}
## 相关知识上下文
{knowledge_context}
## 现有表结构
{existing_tables}
## 分析要求
请输出结构化的需求分析文档,使用 JSON 格式,包含以下字段:
# Note: Use {{ and }} to escape braces for .format() - rendered as literal { and }
```json
{{
"功能名称": "功能的中文名称",
"功能号建议": "建议的功能编号,如 SAL001",
"窗体类型": "基础资料/单据/报表/系统设置",
"主表名建议": "建议的主表名,如 T_SAL_Order",
"从表名建议": "建议的从表名,如 T_SAL_OrderEntry",
"主表字段": [
{{"字段名": "FOrderId", "字段类型": "varchar(50)", "中文名称": "订单编号", "必填": true}},
...
],
"从表字段": [
{{"字段名": "FEntryId", "字段类型": "int", "中文名称": "分录 ID", "必填": true}},
...
],
"业务需求": "详细的业务需求描述",
"关联表": ["相关表 1", "相关表 2"],
"风险提示": ["潜在风险 1", "潜在风险 2"]
}}
```
## 注意事项
1. 字段命名遵循平台规范:主键以 F 开头,使用 PascalCase
2. 表名以 T_开头使用模块前缀
3. 考虑必填字段、默认值、数据长度等约束
4. 识别必要的业务关联关系
5. 评估潜在的数据一致性和性能风险"""
GENERATE_PROMPT_TEMPLATE = """请根据需求分析结果,生成完整的平台配置方案。
## 需求分析结果
{requirements}
## 平台规则
{platform_rules}
## 类似案例参考
{similar_cases}
## 生成要求
请生成完整的配置方案,使用 JSON 格式,包含以下内容:
# Note: Use {{ and }} to escape braces for .format() - rendered as literal { and }
```json
{{
"table_sql": "建表 SQL 语句,包括主表和从表",
"function_config_sql": "功能配置 SQL 语句",
"page_config_sql": "页面配置 SQL 语句",
"menu_config_sql": "菜单配置 SQL 语句",
"ikey_config_sql": "IKEY 配置 SQL 语句",
"config_summary": {{
"created_tables": ["表 1", "表 2"],
"main_entities": ["主要实体 1", "主要实体 2"],
"relationships": "表间关系说明"
}},
"implementation_notes": "实施注意事项",
"validation_rules": ["验证规则 1", "验证规则 2"]
}}
```
## 配置规范
1. **建表 SQL**:
- 主键使用 FId 或 F+ 表名缩写 + Id
- 包含创建时间、创建人、更新时间、更新人等审计字段
- 使用合适的索引提高查询性能
2. **功能配置**:
- 定义功能号、功能名称、功能类型
- 配置数据权限和操作权限
3. **页面配置**:
- 配置表单布局、字段顺序
- 设置字段属性(必填、只读、可见性)
4. **菜单配置**:
- 配置菜单层级、图标、排序
5. **IKEY 配置**:
- 配置编码规则、生成策略
## 注意事项
- 所有 SQL 语句需要语法正确、可直接执行
- 配置需要符合平台规范
- 考虑扩展性和维护性
- 提供必要的注释说明"""

View File

@@ -0,0 +1,251 @@
"""RAG Engine for ERP AI Assistant.
This module provides the RAGEngine class that handles knowledge document
storage and retrieval using ChromaDB and sentence-transformers embeddings.
"""
from typing import Optional
import chromadb
from chromadb.config import Settings as ChromaSettings
from sentence_transformers import SentenceTransformer
from loguru import logger
from app.config import get_settings
class RAGEngine:
"""RAG Engine for knowledge document retrieval.
This class wraps ChromaDB vector database and sentence-transformers
to provide semantic search over knowledge documents.
"""
# Class-level singleton for embedding model (lazy loading)
_embedding_model: Optional[SentenceTransformer] = None
def __init__(self) -> None:
"""Initialize RAG engine with ChromaDB and embedding model."""
settings = get_settings()
# Initialize ChromaDB persistent client
logger.info(f"Initializing ChromaDB at: {settings.CHROMA_DB_PATH}")
self.chroma_client = chromadb.PersistentClient(
path=settings.CHROMA_DB_PATH,
settings=ChromaSettings(anonymized_telemetry=False)
)
# Load sentence-transformers embedding model (lazy loading, singleton)
logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
self.embedding_model = self._get_embedding_model(settings.EMBEDDING_MODEL)
# Get or create documents collection
self.documents_collection = self.chroma_client.get_or_create_collection(
name="documents"
)
# Store chunking settings
self.chunk_size = settings.CHUNK_SIZE
self.chunk_overlap = settings.CHUNK_OVERLAP
logger.info(
f"RAG Engine initialized: chunk_size={self.chunk_size}, "
f"chunk_overlap={self.chunk_overlap}"
)
@classmethod
def _get_embedding_model(cls, model_name: str) -> SentenceTransformer:
"""Get or create the embedding model (lazy loading, singleton).
Args:
model_name: Name of the embedding model to load
Returns:
SentenceTransformer embedding model instance
"""
if cls._embedding_model is None:
logger.info(f"Loading embedding model: {model_name}")
cls._embedding_model = SentenceTransformer(model_name)
return cls._embedding_model
def _split_text(self, text: str) -> list[str]:
"""Split text into overlapping chunks.
Args:
text: The text to split
Returns:
List of chunk strings
"""
if not text:
return []
chunks = []
start = 0
text_length = len(text)
while start < text_length:
end = start + self.chunk_size
chunk = text[start:end]
if chunk.strip(): # Only add non-empty chunks
chunks.append(chunk)
start += self.chunk_size - self.chunk_overlap
# Avoid infinite loop if overlap >= chunk_size
if self.chunk_overlap >= self.chunk_size:
start += 1
return chunks
def _delete_chunks_for_doc(self, doc_id: str) -> None:
"""Delete all chunks associated with a document.
Args:
doc_id: The document ID to delete chunks for
"""
try:
# Find all chunks for this document
results = self.documents_collection.get(
where={"doc_id": doc_id},
include=[]
)
if results and results.get("ids"):
self.documents_collection.delete(ids=results["ids"])
logger.debug(f"Deleted {len(results['ids'])} chunks for document '{doc_id}'")
except Exception as e:
logger.warning(f"Failed to delete chunks for document '{doc_id}': {e}")
def add_document(
self,
doc_id: str,
content: str,
metadata: Optional[dict] = None
) -> int:
"""Add a document to the knowledge base.
Args:
doc_id: Unique identifier for the document
content: The document content to index
metadata: Optional metadata dict to store with the document
Returns:
Number of chunks added
Raises:
ValueError: If content is empty
"""
if not content or not content.strip():
raise ValueError("Cannot add empty document")
try:
# Delete existing chunks for this doc_id (handles duplicates)
self._delete_chunks_for_doc(doc_id)
# Split content into chunks
chunks = self._split_text(content)
logger.info(f"Split document '{doc_id}' into {len(chunks)} chunks")
if not chunks:
return 0
# Generate embeddings for all chunks
logger.debug(f"Generating embeddings for {len(chunks)} chunks")
embeddings = self.embedding_model.encode(chunks)
# Prepare chunk IDs and metadata
chunk_ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
# Add metadata to each chunk
chunk_metadata = []
base_metadata = metadata or {}
for i, chunk in enumerate(chunks):
meta = {
**base_metadata,
"doc_id": doc_id,
"chunk_index": i,
"total_chunks": len(chunks)
}
chunk_metadata.append(meta)
# Add to ChromaDB
self.documents_collection.add(
ids=chunk_ids,
embeddings=embeddings.tolist(),
documents=chunks,
metadatas=chunk_metadata
)
logger.info(f"Added {len(chunks)} chunks for document '{doc_id}'")
return len(chunks)
except Exception as e:
logger.error(f"Failed to add document '{doc_id}': {e}")
raise
def search(self, query: str, top_k: int = 3) -> list[dict]:
"""Search for relevant document chunks.
Args:
query: The search query
top_k: Number of results to return (default: 3, max: 100)
Returns:
List of dicts with 'content', 'metadata', and 'distance'
Raises:
ValueError: If top_k exceeds maximum limit
"""
# Validate top_k
if top_k > 100:
raise ValueError(f"top_k cannot exceed 100 (got: {top_k})")
if not query or not query.strip():
logger.warning("Empty search query provided")
return []
try:
# Generate embedding for query
logger.debug(f"Generating embedding for query: {query[:50]}...")
query_embedding = self.embedding_model.encode([query])
# Query ChromaDB
results = self.documents_collection.query(
query_embeddings=query_embedding.tolist(),
n_results=top_k,
include=["documents", "metadatas", "distances"]
)
# Format results
formatted_results = []
if results and results.get("documents"):
documents = results["documents"][0]
metadatas = results["metadatas"][0] if results.get("metadatas") else []
distances = results["distances"][0] if results.get("distances") else []
for i, content in enumerate(documents):
formatted_results.append({
"content": content,
"metadata": metadatas[i] if i < len(metadatas) else {},
"distance": distances[i] if i < len(distances) else None
})
logger.info(f"Found {len(formatted_results)} results for query")
return formatted_results
except Exception as e:
logger.error(f"Search failed: {e}")
raise
def close(self) -> None:
"""Release resources and cleanup the RAG engine.
This method should be called when the engine is no longer needed
to free up memory and other resources.
"""
logger.info("Closing RAG engine and releasing resources")
self.embedding_model = None
self.documents_collection = None
self.chroma_client = None

78
backend/app/main.py Normal file
View File

@@ -0,0 +1,78 @@
"""FastAPI application entry point for ERP AI Assistant.
This module creates and configures the main FastAPI application instance.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import get_settings
from app.api import analyze, generate, execute
# Get application settings
settings = get_settings()
# Create FastAPI application instance
app = FastAPI(
title=settings.APP_NAME,
version="1.0.0",
debug=settings.DEBUG,
description="AI-powered assistant for ERP platform configuration"
)
# Configure CORS middleware for frontend communication
# For development: allow all origins with port 5173
# For production: configure specific origins in environment
app.add_middleware(
CORSMiddleware,
allow_origins=["*"] if settings.DEBUG else [
"http://localhost:5173",
"http://127.0.0.1:5173",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Register API routers
app.include_router(analyze.router, prefix="/api/v1", tags=["Analysis"])
app.include_router(generate.router, prefix="/api/v1", tags=["Generation"])
app.include_router(execute.router, prefix="/api/v1", tags=["Execution"])
@app.get("/", tags=["Root"])
async def root() -> dict:
"""Root endpoint returning application info.
Returns:
Dictionary with application name, version, and status
"""
return {
"message": settings.APP_NAME,
"version": "1.0.0",
"status": "running"
}
@app.get("/health", tags=["Health"])
async def health_check() -> dict:
"""Health check endpoint for monitoring.
Returns:
Dictionary with health status
"""
return {
"status": "healthy",
"version": "1.0.0"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=settings.DEBUG
)

View File

@@ -0,0 +1 @@
"""Pydantic models for ERP AI Assistant."""

View File

@@ -0,0 +1,89 @@
"""Request models for ERP AI Assistant API.
This module defines Pydantic models for API request validation.
"""
from typing import Optional
from pydantic import BaseModel, Field
class AnalyzeRequest(BaseModel):
"""Request model for requirement analysis.
Attributes:
input_type: Type of input - 'natural_language' or 'structured'
content: Requirement content text
session_id: Optional session ID for context continuity
"""
input_type: str = Field(
...,
description="输入类型: natural_language | structured"
)
content: str = Field(..., description="需求内容")
session_id: Optional[str] = Field(None, description="会话ID")
model_config = {
"json_schema_extra": {
"examples": [
{
"input_type": "natural_language",
"content": "创建一个销售订单管理页面",
"session_id": "session-123"
}
]
}
}
class GenerateRequest(BaseModel):
"""Request model for config generation.
Attributes:
session_id: Session ID from previous analysis
requirements: Structured requirements from analysis
"""
session_id: str = Field(..., description="会话ID")
requirements: dict = Field(..., description="结构化需求")
model_config = {
"json_schema_extra": {
"examples": [
{
"session_id": "session-123",
"requirements": {
"功能名称": "销售订单管理",
"功能类型": "列表页面"
}
}
]
}
}
class ExecuteRequest(BaseModel):
"""Request model for config execution.
Attributes:
session_id: Session ID for tracking
confirmed: User confirmation flag
backup_enabled: Whether to create backup before execution
"""
session_id: str = Field(..., description="会话ID")
confirmed: bool = Field(False, description="用户确认标识")
backup_enabled: bool = Field(True, description="是否启用备份")
model_config = {
"json_schema_extra": {
"examples": [
{
"session_id": "session-123",
"confirmed": True,
"backup_enabled": True
}
]
}
}

View File

@@ -0,0 +1,133 @@
"""Response models for ERP AI Assistant API.
This module defines Pydantic models for API response formatting.
"""
from typing import Any, Optional
from pydantic import BaseModel, Field
class AnalyzeResponse(BaseModel):
"""Response model for requirement analysis.
Attributes:
session_id: Session ID for this analysis
status: Processing status
data: Structured requirement analysis result
"""
session_id: str = Field(..., description="会话ID")
status: str = Field(..., description="处理状态")
data: dict = Field(..., description="结构化需求分析结果")
model_config = {
"json_schema_extra": {
"examples": [
{
"session_id": "session-123",
"status": "success",
"data": {
"功能名称": "销售订单管理",
"功能类型": "列表页面"
}
}
]
}
}
class GenerateResponse(BaseModel):
"""Response model for config generation.
Attributes:
session_id: Session ID
status: Processing status
data: Generated SQL configuration
"""
session_id: str = Field(..., description="会话ID")
status: str = Field(..., description="处理状态")
data: dict = Field(..., description="生成的SQL配置")
model_config = {
"json_schema_extra": {
"examples": [
{
"session_id": "session-123",
"status": "success",
"data": {
"sql_list": ["INSERT INTO SYS_FORM ...", "INSERT INTO SYS_MENU ..."]
}
}
]
}
}
class ExecuteResponse(BaseModel):
"""Response model for config execution.
Attributes:
execution_id: Unique execution ID
status: Execution status
message: Human-readable result message
"""
execution_id: str = Field(..., description="执行ID")
status: str = Field(..., description="执行状态")
message: str = Field(..., description="执行结果消息")
model_config = {
"json_schema_extra": {
"examples": [
{
"execution_id": "exec-456",
"status": "success",
"message": "成功执行 5 条SQL"
}
]
}
}
class ErrorResponse(BaseModel):
"""Response model for errors.
Attributes:
error: Error details dictionary
"""
error: dict = Field(..., description="错误详情")
model_config = {
"json_schema_extra": {
"examples": [
{
"error": {
"code": "VALIDATION_ERROR",
"message": "Invalid input",
"details": "Field 'input_type' must be 'natural_language' or 'structured'"
}
}
]
}
}
class HealthResponse(BaseModel):
"""Response model for health check.
Attributes:
status: Service health status
version: Application version
"""
status: str = Field(..., description="服务状态")
version: str = Field(default="1.0.0", description="版本号")
model_config = {
"json_schema_extra": {
"examples": [{"status": "healthy", "version": "1.0.0"}]
}
}

View File

@@ -0,0 +1 @@
"""Service modules for ERP AI Assistant."""

View File

@@ -0,0 +1,135 @@
"""Config Generation Service.
This module provides the ConfigService class for generating ERP platform
configuration SQL based on structured requirements.
"""
from typing import Dict, Any
from loguru import logger
from app.core.ai_engine import ClaudeEngine
from app.core.rag_engine import RAGEngine
from app.core.prompts import SYSTEM_PROMPT, GENERATE_PROMPT_TEMPLATE
from app.core.db_engine import DatabaseEngine
class ConfigService:
"""Service for generating ERP platform configuration.
This service uses Claude AI with RAG knowledge retrieval to generate
SQL configuration statements based on structured requirements.
"""
def __init__(self) -> None:
"""Initialize config service with required engines."""
self.ai_engine = ClaudeEngine()
self.rag_engine = RAGEngine()
self.db_engine = DatabaseEngine()
logger.info("ConfigService initialized")
async def generate(
self,
requirements: Dict[str, Any],
session_id: str
) -> Dict[str, Any]:
"""Generate configuration SQL based on requirements.
Args:
requirements: Structured requirement specification
session_id: Session ID for tracking
Returns:
Configuration plan with SQL statements
Raises:
ValueError: If requirements are invalid
Exception: If generation fails
"""
if not requirements:
raise ValueError("Requirements cannot be empty")
function_name = requirements.get("功能名称", "Unknown")
logger.info(f"[{session_id}] Starting config generation for: {function_name}")
try:
# Step 1: Retrieve platform rules for form type
form_type = requirements.get("窗体类型", "0")
logger.debug(f"[{session_id}] Retrieving platform rules for form type: {form_type}")
platform_rules = self._get_platform_rules(form_type)
logger.info(f"[{session_id}] Retrieved platform rules")
# Step 2: Retrieve similar cases
logger.debug(f"[{session_id}] Retrieving similar cases")
similar_cases = self._get_similar_cases(function_name)
logger.info(f"[{session_id}] Retrieved similar cases")
# Step 3: Build prompt
prompt = GENERATE_PROMPT_TEMPLATE.format(
requirements=str(requirements),
platform_rules=platform_rules,
similar_cases=similar_cases
)
messages = [
{"role": "user", "content": SYSTEM_PROMPT},
{"role": "assistant", "content": "我已了解,请提供需求信息。"},
{"role": "user", "content": prompt}
]
# Step 4: Call Claude API
logger.debug(f"[{session_id}] Calling Claude API for config generation")
response = await self.ai_engine.call_claude(messages, temperature=0.5)
# Step 5: Parse JSON response
result = self.ai_engine.parse_json_response(response)
logger.success(f"[{session_id}] Config generation completed")
return result
except Exception as e:
logger.error(f"[{session_id}] Config generation failed: {e}")
raise
def _get_platform_rules(self, form_type: str) -> str:
"""Retrieve platform configuration rules for specific form type.
Args:
form_type: Form type code
Returns:
Platform rules text
"""
try:
results = self.rag_engine.search(
f"窗体类型{form_type}配置规则",
top_k=2
)
if not results:
return "未找到相关配置规则"
return "\n\n".join([r["content"] for r in results])
except Exception as e:
logger.warning(f"Failed to retrieve platform rules: {e}")
return "无法获取平台配置规则"
def _get_similar_cases(self, keywords: str) -> str:
"""Retrieve similar configuration cases from knowledge base.
Args:
keywords: Search keywords
Returns:
Similar cases text
"""
try:
results = self.rag_engine.search(keywords, top_k=2)
if not results:
return "未找到相似案例"
return "\n\n".join([r["content"] for r in results])
except Exception as e:
logger.warning(f"Failed to retrieve similar cases: {e}")
return "无法获取相似案例"

View File

@@ -0,0 +1,147 @@
"""Requirement Analysis Service.
This module provides the RequirementService class for analyzing user requirements
using Claude AI with RAG knowledge retrieval.
"""
from typing import Optional, Dict, Any
import uuid
from loguru import logger
from app.core.ai_engine import ClaudeEngine
from app.core.rag_engine import RAGEngine
from app.core.prompts import SYSTEM_PROMPT, ANALYZE_PROMPT_TEMPLATE
from app.core.db_engine import DatabaseEngine
class RequirementService:
"""Service for analyzing user requirements with AI assistance.
This service integrates Claude AI, RAG knowledge retrieval, and database
metadata to provide comprehensive requirement analysis.
"""
def __init__(self) -> None:
"""Initialize requirement service with required engines."""
self.ai_engine = ClaudeEngine()
self.rag_engine = RAGEngine()
self.db_engine = DatabaseEngine()
logger.info("RequirementService initialized")
async def analyze(
self,
user_input: str,
session_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user requirement and generate structured specification.
Args:
user_input: Natural language requirement from user
session_id: Session ID for context management (auto-generated if None)
Returns:
Structured requirement document as dictionary
Raises:
ValueError: If user_input is empty
Exception: If AI analysis fails
"""
# Validate input
if not user_input or not user_input.strip():
raise ValueError("User input cannot be empty")
# Generate session ID if not provided
session_id = session_id or str(uuid.uuid4())
logger.info(f"[{session_id}] Starting requirement analysis: {user_input[:50]}...")
try:
# Step 1: Retrieve relevant knowledge from RAG
logger.debug(f"[{session_id}] Searching knowledge base")
knowledge_results = self.rag_engine.search(user_input, top_k=3)
knowledge_context = self._format_knowledge_context(knowledge_results)
logger.info(f"[{session_id}] Retrieved {len(knowledge_results)} knowledge chunks")
# Step 2: Query existing database tables
logger.debug(f"[{session_id}] Querying existing tables")
existing_tables = self._get_existing_tables(user_input)
logger.info(f"[{session_id}] Retrieved existing table information")
# Step 3: Build prompt
prompt = ANALYZE_PROMPT_TEMPLATE.format(
user_input=user_input,
knowledge_context=knowledge_context,
existing_tables=existing_tables
)
messages = [
{"role": "user", "content": SYSTEM_PROMPT},
{"role": "assistant", "content": "我已了解平台配置规范,请告诉我您的需求。"},
{"role": "user", "content": prompt}
]
# Step 4: Call Claude API
logger.debug(f"[{session_id}] Calling Claude API")
response = await self.ai_engine.call_claude(messages, temperature=0.7)
# Step 5: Parse JSON response
result = self.ai_engine.parse_json_response(response)
function_name = result.get("功能名称", "Unknown")
logger.success(f"[{session_id}] Requirement analysis completed: {function_name}")
return result
except Exception as e:
logger.error(f"[{session_id}] Requirement analysis failed: {e}")
raise
def _format_knowledge_context(self, knowledge_results: list) -> str:
"""Format knowledge search results into context string.
Args:
knowledge_results: List of knowledge search results
Returns:
Formatted knowledge context string
"""
if not knowledge_results:
return "未找到相关知识库内容"
context_parts = []
for result in knowledge_results:
source = result.get("metadata", {}).get("source", "文档")
content = result.get("content", "")
if content:
context_parts.append(f"{source}\n{content}")
return "\n\n".join(context_parts) if context_parts else "未找到相关知识库内容"
def _get_existing_tables(self, user_input: str) -> str:
"""Query existing database tables relevant to user input.
Args:
user_input: User requirement text
Returns:
Formatted string listing existing tables
"""
try:
# Query top 10 tables (simplified version - could be enhanced with relevance matching)
sql = """
SELECT TOP 10 TABLE_NAME
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_TYPE = 'BASE TABLE'
ORDER BY TABLE_NAME
"""
tables = self.db_engine.execute_sql(sql)
if not tables:
return "未找到现有数据表"
table_list = [f"- {t[0]}" for t in tables]
return "\n".join(table_list)
except Exception as e:
logger.warning(f"Failed to query existing tables: {e}")
return "无法获取现有表信息"

View File

@@ -0,0 +1,31 @@
#!/bin/bash
# 一键修复依赖问题脚本
echo "🔧 开始修复依赖问题..."
cd backend
# 激活虚拟环境
if [ -d "venv" ]; then
echo "✓ 找到虚拟环境"
source venv/bin/activate
else
echo "✗ 未找到虚拟环境,正在创建..."
python3 -m venv venv
source venv/bin/activate
fi
echo "📦 卸载冲突的包..."
pip uninstall -y numpy sentence-transformers huggingface-hub
echo "📥 重新安装所有依赖..."
pip install -r requirements.txt
echo "✅ 验证安装..."
python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"
python -c "import sentence_transformers; print(f'SentenceTransformers installed successfully')"
python -c "import chromadb; print(f'ChromaDB installed successfully')"
echo ""
echo "✨ 修复完成!现在可以运行后端服务:"
echo " python -m app.main"

8
backend/pytest.ini Normal file
View File

@@ -0,0 +1,8 @@
[pytest]
asyncio_mode = auto
asyncio_default_fixture_loop_scope = function
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --cov=app --cov-report=term-missing

21
backend/requirements.txt Normal file
View File

@@ -0,0 +1,21 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
sqlalchemy==2.0.23
pyodbc==5.0.1
anthropic==0.18.1
chromadb==0.4.18
sentence-transformers==2.7.0
pydantic==2.5.0
pydantic-settings==2.1.0
python-dotenv==1.0.0
loguru==0.7.2
tenacity==8.2.3
python-jose[cryptography]==3.3.0
pytest==7.4.3
pytest-asyncio==0.21.1
httpx==0.25.2
pytest-cov==4.1.0
pytest-mock==3.12.0
# Fix NumPy compatibility issue with chromadb
numpy<2.0.0

View File

@@ -0,0 +1,178 @@
#!/usr/bin/env python3
"""Initialize knowledge base with sample documents.
This script adds sample knowledge documents to the RAG engine
for the ERP AI Assistant.
"""
import sys
from pathlib import Path
# NOTE: Development workaround to enable direct script execution.
# For production, use: python -m backend.scripts.init_knowledge
sys.path.insert(0, str(Path(__file__).parent.parent))
from loguru import logger
from app.core.rag_engine import RAGEngine
# Sample document: Platform basics
PLATFORM_BASICS_CONTENT = """
# ERP 平台基础知识
## 窗体类型 (Form Types)
ERP 系统支持以下几种窗体类型:
1. **标准窗体 (Standard Form)**
- 用于单一数据实体的 CRUD 操作
- 包含字段、按钮、表格等基本控件
- 适用于简单的数据录入和查询场景
2. **列表窗体 (List Form)**
- 用于展示多条记录的列表
- 支持排序、筛选、分页功能
- 可配置列显示和隐藏
3. **报表窗体 (Report Form)**
- 用于生成统计报表
- 支持图表展示
- 可导出 Excel、PDF 格式
4. **流程窗体 (Workflow Form)**
- 用于业务流程处理
- 包含审批、流转、会签等功能
- 支持流程状态跟踪
## 标准字段 (Standard Fields)
系统预定义以下标准字段类型:
1. **文本字段**
- ShortText: 短文本 (最多 255 字符)
- LongText: 长文本 (最多 4000 字符)
- Memo: 备注文本 (不限长度)
2. **数值字段**
- Integer: 整数
- Decimal: 小数 (可配置精度)
- Currency: 货币 (带币种符号)
3. **日期字段**
- Date: 日期
- DateTime: 日期时间
- Time: 时间
4. **选择字段**
- Dropdown: 下拉选择
- Radio: 单选
- Checkbox: 复选框
- MultiSelect: 多选
5. **关联字段**
- Lookup: 查找关联
- Reference: 引用关联
- Master-Detail: 主从关联
## 配置流程 (Configuration Process)
### 1. 需求分析
- 明确业务场景
- 确定窗体类型
- 梳理字段清单
### 2. 窗体设计
- 创建新窗体
- 配置窗体属性
- 添加字段控件
### 3. 字段配置
- 选择字段类型
- 设置字段属性 (必填、只读、默认值等)
- 配置验证规则
### 4. 权限设置
- 配置角色权限
- 设置数据访问范围
- 配置操作权限
### 5. 测试验证
- 功能测试
- 权限测试
- 性能测试
### 6. 发布上线
- 提交发布申请
- 通过审批流程
- 正式发布
## 常用术语 (Common Terms)
- **窗体 (Form)**: 用户界面的基本单元,用于数据展示和操作
- **字段 (Field)**: 窗体中的数据项,对应数据库列
- **控件 (Control)**: 窗体上的可视化元素
- **数据源 (Data Source)**: 窗体绑定的数据表或查询
- **动作 (Action)**: 窗体上的操作按钮
- **验证 (Validation)**: 数据输入的合法性检查
- **权限 (Permission)**: 用户对资源的访问控制
- **工作流 (Workflow)**: 业务流程的自动化流转
## 最佳实践 (Best Practices)
1. **字段命名规范**
- 使用英文命名,遵循下划线分隔
- 字段名应清晰表达业务含义
- 避免使用系统保留字
2. **性能优化**
- 列表窗体配置合理的分页大小
- 为常用查询字段建立索引
- 避免在窗体中加载过多数据
3. **用户体验**
- 必填字段应明确标识
- 提供清晰的错误提示
- 常用操作应放在明显位置
4. **安全性**
- 敏感数据应设置访问权限
- 用户输入应进行验证
- 定期审计权限配置
"""
def main() -> None:
"""Initialize knowledge base with sample documents."""
logger.info("Starting knowledge base initialization...")
try:
# Initialize RAG engine
logger.info("Initializing RAG engine...")
rag = RAGEngine()
# Add platform basics document
logger.info("Adding platform basics document...")
metadata = {
"title": "平台基础知识",
"category": "platform",
"language": "zh-CN",
"version": "1.0"
}
chunks_added = rag.add_document(
doc_id="platform_basics",
content=PLATFORM_BASICS_CONTENT,
metadata=metadata
)
logger.success(
f"Knowledge base initialized successfully! "
f"Added {chunks_added} chunks from 'platform_basics' document."
)
except Exception as e:
logger.error(f"Failed to initialize knowledge base: {e}")
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1 @@
"""Tests"""

26
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,26 @@
import pytest
from app.config import get_settings
@pytest.fixture
def test_settings():
"""Test settings"""
return get_settings()
@pytest.fixture
def mock_db_engine(mocker):
"""Mock database engine"""
from app.core.db_engine import DatabaseEngine
return mocker.MagicMock(spec=DatabaseEngine)
@pytest.fixture
def mock_ai_engine(mocker):
"""Mock AI engine with default parse_json_response behavior"""
from app.core.ai_engine import ClaudeEngine
mock_engine = mocker.MagicMock(spec=ClaudeEngine)
# Default behavior: returns a test function dict
# Can be overridden in individual tests via mock_ai_engine.parse_json_response.return_value = {...}
mock_engine.parse_json_response.return_value = {"function_name": "test_function"}
return mock_engine

View File

@@ -0,0 +1,148 @@
import pytest
from app.core.ai_engine import ClaudeEngine
@pytest.fixture
def mock_settings(mocker):
"""Mock settings for test isolation."""
mock_settings = mocker.MagicMock()
mock_settings.ANTHROPIC_API_KEY = "test-key"
mock_settings.CLAUDE_MODEL = "claude-sonnet-4-6"
mock_settings.CLAUDE_MAX_TOKENS = 1024
mock_settings.CLAUDE_TEMPERATURE = 0.7
mocker.patch('app.core.ai_engine.get_settings', return_value=mock_settings)
return mock_settings
@pytest.fixture
def mock_anthropic_client(mocker):
"""Mock Anthropic async client."""
mock_client = mocker.AsyncMock()
mocker.patch('app.core.ai_engine.anthropic.AsyncAnthropic', return_value=mock_client)
return mock_client
def test_claude_engine_init(mocker, mock_settings):
"""测试 Claude 引擎初始化"""
engine = ClaudeEngine()
assert engine.client is not None
assert engine.model == "claude-sonnet-4-6"
assert engine.max_tokens == 1024
assert engine.temperature == 0.7
def test_parse_json_response(mocker, mock_settings):
"""测试 JSON 解析"""
engine = ClaudeEngine()
# 测试纯 JSON
json_str = '{"name": "test", "value": 123}'
result = engine.parse_json_response(json_str)
assert result["name"] == "test"
assert result["value"] == 123
# 测试 markdown 代码块
md_str = '```json\n{"name": "test"}\n```'
result = engine.parse_json_response(md_str)
assert result["name"] == "test"
def test_parse_json_response_empty_content(mocker, mock_settings):
"""测试空内容错误处理"""
engine = ClaudeEngine()
with pytest.raises(ValueError, match="Empty content provided"):
engine.parse_json_response("")
with pytest.raises(ValueError, match="Empty content provided"):
engine.parse_json_response(" ")
def test_parse_json_response_invalid_json(mocker, mock_settings):
"""测试无效 JSON 错误处理"""
engine = ClaudeEngine()
# 无效 JSON 且无法提取任何代码块
invalid_str = "This is not JSON at all"
with pytest.raises(ValueError, match="无法解析 Claude 返回的 JSON"):
engine.parse_json_response(invalid_str)
# 无效的 JSON 代码块
invalid_json_block = '```json\n{invalid json}\n```'
with pytest.raises(ValueError, match="无法解析 Claude 返回的 JSON"):
engine.parse_json_response(invalid_json_block)
def test_parse_json_response_code_block(mocker, mock_settings):
"""测试代码块 JSON 解析"""
engine = ClaudeEngine()
# 普通代码块(无 json 标签)
code_block = '```\n{"status": "ok"}\n```'
result = engine.parse_json_response(code_block)
assert result["status"] == "ok"
def test_parse_json_response_nested_json(mocker, mock_settings):
"""测试嵌套 JSON 解析"""
engine = ClaudeEngine()
# 带有一些额外文本的 JSON
text_with_json = 'Some text before {"key": "value"} and after'
result = engine.parse_json_response(text_with_json)
assert result["key"] == "value"
# 嵌套 JSON
nested_json = '{"outer": {"inner": "value"}}'
result = engine.parse_json_response(nested_json)
assert result["outer"]["inner"] == "value"
@pytest.mark.asyncio
async def test_call_claude(mocker, mock_settings, mock_anthropic_client):
"""测试 call_claude 方法"""
# 设置 mock 响应
mock_response = mocker.MagicMock()
mock_response.content = [mocker.MagicMock(text="Hello, I am Claude")]
mock_anthropic_client.messages.create.return_value = mock_response
engine = ClaudeEngine()
messages = [{"role": "user", "content": "Hello"}]
result = await engine.call_claude(messages)
assert result == "Hello, I am Claude"
mock_anthropic_client.messages.create.assert_called_once_with(
model="claude-sonnet-4-6",
max_tokens=1024,
temperature=0.7,
messages=messages
)
@pytest.mark.asyncio
async def test_call_claude_with_temperature(mocker, mock_settings, mock_anthropic_client):
"""测试 call_claude 带温度参数"""
mock_response = mocker.MagicMock()
mock_response.content = [mocker.MagicMock(text="Response")]
mock_anthropic_client.messages.create.return_value = mock_response
engine = ClaudeEngine()
messages = [{"role": "user", "content": "Hello"}]
result = await engine.call_claude(messages, temperature=1.5)
assert result == "Response"
call_args = mock_anthropic_client.messages.create.call_args
assert call_args.kwargs["temperature"] == 1.5
@pytest.mark.asyncio
async def test_call_claude_error(mocker, mock_settings, mock_anthropic_client):
"""测试 call_claude 错误处理"""
# 设置 mock 抛出异常
mock_anthropic_client.messages.create.side_effect = Exception("API Error")
engine = ClaudeEngine()
messages = [{"role": "user", "content": "Hello"}]
with pytest.raises(Exception, match="API Error"):
await engine.call_claude(messages)

View File

@@ -0,0 +1,94 @@
"""Tests for Config Service.
This module tests the ConfigService class for configuration generation.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.config_service import ConfigService
@pytest.mark.asyncio
async def test_generate_config():
"""Test config generation with mocked dependencies."""
with patch('app.services.config_service.ClaudeEngine') as MockClaudeEngine, \
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
patch('app.services.config_service.DatabaseEngine') as MockDBEngine:
# Setup mocks
mock_ai_engine = MagicMock()
mock_ai_engine.call_claude = AsyncMock(return_value='{"配置方案": {"sql_list": ["INSERT INTO SYS_FORM..."]}}')
mock_ai_engine.parse_json_response = MagicMock(return_value={
"配置方案": {
"sql_list": ["INSERT INTO SYS_FORM VALUES (...)"]
}
})
MockClaudeEngine.return_value = mock_ai_engine
mock_rag_engine = MagicMock()
mock_rag_engine.search = MagicMock(return_value=[
{"content": "Sample rule", "metadata": {}}
])
MockRAGEngine.return_value = mock_rag_engine
MockDBEngine.return_value = MagicMock()
# Create service and test
service = ConfigService()
requirements = {
"功能名称": "销售订单",
"功能号建议": "11-001",
"窗体类型": "5",
"主表名建议": "SA_ORDER",
"主表字段": [
{"字段名": "订单号", "字段类型": "varchar(50)", "必填": True}
]
}
result = await service.generate(requirements, "test-session")
assert result is not None
assert "配置方案" in result
@pytest.mark.asyncio
async def test_get_platform_rules():
"""Test platform rules retrieval."""
with patch('app.services.config_service.ClaudeEngine'), \
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
patch('app.services.config_service.DatabaseEngine'):
mock_rag_engine = MagicMock()
mock_rag_engine.search = MagicMock(return_value=[
{"content": "Rule 1"},
{"content": "Rule 2"}
])
MockRAGEngine.return_value = mock_rag_engine
service = ConfigService()
rules = service._get_platform_rules("5")
assert "Rule 1" in rules
assert "Rule 2" in rules
@pytest.mark.asyncio
async def test_get_similar_cases():
"""Test similar cases retrieval."""
with patch('app.services.config_service.ClaudeEngine'), \
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
patch('app.services.config_service.DatabaseEngine'):
mock_rag_engine = MagicMock()
mock_rag_engine.search = MagicMock(return_value=[
{"content": "Case 1"},
{"content": "Case 2"}
])
MockRAGEngine.return_value = mock_rag_engine
service = ConfigService()
cases = service._get_similar_cases("销售订单")
assert "Case 1" in cases
assert "Case 2" in cases

View File

@@ -0,0 +1,25 @@
import pytest
from app.core.db_engine import DatabaseEngine
def test_database_engine_init():
"""测试数据库引擎初始化"""
engine = DatabaseEngine()
assert engine.engine is not None
assert engine.Session is not None
def test_execute_sql_select():
"""测试执行 SELECT 查询"""
engine = DatabaseEngine()
result = engine.execute_sql("SELECT 1 AS test")
assert result is not None
assert len(result) > 0
def test_table_exists():
"""测试表存在性检查"""
engine = DatabaseEngine()
# 假设 SYS_FORM 表存在
exists = engine.table_exists("SYS_FORM")
assert exists is True

View File

@@ -0,0 +1,141 @@
"""Tests for Config Executor.
This module tests the ConfigExecutor class for SQL validation and execution.
"""
import pytest
from unittest.mock import MagicMock, patch
from app.core.executor import ConfigExecutor
def test_executor_init():
"""Test executor initialization."""
executor = ConfigExecutor()
assert executor.db_engine is not None
def test_validate_sql_safe():
"""Test validation of safe SQL statements."""
executor = ConfigExecutor()
# Test SELECT
is_valid, msg = executor.validate_sql("SELECT * FROM SYS_FORM")
assert is_valid is True
assert "验证通过" in msg
# Test INSERT
is_valid, msg = executor.validate_sql(
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (1, 'Test')"
)
assert is_valid is True
# Test UPDATE with safe WHERE clause
is_valid, msg = executor.validate_sql(
"UPDATE SYS_FORM SET FORM_NAME = 'Test' WHERE IKEY = 1"
)
assert is_valid is True
def test_validate_sql_dangerous():
"""Test validation catches dangerous SQL statements."""
executor = ConfigExecutor()
# Test DROP DATABASE
is_valid, msg = executor.validate_sql("DROP DATABASE test_db")
assert is_valid is False
assert "危险操作" in msg
# Test DROP TABLE
is_valid, msg = executor.validate_sql("DROP TABLE users")
assert is_valid is False
assert "危险操作" in msg
# Test TRUNCATE
is_valid, msg = executor.validate_sql("TRUNCATE TABLE important_data")
assert is_valid is False
assert "危险操作" in msg
# Test DELETE without WHERE
is_valid, msg = executor.validate_sql("DELETE FROM users")
assert is_valid is False
assert "危险操作" in msg
def test_execute_config_success():
"""Test successful execution of SQL list."""
with patch('app.core.executor.DatabaseEngine') as MockDBEngine:
mock_db_engine = MagicMock()
mock_db_engine.execute_transaction = MagicMock(return_value=True)
MockDBEngine.return_value = mock_db_engine
executor = ConfigExecutor()
sql_list = [
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (1, 'Test1')",
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (2, 'Test2')"
]
result = executor.execute_config(sql_list, session_id="test-session")
assert result["success"] is True
assert len(result["executed"]) == 2
assert result["failed"] is None
assert "成功执行" in result["message"]
def test_execute_config_validation_failure():
"""Test execution fails when SQL validation fails."""
executor = ConfigExecutor()
sql_list = [
"SELECT * FROM SYS_FORM",
"DROP DATABASE test" # Dangerous SQL
]
result = executor.execute_config(sql_list, session_id="test-session")
assert result["success"] is False
assert result["failed"] is not None
assert "验证失败" in result["message"]
assert len(result["executed"]) == 0
def test_execute_config_execution_failure():
"""Test execution handles database errors."""
with patch('app.core.executor.DatabaseEngine') as MockDBEngine:
mock_db_engine = MagicMock()
mock_db_engine.execute_transaction = MagicMock(
side_effect=Exception("Database connection error")
)
MockDBEngine.return_value = mock_db_engine
executor = ConfigExecutor()
sql_list = ["SELECT * FROM SYS_FORM"]
result = executor.execute_config(sql_list, session_id="test-session")
assert result["success"] is False
assert result["failed"] is not None
assert "执行失败" in result["message"]
def test_rollback_placeholder():
"""Test rollback functionality placeholder."""
executor = ConfigExecutor()
result = executor.rollback(session_id="test-session")
assert result["success"] is False
assert "待实现" in result["message"]
def test_dangerous_keywords_exist():
"""Test that dangerous keywords list is properly defined."""
executor = ConfigExecutor()
assert hasattr(executor, 'DANGEROUS_KEYWORDS')
assert len(executor.DANGEROUS_KEYWORDS) > 0
assert "DROP DATABASE" in executor.DANGEROUS_KEYWORDS
assert "DROP TABLE" in executor.DANGEROUS_KEYWORDS
assert "TRUNCATE TABLE" in executor.DANGEROUS_KEYWORDS

View File

@@ -0,0 +1,36 @@
"""Tests for prompt templates."""
from app.core.prompts import SYSTEM_PROMPT, ANALYZE_PROMPT_TEMPLATE, GENERATE_PROMPT_TEMPLATE
def test_system_prompt_exists():
"""测试系统 Prompt 存在"""
assert SYSTEM_PROMPT is not None
assert len(SYSTEM_PROMPT) > 100
# Test for stable characteristics rather than exact wording
assert "ERP" in SYSTEM_PROMPT
assert "配置" in SYSTEM_PROMPT
def test_analyze_prompt_template():
"""测试需求解析模板"""
rendered = ANALYZE_PROMPT_TEMPLATE.format(
user_input="创建销售订单",
knowledge_context="测试知识",
existing_tables="测试表"
)
assert "创建销售订单" in rendered
assert "测试知识" in rendered
assert "测试表" in rendered
def test_generate_prompt_template():
"""测试配置生成模板"""
rendered = GENERATE_PROMPT_TEMPLATE.format(
requirements="需求分析结果",
platform_rules="平台规则",
similar_cases="类似案例"
)
assert "需求分析结果" in rendered
assert "平台规则" in rendered
assert "类似案例" in rendered

View File

@@ -0,0 +1,156 @@
"""Tests for RAG Engine.
This module tests the RAGEngine class for document indexing and retrieval.
"""
import pytest
from unittest.mock import MagicMock, patch
from app.core.rag_engine import RAGEngine
def test_rag_engine_init():
"""Test RAG engine initialization."""
engine = RAGEngine()
assert engine.chroma_client is not None
assert engine.documents_collection is not None
assert engine.chunk_size > 0
assert engine.chunk_overlap >= 0
assert engine.chunk_overlap < engine.chunk_size
def test_split_text_basic():
"""Test basic text splitting functionality."""
engine = RAGEngine()
# Test with text longer than chunk_size
long_text = "A" * 1000
chunks = engine._split_text(long_text)
assert len(chunks) > 0
assert all(len(chunk) <= engine.chunk_size for chunk in chunks)
assert all(chunk.strip() for chunk in chunks) # No empty chunks
def test_split_text_empty():
"""Test splitting empty text."""
engine = RAGEngine()
# Test with empty text
assert engine._split_text("") == []
assert engine._split_text(" ") == []
def test_split_text_overlap():
"""Test text splitting with overlap."""
engine = RAGEngine()
# Test that chunks overlap correctly
text = "A" * 600
chunks = engine._split_text(text)
if len(chunks) > 1:
# Check overlap exists between consecutive chunks
# (This is a basic check; actual overlap content depends on implementation)
assert len(chunks) > 1
def test_add_document_success():
"""Test adding a document to the knowledge base."""
engine = RAGEngine()
# Mock the collection's add method
engine.documents_collection.add = MagicMock()
doc_id = "test_doc_1"
content = "This is a test document for the knowledge base."
metadata = {"source": "test", "type": "sample"}
num_chunks = engine.add_document(doc_id, content, metadata)
assert num_chunks > 0
assert engine.documents_collection.add.called
# Verify add was called with correct parameters
call_args = engine.documents_collection.add.call_args
assert "ids" in call_args.kwargs
assert "embeddings" in call_args.kwargs
assert "documents" in call_args.kwargs
assert "metadatas" in call_args.kwargs
def test_add_document_empty_content():
"""Test that adding empty document raises ValueError."""
engine = RAGEngine()
with pytest.raises(ValueError, match="Cannot add empty document"):
engine.add_document("test_doc", "")
with pytest.raises(ValueError, match="Cannot add empty document"):
engine.add_document("test_doc", " ")
def test_search_basic():
"""Test basic search functionality."""
engine = RAGEngine()
# Mock the collection's query method
mock_results = {
"documents": [["Result 1", "Result 2"]],
"metadatas": [[{"doc_id": "doc1"}, {"doc_id": "doc2"}]],
"distances": [[0.1, 0.2]]
}
engine.documents_collection.query = MagicMock(return_value=mock_results)
results = engine.search("test query", top_k=2)
assert len(results) == 2
assert results[0]["content"] == "Result 1"
assert results[0]["metadata"]["doc_id"] == "doc1"
assert results[0]["distance"] == 0.1
assert engine.documents_collection.query.called
def test_search_empty_query():
"""Test search with empty query returns empty results."""
engine = RAGEngine()
results = engine.search("", top_k=3)
assert results == []
results = engine.search(" ", top_k=3)
assert results == []
def test_search_invalid_top_k():
"""Test that search with invalid top_k raises ValueError."""
engine = RAGEngine()
with pytest.raises(ValueError, match="top_k cannot exceed 100"):
engine.search("test", top_k=101)
def test_delete_chunks_for_doc():
"""Test deleting chunks for a document."""
engine = RAGEngine()
# Mock the get and delete methods
engine.documents_collection.get = MagicMock(return_value={
"ids": ["doc1_chunk_0", "doc1_chunk_1"]
})
engine.documents_collection.delete = MagicMock()
engine._delete_chunks_for_doc("doc1")
assert engine.documents_collection.get.called
assert engine.documents_collection.delete.called
def test_close():
"""Test closing the RAG engine releases resources."""
engine = RAGEngine()
engine.close()
assert engine.embedding_model is None
assert engine.documents_collection is None
assert engine.chroma_client is None

View File

@@ -0,0 +1,116 @@
"""Tests for Requirement Service.
This module tests the RequirementService class for requirement analysis.
"""
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
from app.services.requirement_service import RequirementService
@pytest.mark.asyncio
async def test_analyze_requirement():
"""Test requirement analysis with mocked dependencies."""
# Create service with mocked engines
with patch('app.services.requirement_service.ClaudeEngine') as MockClaudeEngine, \
patch('app.services.requirement_service.RAGEngine') as MockRAGEngine, \
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
# Setup mocks
mock_ai_engine = MagicMock()
mock_ai_engine.call_claude = AsyncMock(return_value='{"功能名称": "销售订单管理", "功能类型": "列表页面"}')
mock_ai_engine.parse_json_response = MagicMock(return_value={
"功能名称": "销售订单管理",
"功能类型": "列表页面"
})
MockClaudeEngine.return_value = mock_ai_engine
mock_rag_engine = MagicMock()
mock_rag_engine.search = MagicMock(return_value=[
{"content": "Sample knowledge", "metadata": {"source": "docs"}}
])
MockRAGEngine.return_value = mock_rag_engine
mock_db_engine = MagicMock()
mock_db_engine.execute_sql = MagicMock(return_value=[("SYS_FORM",), ("SYS_MENU",)])
MockDBEngine.return_value = mock_db_engine
# Create service and test
service = RequirementService()
result = await service.analyze(
user_input="创建一个销售订单管理页面",
session_id="test-session"
)
assert result is not None
assert "功能名称" in result
assert result["功能名称"] == "销售订单管理"
@pytest.mark.asyncio
async def test_analyze_requirement_without_session_id():
"""Test that session_id is auto-generated if not provided."""
with patch('app.services.requirement_service.ClaudeEngine') as MockClaudeEngine, \
patch('app.services.requirement_service.RAGEngine') as MockRAGEngine, \
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
# Setup mocks
mock_ai_engine = MagicMock()
mock_ai_engine.call_claude = AsyncMock(return_value='{"功能名称": "测试功能"}')
mock_ai_engine.parse_json_response = MagicMock(return_value={"功能名称": "测试功能"})
MockClaudeEngine.return_value = mock_ai_engine
mock_rag_engine = MagicMock()
mock_rag_engine.search = MagicMock(return_value=[])
MockRAGEngine.return_value = mock_rag_engine
mock_db_engine = MagicMock()
mock_db_engine.execute_sql = MagicMock(return_value=[])
MockDBEngine.return_value = mock_db_engine
# Test without session_id
service = RequirementService()
result = await service.analyze(user_input="测试输入")
assert result is not None
assert "功能名称" in result
@pytest.mark.asyncio
async def test_get_existing_tables_success():
"""Test successful retrieval of existing tables."""
with patch('app.services.requirement_service.ClaudeEngine'), \
patch('app.services.requirement_service.RAGEngine'), \
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
mock_db_engine = MagicMock()
mock_db_engine.execute_sql = MagicMock(return_value=[
("SYS_FORM",),
("SYS_MENU",),
("SYS_USER",)
])
MockDBEngine.return_value = mock_db_engine
service = RequirementService()
tables = service._get_existing_tables("测试")
assert "SYS_FORM" in tables
assert "SYS_MENU" in tables
assert "SYS_USER" in tables
@pytest.mark.asyncio
async def test_get_existing_tables_failure():
"""Test handling of database query failure."""
with patch('app.services.requirement_service.ClaudeEngine'), \
patch('app.services.requirement_service.RAGEngine'), \
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
mock_db_engine = MagicMock()
mock_db_engine.execute_sql = MagicMock(side_effect=Exception("DB Error"))
MockDBEngine.return_value = mock_db_engine
service = RequirementService()
tables = service._get_existing_tables("测试")
assert "无法获取现有表信息" in tables