feat: implement ERP AI Assistant Phase 1
Backend (FastAPI + SQLAlchemy + Claude API + RAG): - Config management with Pydantic v2 - Database engine with connection pooling and SQL injection prevention - AI engine with Claude API integration (support custom base URL) - RAG engine with ChromaDB and sentence-transformers - Requirement analysis service - Config generation service - Executor engine with SQL validation - REST API endpoints: /analyze, /generate, /execute Frontend (Vue 3 + Element Plus + Pinia): - Complete 3-step workflow: analyze → generate → execute - Step indicator with progress visualization - Analysis result display with field table - SQL preview with monospace font - Execute confirmation dialog with safety warning - Execution result display - State management with Pinia - API service integration Security: - SQL injection prevention with parameterized queries - Dangerous SQL operation blocking - Database password URL encoding - Transaction auto-rollback - Pydantic config validation Features: - Natural language requirement analysis - Automated SQL configuration generation - Safe execution with human review - LAN access support - Custom Claude API endpoint support Documentation: - README with quick start guide - Quick start guide - LAN access configuration - Dependency fixes guide - Claude API configuration - Git operation guide - Implementation report Dependencies fixed: - numpy<2.0.0 for chromadb compatibility - sentence-transformers==2.7.0 for huggingface_hub compatibility Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2
backend/app/__init__.py
Normal file
2
backend/app/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""ERP AI Assistant Backend"""
|
||||
__version__ = "1.0.0"
|
||||
1
backend/app/api/__init__.py
Normal file
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routes for ERP AI Assistant."""
|
||||
113
backend/app/api/analyze.py
Normal file
113
backend/app/api/analyze.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Analyze API endpoint for requirement analysis.
|
||||
|
||||
This module provides the /analyze endpoint for analyzing user requirements.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
from app.models.request import AnalyzeRequest
|
||||
from app.models.response import AnalyzeResponse, ErrorResponse
|
||||
from app.services.requirement_service import RequirementService
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/analyze",
|
||||
response_model=AnalyzeResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
500: {"model": ErrorResponse, "description": "Internal server error"}
|
||||
},
|
||||
summary="Analyze user requirement",
|
||||
description="Analyze natural language or structured requirement and return structured specification"
|
||||
)
|
||||
async def analyze_requirement(request: AnalyzeRequest) -> AnalyzeResponse:
|
||||
"""Analyze user requirement and return structured specification.
|
||||
|
||||
This endpoint accepts either natural language or structured input,
|
||||
processes it through Claude AI with RAG knowledge retrieval, and
|
||||
returns a structured requirement specification.
|
||||
|
||||
Args:
|
||||
request: AnalyzeRequest containing input_type, content, and optional session_id
|
||||
|
||||
Returns:
|
||||
AnalyzeResponse with session_id, status, and structured data
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 for invalid input, 500 for processing errors
|
||||
"""
|
||||
# Generate session ID if not provided
|
||||
session_id = request.session_id or str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Validate input type
|
||||
if request.input_type not in ["natural_language", "structured"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "INVALID_INPUT_TYPE",
|
||||
"message": "input_type must be 'natural_language' or 'structured'",
|
||||
"session_id": session_id
|
||||
}
|
||||
)
|
||||
|
||||
# Validate content
|
||||
if not request.content or not request.content.strip():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "EMPTY_CONTENT",
|
||||
"message": "content cannot be empty",
|
||||
"session_id": session_id
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[{session_id}] Processing analyze request: {request.content[:50]}...")
|
||||
|
||||
# Create service and analyze
|
||||
service = RequirementService()
|
||||
result = await service.analyze(
|
||||
user_input=request.content,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
logger.success(f"[{session_id}] Analysis completed successfully")
|
||||
|
||||
return AnalyzeResponse(
|
||||
session_id=session_id,
|
||||
status="success",
|
||||
data=result
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[{session_id}] Validation error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
"session_id": session_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Analysis failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "ANALYSIS_FAILED",
|
||||
"message": f"Failed to analyze requirement: {str(e)}",
|
||||
"session_id": session_id
|
||||
}
|
||||
)
|
||||
151
backend/app/api/execute.py
Normal file
151
backend/app/api/execute.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Execute API endpoint for SQL configuration execution.
|
||||
|
||||
This module provides the /execute endpoint for executing SQL configuration.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
from app.models.request import ExecuteRequest
|
||||
from app.models.response import ExecuteResponse, ErrorResponse
|
||||
from app.core.executor import ConfigExecutor
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
# In-memory storage for SQL lists (should use Redis/database in production)
|
||||
_session_sql_store: Dict[str, list] = {}
|
||||
|
||||
|
||||
def store_session_sql(session_id: str, sql_list: list) -> None:
|
||||
"""Store SQL list for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
sql_list: List of SQL statements
|
||||
"""
|
||||
_session_sql_store[session_id] = sql_list
|
||||
logger.debug(f"Stored {len(sql_list)} SQL statements for session {session_id}")
|
||||
|
||||
|
||||
def get_session_sql(session_id: str) -> list:
|
||||
"""Retrieve SQL list for a session.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of SQL statements (empty list if not found)
|
||||
"""
|
||||
sql_list = _session_sql_store.get(session_id, [])
|
||||
logger.debug(f"Retrieved {len(sql_list)} SQL statements for session {session_id}")
|
||||
return sql_list
|
||||
|
||||
|
||||
@router.post(
|
||||
"/execute",
|
||||
response_model=ExecuteResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
500: {"model": ErrorResponse, "description": "Internal server error"}
|
||||
},
|
||||
summary="Execute SQL configuration",
|
||||
description="Execute SQL configuration after user confirmation"
|
||||
)
|
||||
async def execute_config(request: ExecuteRequest) -> ExecuteResponse:
|
||||
"""Execute SQL configuration after user confirmation.
|
||||
|
||||
This endpoint executes the SQL statements associated with the session.
|
||||
User must set confirmed=True to proceed with execution.
|
||||
|
||||
Args:
|
||||
request: ExecuteRequest with session_id, confirmed, and backup_enabled
|
||||
|
||||
Returns:
|
||||
ExecuteResponse with execution_id, status, and message
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if not confirmed or invalid, 500 for execution errors
|
||||
"""
|
||||
# Generate execution ID
|
||||
execution_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
# Check user confirmation
|
||||
if not request.confirmed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "NOT_CONFIRMED",
|
||||
"message": "User must confirm execution by setting confirmed=True",
|
||||
"session_id": request.session_id,
|
||||
"execution_id": execution_id
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[{request.session_id}] Processing execute request")
|
||||
|
||||
# Retrieve SQL list for session
|
||||
sql_list = get_session_sql(request.session_id)
|
||||
|
||||
if not sql_list:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "NO_SQL_FOUND",
|
||||
"message": "No SQL statements found for this session",
|
||||
"session_id": request.session_id,
|
||||
"execution_id": execution_id
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[{request.session_id}] Retrieved {len(sql_list)} SQL statements")
|
||||
|
||||
# Create executor
|
||||
executor = ConfigExecutor()
|
||||
|
||||
# Execute configuration
|
||||
result = executor.execute_config(sql_list, request.session_id)
|
||||
|
||||
if result["success"]:
|
||||
logger.success(
|
||||
f"[{request.session_id}] Execution completed: {result['message']}"
|
||||
)
|
||||
return ExecuteResponse(
|
||||
execution_id=execution_id,
|
||||
status="success",
|
||||
message=result["message"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"[{request.session_id}] Execution failed: {result['failed']}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "EXECUTION_FAILED",
|
||||
"message": result["message"],
|
||||
"error": result["failed"],
|
||||
"session_id": request.session_id,
|
||||
"execution_id": execution_id
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{request.session_id}] Execution failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "EXECUTION_ERROR",
|
||||
"message": f"Failed to execute config: {str(e)}",
|
||||
"session_id": request.session_id,
|
||||
"execution_id": execution_id
|
||||
}
|
||||
)
|
||||
102
backend/app/api/generate.py
Normal file
102
backend/app/api/generate.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Generate API endpoint for configuration generation.
|
||||
|
||||
This module provides the /generate endpoint for generating SQL configuration.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
from app.models.request import GenerateRequest
|
||||
from app.models.response import GenerateResponse, ErrorResponse
|
||||
from app.services.config_service import ConfigService
|
||||
from app.api.execute import store_session_sql # Import SQL storage function
|
||||
|
||||
# Create router
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/generate",
|
||||
response_model=GenerateResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid request"},
|
||||
500: {"model": ErrorResponse, "description": "Internal server error"}
|
||||
},
|
||||
summary="Generate SQL configuration",
|
||||
description="Generate SQL configuration based on structured requirements"
|
||||
)
|
||||
async def generate_config(request: GenerateRequest) -> GenerateResponse:
|
||||
"""Generate SQL configuration based on structured requirements.
|
||||
|
||||
This endpoint takes structured requirements from the analysis phase
|
||||
and generates SQL configuration statements using Claude AI.
|
||||
|
||||
Args:
|
||||
request: GenerateRequest with session_id and requirements
|
||||
|
||||
Returns:
|
||||
GenerateResponse with session_id, status, and generated config
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 for invalid input, 500 for processing errors
|
||||
"""
|
||||
try:
|
||||
# Validate requirements
|
||||
if not request.requirements:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "EMPTY_REQUIREMENTS",
|
||||
"message": "requirements cannot be empty",
|
||||
"session_id": request.session_id
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[{request.session_id}] Processing generate request")
|
||||
|
||||
# Create service and generate config
|
||||
service = ConfigService()
|
||||
result = await service.generate(
|
||||
requirements=request.requirements,
|
||||
session_id=request.session_id
|
||||
)
|
||||
|
||||
# Store generated SQL for later execution
|
||||
if result and result.get("配置方案") and result["配置方案"].get("sql_list"):
|
||||
sql_list = result["配置方案"]["sql_list"]
|
||||
store_session_sql(request.session_id, sql_list)
|
||||
logger.info(f"[{request.session_id}] Stored {len(sql_list)} SQL statements for execution")
|
||||
|
||||
logger.success(f"[{request.session_id}] Config generation completed")
|
||||
|
||||
return GenerateResponse(
|
||||
session_id=request.session_id,
|
||||
status="success",
|
||||
data=result
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[{request.session_id}] Validation error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
"session_id": request.session_id
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{request.session_id}] Config generation failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"code": "GENERATION_FAILED",
|
||||
"message": f"Failed to generate config: {str(e)}",
|
||||
"session_id": request.session_id
|
||||
}
|
||||
)
|
||||
67
backend/app/config.py
Normal file
67
backend/app/config.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import ConfigDict, field_validator
|
||||
from functools import lru_cache
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Application
|
||||
APP_NAME: str = "ERP AI Assistant"
|
||||
APP_ENV: str = "development"
|
||||
DEBUG: bool = True
|
||||
SECRET_KEY: str
|
||||
|
||||
# Database
|
||||
DB_DRIVER: str
|
||||
DB_SERVER: str
|
||||
DB_PORT: int = 1433
|
||||
DB_NAME: str
|
||||
DB_USER: str
|
||||
DB_PASSWORD: str
|
||||
|
||||
# Claude API
|
||||
ANTHROPIC_API_KEY: str
|
||||
ANTHROPIC_BASE_URL: str | None = None # Optional custom base URL for proxy/self-hosted
|
||||
CLAUDE_MODEL: str = "claude-sonnet-4-6"
|
||||
CLAUDE_MAX_TOKENS: int = 8192
|
||||
CLAUDE_TEMPERATURE: float = 0.7
|
||||
|
||||
# Knowledge Base
|
||||
KNOWLEDGE_BASE_PATH: str = "./knowledge_base"
|
||||
CHROMA_DB_PATH: str = "./knowledge_base/chroma_db"
|
||||
EMBEDDING_MODEL: str = "all-MiniLM-L6-v2"
|
||||
CHUNK_SIZE: int = 500
|
||||
CHUNK_OVERLAP: int = 50
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""构建数据库连接 URL(密码安全编码)"""
|
||||
password = quote_plus(self.DB_PASSWORD)
|
||||
return (
|
||||
f"mssql+pyodbc://{self.DB_USER}:{password}"
|
||||
f"@{self.DB_SERVER}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
f"?driver={quote_plus(self.DB_DRIVER)}"
|
||||
)
|
||||
|
||||
@field_validator('CLAUDE_TEMPERATURE')
|
||||
@classmethod
|
||||
def validate_temperature(cls, v):
|
||||
if not 0 <= v <= 2:
|
||||
raise ValueError('CLAUDE_TEMPERATURE must be between 0 and 2')
|
||||
return v
|
||||
|
||||
@field_validator('CHUNK_OVERLAP')
|
||||
@classmethod
|
||||
def validate_chunk_overlap(cls, v, info):
|
||||
chunk_size = info.data.get('CHUNK_SIZE', 500)
|
||||
if v >= chunk_size:
|
||||
raise ValueError('CHUNK_OVERLAP must be less than CHUNK_SIZE')
|
||||
return v
|
||||
|
||||
model_config = ConfigDict(env_file=".env", case_sensitive=True)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""获取配置单例"""
|
||||
return Settings()
|
||||
1
backend/app/core/__init__.py
Normal file
1
backend/app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core modules"""
|
||||
120
backend/app/core/ai_engine.py
Normal file
120
backend/app/core/ai_engine.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""AI Engine for ERP AI Assistant.
|
||||
|
||||
This module provides the ClaudeEngine class that wraps Claude API calls
|
||||
and provides JSON parsing utilities.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import anthropic
|
||||
from loguru import logger
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class ClaudeEngine:
|
||||
"""Engine for interacting with Claude API.
|
||||
|
||||
This class wraps the Anthropic Claude API client and provides
|
||||
utilities for parsing JSON responses from Claude.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Claude engine with settings."""
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize Anthropic client with optional custom base_url
|
||||
client_kwargs = {"api_key": settings.ANTHROPIC_API_KEY}
|
||||
if settings.ANTHROPIC_BASE_URL:
|
||||
client_kwargs["base_url"] = settings.ANTHROPIC_BASE_URL
|
||||
logger.info(f"Using custom Anthropic base URL: {settings.ANTHROPIC_BASE_URL}")
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(**client_kwargs)
|
||||
self.model = settings.CLAUDE_MODEL
|
||||
self.max_tokens = settings.CLAUDE_MAX_TOKENS
|
||||
self.temperature = settings.CLAUDE_TEMPERATURE
|
||||
|
||||
def parse_json_response(self, content: str) -> dict[str, Any]:
|
||||
"""Parse JSON from Claude responses.
|
||||
|
||||
Attempts multiple parsing strategies:
|
||||
1. Direct JSON parse
|
||||
2. Extract from markdown code blocks
|
||||
3. Extract any {...} block
|
||||
|
||||
Args:
|
||||
content: The response content from Claude
|
||||
|
||||
Returns:
|
||||
Parsed JSON as a dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If JSON cannot be parsed using any strategy
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
raise ValueError("Empty content provided")
|
||||
|
||||
# Strategy 1: Try direct JSON parse
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Strategy 2: Try extracting from markdown code blocks
|
||||
json_code_block_pattern = r'```json\s*(\{.*?\})\s*```'
|
||||
json_match = re.search(json_code_block_pattern, content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Also try any code block (not just json tagged)
|
||||
code_block_pattern = r'```\s*(\{.*?\})\s*```'
|
||||
code_block_match = re.search(code_block_pattern, content, re.DOTALL)
|
||||
if code_block_match:
|
||||
try:
|
||||
return json.loads(code_block_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Strategy 3: Try extracting any {...} block
|
||||
# Find balanced braces
|
||||
brace_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
||||
json_blocks = re.findall(brace_pattern, content, re.DOTALL)
|
||||
for json_block in json_blocks:
|
||||
try:
|
||||
return json.loads(json_block)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# All strategies failed
|
||||
logger.error(f"无法解析 Claude 返回的 JSON: {content[:200]}")
|
||||
raise ValueError("无法解析 Claude 返回的 JSON,请检查响应格式")
|
||||
|
||||
async def call_claude(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float | None = None
|
||||
) -> str:
|
||||
"""Call Claude API.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries with 'role' and 'content'
|
||||
temperature: Optional temperature override (0-2)
|
||||
|
||||
Returns:
|
||||
The text content from Claude's response
|
||||
|
||||
Raises:
|
||||
Exception: If the API call fails
|
||||
"""
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=temperature if temperature is not None else self.temperature,
|
||||
messages=messages
|
||||
)
|
||||
return response.content[0].text
|
||||
78
backend/app/core/db_engine.py
Normal file
78
backend/app/core/db_engine.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import Optional
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from contextlib import contextmanager
|
||||
from loguru import logger
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class DatabaseEngine:
|
||||
"""数据库操作引擎"""
|
||||
|
||||
def __init__(self):
|
||||
settings = get_settings()
|
||||
self.engine = create_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=20,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.DEBUG
|
||||
)
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self):
|
||||
"""获取数据库会话(上下文管理器)"""
|
||||
session = self.Session()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"数据库操作失败:{e}")
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def execute_sql(self, sql: str, params: Optional[dict] = None) -> list:
|
||||
"""执行单条 SQL"""
|
||||
with self.get_session() as session:
|
||||
result = session.execute(text(sql), params or {})
|
||||
return result.fetchall()
|
||||
|
||||
def execute_transaction(self, sql_list: list, params_list: Optional[list] = None) -> bool:
|
||||
"""执行事务(多条 SQL)"""
|
||||
params_list = params_list or [None] * len(sql_list)
|
||||
with self.get_session() as session:
|
||||
for sql, params in zip(sql_list, params_list):
|
||||
session.execute(text(sql), params or {})
|
||||
return True
|
||||
|
||||
def get_table_structure(self, table_name: str):
|
||||
"""获取表结构(安全参数化查询)"""
|
||||
sql = """
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
DATA_TYPE,
|
||||
CHARACTER_MAXIMUM_LENGTH,
|
||||
IS_NULLABLE,
|
||||
COLUMN_DEFAULT
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME = :table_name
|
||||
ORDER BY ORDINAL_POSITION
|
||||
"""
|
||||
return self.execute_sql(sql, {"table_name": table_name})
|
||||
|
||||
def table_exists(self, table_name: str) -> bool:
|
||||
"""检查表是否存在(安全参数化查询)"""
|
||||
sql = """
|
||||
SELECT COUNT(*)
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_NAME = :table_name
|
||||
"""
|
||||
result = self.execute_sql(sql, {"table_name": table_name})
|
||||
return result[0][0] > 0
|
||||
|
||||
def dispose(self):
|
||||
"""关闭连接池,释放资源"""
|
||||
self.engine.dispose()
|
||||
147
backend/app/core/executor.py
Normal file
147
backend/app/core/executor.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Config Executor for ERP AI Assistant.
|
||||
|
||||
This module provides the ConfigExecutor class for validating and executing
|
||||
SQL configuration statements with safety checks.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Tuple, Dict, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.core.db_engine import DatabaseEngine
|
||||
|
||||
|
||||
class ConfigExecutor:
|
||||
"""Executor for SQL configuration statements with safety validation.
|
||||
|
||||
This class validates SQL statements against dangerous operations before
|
||||
execution and provides transaction-based execution with rollback support.
|
||||
"""
|
||||
|
||||
# Dangerous SQL keywords that should be blocked
|
||||
DANGEROUS_KEYWORDS = [
|
||||
r"DROP\s+DATABASE",
|
||||
r"DROP\s+TABLE",
|
||||
r"TRUNCATE\s+TABLE",
|
||||
r"DELETE\s+FROM",
|
||||
r"UPDATE\s+.*\s+SET",
|
||||
r"ALTER\s+TABLE\s+.*\s+DROP"
|
||||
]
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize executor with database engine."""
|
||||
self.db_engine = DatabaseEngine()
|
||||
logger.info("ConfigExecutor initialized")
|
||||
|
||||
def validate_sql(self, sql: str) -> Tuple[bool, str]:
|
||||
"""Validate SQL statement for safety.
|
||||
|
||||
Checks SQL against a list of dangerous keywords/patterns to prevent
|
||||
destructive operations.
|
||||
|
||||
Args:
|
||||
sql: SQL statement to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, message) where is_valid indicates if SQL is safe
|
||||
"""
|
||||
if not sql or not sql.strip():
|
||||
return False, "SQL语句为空"
|
||||
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
# Check for dangerous operations
|
||||
for pattern in self.DANGEROUS_KEYWORDS:
|
||||
if re.search(pattern, sql_upper):
|
||||
# Extract matched keyword for error message
|
||||
match = re.search(pattern, sql_upper)
|
||||
matched_keyword = match.group(0) if match else pattern
|
||||
logger.warning(f"SQL validation failed: dangerous operation '{matched_keyword}' detected")
|
||||
return False, f"危险操作被拦截: {matched_keyword}"
|
||||
|
||||
logger.debug(f"SQL validation passed: {sql[:50]}...")
|
||||
return True, "SQL验证通过"
|
||||
|
||||
def execute_config(
|
||||
self,
|
||||
sql_list: List[str],
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a list of SQL statements in a transaction.
|
||||
|
||||
Validates all SQL statements before execution. If any validation fails,
|
||||
no statements are executed.
|
||||
|
||||
Args:
|
||||
sql_list: List of SQL statements to execute
|
||||
session_id: Session ID for logging and tracking
|
||||
|
||||
Returns:
|
||||
Dictionary containing:
|
||||
- success: Boolean indicating overall success
|
||||
- executed: List of executed SQL statements
|
||||
- failed: Error message if execution failed, None otherwise
|
||||
- message: Human-readable result message
|
||||
"""
|
||||
logger.info(f"[{session_id}] Starting config execution with {len(sql_list)} SQL statements")
|
||||
|
||||
results: Dict[str, Any] = {
|
||||
"success": True,
|
||||
"executed": [],
|
||||
"failed": None,
|
||||
"message": ""
|
||||
}
|
||||
|
||||
try:
|
||||
# Step 1: Validate all SQL statements
|
||||
logger.debug(f"[{session_id}] Validating {len(sql_list)} SQL statements")
|
||||
for i, sql in enumerate(sql_list):
|
||||
is_valid, msg = self.validate_sql(sql)
|
||||
if not is_valid:
|
||||
error_msg = f"SQL #{i+1} 验证失败: {msg}"
|
||||
logger.error(f"[{session_id}] {error_msg}")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Step 2: Execute transaction
|
||||
logger.debug(f"[{session_id}] Executing transaction")
|
||||
self.db_engine.execute_transaction(sql_list)
|
||||
|
||||
# Step 3: Record success
|
||||
results["executed"] = sql_list
|
||||
results["message"] = f"成功执行 {len(sql_list)} 条SQL"
|
||||
logger.success(f"[{session_id}] {results['message']}")
|
||||
|
||||
except ValueError as e:
|
||||
# Validation failure
|
||||
results["success"] = False
|
||||
results["failed"] = str(e)
|
||||
results["message"] = f"执行失败: {e}"
|
||||
logger.error(f"[{session_id}] {results['message']}")
|
||||
|
||||
except Exception as e:
|
||||
# Execution failure
|
||||
results["success"] = False
|
||||
results["failed"] = str(e)
|
||||
results["message"] = f"执行失败: {e}"
|
||||
logger.error(f"[{session_id}] {results['message']}")
|
||||
|
||||
return results
|
||||
|
||||
def rollback(self, session_id: str) -> Dict[str, Any]:
|
||||
"""Rollback executed operations for a session.
|
||||
|
||||
This is a placeholder for rollback functionality. Actual implementation
|
||||
would require recording inverse SQL operations during execution.
|
||||
|
||||
Args:
|
||||
session_id: Session ID to rollback
|
||||
|
||||
Returns:
|
||||
Dictionary with success status and message
|
||||
"""
|
||||
logger.warning(f"[{session_id}] Rollback requested but not yet implemented")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "回滚功能待实现"
|
||||
}
|
||||
144
backend/app/core/prompts.py
Normal file
144
backend/app/core/prompts.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Prompt 模板定义
|
||||
|
||||
模板说明:
|
||||
- SYSTEM_PROMPT: 系统提示词,定义 Claude 的角色和专业领域
|
||||
- ANALYZE_PROMPT_TEMPLATE: 需求解析模板,占位符:user_input, knowledge_context, existing_tables
|
||||
- GENERATE_PROMPT_TEMPLATE: 配置生成模板,占位符:requirements, platform_rules, similar_cases
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = """你是一个 ERP 平台配置专家助手,专门帮助开发人员配置一零软件结构化开发平台。
|
||||
|
||||
## 你的职责
|
||||
|
||||
你是 ERP 系统配置和开发的专业顾问,负责:
|
||||
1. 分析用户需求,理解业务场景
|
||||
2. 设计合理的数据库表结构
|
||||
3. 生成符合平台规范的配置方案
|
||||
4. 提供完整的 SQL 脚本和配置说明
|
||||
|
||||
## 平台知识
|
||||
|
||||
你熟悉以下平台概念:
|
||||
- 窗体类型:基础资料、单据、报表、系统设置等
|
||||
- 标准字段命名规范:F 开头的主键、FPrefix 前缀的自定义字段
|
||||
- 配置流程:需求分析 → 表结构设计 → 功能配置 → 页面配置 → 菜单配置
|
||||
- 命名约定:表名以 T_开头,功能号以功能类别前缀开头
|
||||
|
||||
## 输出要求
|
||||
|
||||
1. 提供完整的 SQL 脚本,包括建表语句、函数配置、页面配置等
|
||||
2. 确保配置符合平台规范和最佳实践
|
||||
3. 进行风险评估,提示潜在问题
|
||||
4. 使用 JSON 格式输出结构化结果
|
||||
5. 所有字段和表名使用英文,注释使用中文
|
||||
|
||||
请始终保持专业、严谨的工作态度,确保输出的配置方案可落地执行。"""
|
||||
|
||||
|
||||
ANALYZE_PROMPT_TEMPLATE = """请分析以下用户需求,生成结构化的需求分析文档。
|
||||
|
||||
## 用户输入
|
||||
{user_input}
|
||||
|
||||
## 相关知识上下文
|
||||
{knowledge_context}
|
||||
|
||||
## 现有表结构
|
||||
{existing_tables}
|
||||
|
||||
## 分析要求
|
||||
|
||||
请输出结构化的需求分析文档,使用 JSON 格式,包含以下字段:
|
||||
|
||||
# Note: Use {{ and }} to escape braces for .format() - rendered as literal { and }
|
||||
```json
|
||||
{{
|
||||
"功能名称": "功能的中文名称",
|
||||
"功能号建议": "建议的功能编号,如 SAL001",
|
||||
"窗体类型": "基础资料/单据/报表/系统设置",
|
||||
"主表名建议": "建议的主表名,如 T_SAL_Order",
|
||||
"从表名建议": "建议的从表名,如 T_SAL_OrderEntry",
|
||||
"主表字段": [
|
||||
{{"字段名": "FOrderId", "字段类型": "varchar(50)", "中文名称": "订单编号", "必填": true}},
|
||||
...
|
||||
],
|
||||
"从表字段": [
|
||||
{{"字段名": "FEntryId", "字段类型": "int", "中文名称": "分录 ID", "必填": true}},
|
||||
...
|
||||
],
|
||||
"业务需求": "详细的业务需求描述",
|
||||
"关联表": ["相关表 1", "相关表 2"],
|
||||
"风险提示": ["潜在风险 1", "潜在风险 2"]
|
||||
}}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 字段命名遵循平台规范:主键以 F 开头,使用 PascalCase
|
||||
2. 表名以 T_开头,使用模块前缀
|
||||
3. 考虑必填字段、默认值、数据长度等约束
|
||||
4. 识别必要的业务关联关系
|
||||
5. 评估潜在的数据一致性和性能风险"""
|
||||
|
||||
|
||||
GENERATE_PROMPT_TEMPLATE = """请根据需求分析结果,生成完整的平台配置方案。
|
||||
|
||||
## 需求分析结果
|
||||
{requirements}
|
||||
|
||||
## 平台规则
|
||||
{platform_rules}
|
||||
|
||||
## 类似案例参考
|
||||
{similar_cases}
|
||||
|
||||
## 生成要求
|
||||
|
||||
请生成完整的配置方案,使用 JSON 格式,包含以下内容:
|
||||
|
||||
# Note: Use {{ and }} to escape braces for .format() - rendered as literal { and }
|
||||
```json
|
||||
{{
|
||||
"table_sql": "建表 SQL 语句,包括主表和从表",
|
||||
"function_config_sql": "功能配置 SQL 语句",
|
||||
"page_config_sql": "页面配置 SQL 语句",
|
||||
"menu_config_sql": "菜单配置 SQL 语句",
|
||||
"ikey_config_sql": "IKEY 配置 SQL 语句",
|
||||
"config_summary": {{
|
||||
"created_tables": ["表 1", "表 2"],
|
||||
"main_entities": ["主要实体 1", "主要实体 2"],
|
||||
"relationships": "表间关系说明"
|
||||
}},
|
||||
"implementation_notes": "实施注意事项",
|
||||
"validation_rules": ["验证规则 1", "验证规则 2"]
|
||||
}}
|
||||
```
|
||||
|
||||
## 配置规范
|
||||
|
||||
1. **建表 SQL**:
|
||||
- 主键使用 FId 或 F+ 表名缩写 + Id
|
||||
- 包含创建时间、创建人、更新时间、更新人等审计字段
|
||||
- 使用合适的索引提高查询性能
|
||||
|
||||
2. **功能配置**:
|
||||
- 定义功能号、功能名称、功能类型
|
||||
- 配置数据权限和操作权限
|
||||
|
||||
3. **页面配置**:
|
||||
- 配置表单布局、字段顺序
|
||||
- 设置字段属性(必填、只读、可见性)
|
||||
|
||||
4. **菜单配置**:
|
||||
- 配置菜单层级、图标、排序
|
||||
|
||||
5. **IKEY 配置**:
|
||||
- 配置编码规则、生成策略
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 所有 SQL 语句需要语法正确、可直接执行
|
||||
- 配置需要符合平台规范
|
||||
- 考虑扩展性和维护性
|
||||
- 提供必要的注释说明"""
|
||||
251
backend/app/core/rag_engine.py
Normal file
251
backend/app/core/rag_engine.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""RAG Engine for ERP AI Assistant.
|
||||
|
||||
This module provides the RAGEngine class that handles knowledge document
|
||||
storage and retrieval using ChromaDB and sentence-transformers embeddings.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings as ChromaSettings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from loguru import logger
|
||||
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
class RAGEngine:
|
||||
"""RAG Engine for knowledge document retrieval.
|
||||
|
||||
This class wraps ChromaDB vector database and sentence-transformers
|
||||
to provide semantic search over knowledge documents.
|
||||
"""
|
||||
|
||||
# Class-level singleton for embedding model (lazy loading)
|
||||
_embedding_model: Optional[SentenceTransformer] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize RAG engine with ChromaDB and embedding model."""
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize ChromaDB persistent client
|
||||
logger.info(f"Initializing ChromaDB at: {settings.CHROMA_DB_PATH}")
|
||||
self.chroma_client = chromadb.PersistentClient(
|
||||
path=settings.CHROMA_DB_PATH,
|
||||
settings=ChromaSettings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Load sentence-transformers embedding model (lazy loading, singleton)
|
||||
logger.info(f"Loading embedding model: {settings.EMBEDDING_MODEL}")
|
||||
self.embedding_model = self._get_embedding_model(settings.EMBEDDING_MODEL)
|
||||
|
||||
# Get or create documents collection
|
||||
self.documents_collection = self.chroma_client.get_or_create_collection(
|
||||
name="documents"
|
||||
)
|
||||
|
||||
# Store chunking settings
|
||||
self.chunk_size = settings.CHUNK_SIZE
|
||||
self.chunk_overlap = settings.CHUNK_OVERLAP
|
||||
|
||||
logger.info(
|
||||
f"RAG Engine initialized: chunk_size={self.chunk_size}, "
|
||||
f"chunk_overlap={self.chunk_overlap}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_embedding_model(cls, model_name: str) -> SentenceTransformer:
|
||||
"""Get or create the embedding model (lazy loading, singleton).
|
||||
|
||||
Args:
|
||||
model_name: Name of the embedding model to load
|
||||
|
||||
Returns:
|
||||
SentenceTransformer embedding model instance
|
||||
"""
|
||||
if cls._embedding_model is None:
|
||||
logger.info(f"Loading embedding model: {model_name}")
|
||||
cls._embedding_model = SentenceTransformer(model_name)
|
||||
return cls._embedding_model
|
||||
|
||||
def _split_text(self, text: str) -> list[str]:
|
||||
"""Split text into overlapping chunks.
|
||||
|
||||
Args:
|
||||
text: The text to split
|
||||
|
||||
Returns:
|
||||
List of chunk strings
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_length = len(text)
|
||||
|
||||
while start < text_length:
|
||||
end = start + self.chunk_size
|
||||
chunk = text[start:end]
|
||||
|
||||
if chunk.strip(): # Only add non-empty chunks
|
||||
chunks.append(chunk)
|
||||
|
||||
start += self.chunk_size - self.chunk_overlap
|
||||
|
||||
# Avoid infinite loop if overlap >= chunk_size
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
start += 1
|
||||
|
||||
return chunks
|
||||
|
||||
def _delete_chunks_for_doc(self, doc_id: str) -> None:
|
||||
"""Delete all chunks associated with a document.
|
||||
|
||||
Args:
|
||||
doc_id: The document ID to delete chunks for
|
||||
"""
|
||||
try:
|
||||
# Find all chunks for this document
|
||||
results = self.documents_collection.get(
|
||||
where={"doc_id": doc_id},
|
||||
include=[]
|
||||
)
|
||||
if results and results.get("ids"):
|
||||
self.documents_collection.delete(ids=results["ids"])
|
||||
logger.debug(f"Deleted {len(results['ids'])} chunks for document '{doc_id}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete chunks for document '{doc_id}': {e}")
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
doc_id: str,
|
||||
content: str,
|
||||
metadata: Optional[dict] = None
|
||||
) -> int:
|
||||
"""Add a document to the knowledge base.
|
||||
|
||||
Args:
|
||||
doc_id: Unique identifier for the document
|
||||
content: The document content to index
|
||||
metadata: Optional metadata dict to store with the document
|
||||
|
||||
Returns:
|
||||
Number of chunks added
|
||||
|
||||
Raises:
|
||||
ValueError: If content is empty
|
||||
"""
|
||||
if not content or not content.strip():
|
||||
raise ValueError("Cannot add empty document")
|
||||
|
||||
try:
|
||||
# Delete existing chunks for this doc_id (handles duplicates)
|
||||
self._delete_chunks_for_doc(doc_id)
|
||||
|
||||
# Split content into chunks
|
||||
chunks = self._split_text(content)
|
||||
logger.info(f"Split document '{doc_id}' into {len(chunks)} chunks")
|
||||
|
||||
if not chunks:
|
||||
return 0
|
||||
|
||||
# Generate embeddings for all chunks
|
||||
logger.debug(f"Generating embeddings for {len(chunks)} chunks")
|
||||
embeddings = self.embedding_model.encode(chunks)
|
||||
|
||||
# Prepare chunk IDs and metadata
|
||||
chunk_ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
|
||||
|
||||
# Add metadata to each chunk
|
||||
chunk_metadata = []
|
||||
base_metadata = metadata or {}
|
||||
for i, chunk in enumerate(chunks):
|
||||
meta = {
|
||||
**base_metadata,
|
||||
"doc_id": doc_id,
|
||||
"chunk_index": i,
|
||||
"total_chunks": len(chunks)
|
||||
}
|
||||
chunk_metadata.append(meta)
|
||||
|
||||
# Add to ChromaDB
|
||||
self.documents_collection.add(
|
||||
ids=chunk_ids,
|
||||
embeddings=embeddings.tolist(),
|
||||
documents=chunks,
|
||||
metadatas=chunk_metadata
|
||||
)
|
||||
|
||||
logger.info(f"Added {len(chunks)} chunks for document '{doc_id}'")
|
||||
return len(chunks)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add document '{doc_id}': {e}")
|
||||
raise
|
||||
|
||||
def search(self, query: str, top_k: int = 3) -> list[dict]:
|
||||
"""Search for relevant document chunks.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
top_k: Number of results to return (default: 3, max: 100)
|
||||
|
||||
Returns:
|
||||
List of dicts with 'content', 'metadata', and 'distance'
|
||||
|
||||
Raises:
|
||||
ValueError: If top_k exceeds maximum limit
|
||||
"""
|
||||
# Validate top_k
|
||||
if top_k > 100:
|
||||
raise ValueError(f"top_k cannot exceed 100 (got: {top_k})")
|
||||
|
||||
if not query or not query.strip():
|
||||
logger.warning("Empty search query provided")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Generate embedding for query
|
||||
logger.debug(f"Generating embedding for query: {query[:50]}...")
|
||||
query_embedding = self.embedding_model.encode([query])
|
||||
|
||||
# Query ChromaDB
|
||||
results = self.documents_collection.query(
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=top_k,
|
||||
include=["documents", "metadatas", "distances"]
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
|
||||
if results and results.get("documents"):
|
||||
documents = results["documents"][0]
|
||||
metadatas = results["metadatas"][0] if results.get("metadatas") else []
|
||||
distances = results["distances"][0] if results.get("distances") else []
|
||||
|
||||
for i, content in enumerate(documents):
|
||||
formatted_results.append({
|
||||
"content": content,
|
||||
"metadata": metadatas[i] if i < len(metadatas) else {},
|
||||
"distance": distances[i] if i < len(distances) else None
|
||||
})
|
||||
|
||||
logger.info(f"Found {len(formatted_results)} results for query")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Search failed: {e}")
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release resources and cleanup the RAG engine.
|
||||
|
||||
This method should be called when the engine is no longer needed
|
||||
to free up memory and other resources.
|
||||
"""
|
||||
logger.info("Closing RAG engine and releasing resources")
|
||||
self.embedding_model = None
|
||||
self.documents_collection = None
|
||||
self.chroma_client = None
|
||||
78
backend/app/main.py
Normal file
78
backend/app/main.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""FastAPI application entry point for ERP AI Assistant.
|
||||
|
||||
This module creates and configures the main FastAPI application instance.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import get_settings
|
||||
from app.api import analyze, generate, execute
|
||||
|
||||
# Get application settings
|
||||
settings = get_settings()
|
||||
|
||||
# Create FastAPI application instance
|
||||
app = FastAPI(
|
||||
title=settings.APP_NAME,
|
||||
version="1.0.0",
|
||||
debug=settings.DEBUG,
|
||||
description="AI-powered assistant for ERP platform configuration"
|
||||
)
|
||||
|
||||
# Configure CORS middleware for frontend communication
|
||||
# For development: allow all origins with port 5173
|
||||
# For production: configure specific origins in environment
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"] if settings.DEBUG else [
|
||||
"http://localhost:5173",
|
||||
"http://127.0.0.1:5173",
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Register API routers
|
||||
app.include_router(analyze.router, prefix="/api/v1", tags=["Analysis"])
|
||||
app.include_router(generate.router, prefix="/api/v1", tags=["Generation"])
|
||||
app.include_router(execute.router, prefix="/api/v1", tags=["Execution"])
|
||||
|
||||
|
||||
@app.get("/", tags=["Root"])
|
||||
async def root() -> dict:
|
||||
"""Root endpoint returning application info.
|
||||
|
||||
Returns:
|
||||
Dictionary with application name, version, and status
|
||||
"""
|
||||
return {
|
||||
"message": settings.APP_NAME,
|
||||
"version": "1.0.0",
|
||||
"status": "running"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health", tags=["Health"])
|
||||
async def health_check() -> dict:
|
||||
"""Health check endpoint for monitoring.
|
||||
|
||||
Returns:
|
||||
Dictionary with health status
|
||||
"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=settings.DEBUG
|
||||
)
|
||||
1
backend/app/models/__init__.py
Normal file
1
backend/app/models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Pydantic models for ERP AI Assistant."""
|
||||
89
backend/app/models/request.py
Normal file
89
backend/app/models/request.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Request models for ERP AI Assistant API.
|
||||
|
||||
This module defines Pydantic models for API request validation.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyzeRequest(BaseModel):
|
||||
"""Request model for requirement analysis.
|
||||
|
||||
Attributes:
|
||||
input_type: Type of input - 'natural_language' or 'structured'
|
||||
content: Requirement content text
|
||||
session_id: Optional session ID for context continuity
|
||||
"""
|
||||
|
||||
input_type: str = Field(
|
||||
...,
|
||||
description="输入类型: natural_language | structured"
|
||||
)
|
||||
content: str = Field(..., description="需求内容")
|
||||
session_id: Optional[str] = Field(None, description="会话ID")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"input_type": "natural_language",
|
||||
"content": "创建一个销售订单管理页面",
|
||||
"session_id": "session-123"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
"""Request model for config generation.
|
||||
|
||||
Attributes:
|
||||
session_id: Session ID from previous analysis
|
||||
requirements: Structured requirements from analysis
|
||||
"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
requirements: dict = Field(..., description="结构化需求")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"session_id": "session-123",
|
||||
"requirements": {
|
||||
"功能名称": "销售订单管理",
|
||||
"功能类型": "列表页面"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
"""Request model for config execution.
|
||||
|
||||
Attributes:
|
||||
session_id: Session ID for tracking
|
||||
confirmed: User confirmation flag
|
||||
backup_enabled: Whether to create backup before execution
|
||||
"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
confirmed: bool = Field(False, description="用户确认标识")
|
||||
backup_enabled: bool = Field(True, description="是否启用备份")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"session_id": "session-123",
|
||||
"confirmed": True,
|
||||
"backup_enabled": True
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
133
backend/app/models/response.py
Normal file
133
backend/app/models/response.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Response models for ERP AI Assistant API.
|
||||
|
||||
This module defines Pydantic models for API response formatting.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AnalyzeResponse(BaseModel):
|
||||
"""Response model for requirement analysis.
|
||||
|
||||
Attributes:
|
||||
session_id: Session ID for this analysis
|
||||
status: Processing status
|
||||
data: Structured requirement analysis result
|
||||
"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
status: str = Field(..., description="处理状态")
|
||||
data: dict = Field(..., description="结构化需求分析结果")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"session_id": "session-123",
|
||||
"status": "success",
|
||||
"data": {
|
||||
"功能名称": "销售订单管理",
|
||||
"功能类型": "列表页面"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
"""Response model for config generation.
|
||||
|
||||
Attributes:
|
||||
session_id: Session ID
|
||||
status: Processing status
|
||||
data: Generated SQL configuration
|
||||
"""
|
||||
|
||||
session_id: str = Field(..., description="会话ID")
|
||||
status: str = Field(..., description="处理状态")
|
||||
data: dict = Field(..., description="生成的SQL配置")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"session_id": "session-123",
|
||||
"status": "success",
|
||||
"data": {
|
||||
"sql_list": ["INSERT INTO SYS_FORM ...", "INSERT INTO SYS_MENU ..."]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ExecuteResponse(BaseModel):
|
||||
"""Response model for config execution.
|
||||
|
||||
Attributes:
|
||||
execution_id: Unique execution ID
|
||||
status: Execution status
|
||||
message: Human-readable result message
|
||||
"""
|
||||
|
||||
execution_id: str = Field(..., description="执行ID")
|
||||
status: str = Field(..., description="执行状态")
|
||||
message: str = Field(..., description="执行结果消息")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"execution_id": "exec-456",
|
||||
"status": "success",
|
||||
"message": "成功执行 5 条SQL"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Response model for errors.
|
||||
|
||||
Attributes:
|
||||
error: Error details dictionary
|
||||
"""
|
||||
|
||||
error: dict = Field(..., description="错误详情")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [
|
||||
{
|
||||
"error": {
|
||||
"code": "VALIDATION_ERROR",
|
||||
"message": "Invalid input",
|
||||
"details": "Field 'input_type' must be 'natural_language' or 'structured'"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Response model for health check.
|
||||
|
||||
Attributes:
|
||||
status: Service health status
|
||||
version: Application version
|
||||
"""
|
||||
|
||||
status: str = Field(..., description="服务状态")
|
||||
version: str = Field(default="1.0.0", description="版本号")
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"examples": [{"status": "healthy", "version": "1.0.0"}]
|
||||
}
|
||||
}
|
||||
1
backend/app/services/__init__.py
Normal file
1
backend/app/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service modules for ERP AI Assistant."""
|
||||
135
backend/app/services/config_service.py
Normal file
135
backend/app/services/config_service.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Config Generation Service.
|
||||
|
||||
This module provides the ConfigService class for generating ERP platform
|
||||
configuration SQL based on structured requirements.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.core.ai_engine import ClaudeEngine
|
||||
from app.core.rag_engine import RAGEngine
|
||||
from app.core.prompts import SYSTEM_PROMPT, GENERATE_PROMPT_TEMPLATE
|
||||
from app.core.db_engine import DatabaseEngine
|
||||
|
||||
|
||||
class ConfigService:
|
||||
"""Service for generating ERP platform configuration.
|
||||
|
||||
This service uses Claude AI with RAG knowledge retrieval to generate
|
||||
SQL configuration statements based on structured requirements.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config service with required engines."""
|
||||
self.ai_engine = ClaudeEngine()
|
||||
self.rag_engine = RAGEngine()
|
||||
self.db_engine = DatabaseEngine()
|
||||
logger.info("ConfigService initialized")
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
requirements: Dict[str, Any],
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate configuration SQL based on requirements.
|
||||
|
||||
Args:
|
||||
requirements: Structured requirement specification
|
||||
session_id: Session ID for tracking
|
||||
|
||||
Returns:
|
||||
Configuration plan with SQL statements
|
||||
|
||||
Raises:
|
||||
ValueError: If requirements are invalid
|
||||
Exception: If generation fails
|
||||
"""
|
||||
if not requirements:
|
||||
raise ValueError("Requirements cannot be empty")
|
||||
|
||||
function_name = requirements.get("功能名称", "Unknown")
|
||||
logger.info(f"[{session_id}] Starting config generation for: {function_name}")
|
||||
|
||||
try:
|
||||
# Step 1: Retrieve platform rules for form type
|
||||
form_type = requirements.get("窗体类型", "0")
|
||||
logger.debug(f"[{session_id}] Retrieving platform rules for form type: {form_type}")
|
||||
platform_rules = self._get_platform_rules(form_type)
|
||||
logger.info(f"[{session_id}] Retrieved platform rules")
|
||||
|
||||
# Step 2: Retrieve similar cases
|
||||
logger.debug(f"[{session_id}] Retrieving similar cases")
|
||||
similar_cases = self._get_similar_cases(function_name)
|
||||
logger.info(f"[{session_id}] Retrieved similar cases")
|
||||
|
||||
# Step 3: Build prompt
|
||||
prompt = GENERATE_PROMPT_TEMPLATE.format(
|
||||
requirements=str(requirements),
|
||||
platform_rules=platform_rules,
|
||||
similar_cases=similar_cases
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": SYSTEM_PROMPT},
|
||||
{"role": "assistant", "content": "我已了解,请提供需求信息。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# Step 4: Call Claude API
|
||||
logger.debug(f"[{session_id}] Calling Claude API for config generation")
|
||||
response = await self.ai_engine.call_claude(messages, temperature=0.5)
|
||||
|
||||
# Step 5: Parse JSON response
|
||||
result = self.ai_engine.parse_json_response(response)
|
||||
|
||||
logger.success(f"[{session_id}] Config generation completed")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Config generation failed: {e}")
|
||||
raise
|
||||
|
||||
def _get_platform_rules(self, form_type: str) -> str:
|
||||
"""Retrieve platform configuration rules for specific form type.
|
||||
|
||||
Args:
|
||||
form_type: Form type code
|
||||
|
||||
Returns:
|
||||
Platform rules text
|
||||
"""
|
||||
try:
|
||||
results = self.rag_engine.search(
|
||||
f"窗体类型{form_type}配置规则",
|
||||
top_k=2
|
||||
)
|
||||
if not results:
|
||||
return "未找到相关配置规则"
|
||||
|
||||
return "\n\n".join([r["content"] for r in results])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve platform rules: {e}")
|
||||
return "无法获取平台配置规则"
|
||||
|
||||
def _get_similar_cases(self, keywords: str) -> str:
|
||||
"""Retrieve similar configuration cases from knowledge base.
|
||||
|
||||
Args:
|
||||
keywords: Search keywords
|
||||
|
||||
Returns:
|
||||
Similar cases text
|
||||
"""
|
||||
try:
|
||||
results = self.rag_engine.search(keywords, top_k=2)
|
||||
if not results:
|
||||
return "未找到相似案例"
|
||||
|
||||
return "\n\n".join([r["content"] for r in results])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to retrieve similar cases: {e}")
|
||||
return "无法获取相似案例"
|
||||
147
backend/app/services/requirement_service.py
Normal file
147
backend/app/services/requirement_service.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Requirement Analysis Service.
|
||||
|
||||
This module provides the RequirementService class for analyzing user requirements
|
||||
using Claude AI with RAG knowledge retrieval.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from app.core.ai_engine import ClaudeEngine
|
||||
from app.core.rag_engine import RAGEngine
|
||||
from app.core.prompts import SYSTEM_PROMPT, ANALYZE_PROMPT_TEMPLATE
|
||||
from app.core.db_engine import DatabaseEngine
|
||||
|
||||
|
||||
class RequirementService:
|
||||
"""Service for analyzing user requirements with AI assistance.
|
||||
|
||||
This service integrates Claude AI, RAG knowledge retrieval, and database
|
||||
metadata to provide comprehensive requirement analysis.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize requirement service with required engines."""
|
||||
self.ai_engine = ClaudeEngine()
|
||||
self.rag_engine = RAGEngine()
|
||||
self.db_engine = DatabaseEngine()
|
||||
logger.info("RequirementService initialized")
|
||||
|
||||
async def analyze(
|
||||
self,
|
||||
user_input: str,
|
||||
session_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user requirement and generate structured specification.
|
||||
|
||||
Args:
|
||||
user_input: Natural language requirement from user
|
||||
session_id: Session ID for context management (auto-generated if None)
|
||||
|
||||
Returns:
|
||||
Structured requirement document as dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If user_input is empty
|
||||
Exception: If AI analysis fails
|
||||
"""
|
||||
# Validate input
|
||||
if not user_input or not user_input.strip():
|
||||
raise ValueError("User input cannot be empty")
|
||||
|
||||
# Generate session ID if not provided
|
||||
session_id = session_id or str(uuid.uuid4())
|
||||
logger.info(f"[{session_id}] Starting requirement analysis: {user_input[:50]}...")
|
||||
|
||||
try:
|
||||
# Step 1: Retrieve relevant knowledge from RAG
|
||||
logger.debug(f"[{session_id}] Searching knowledge base")
|
||||
knowledge_results = self.rag_engine.search(user_input, top_k=3)
|
||||
knowledge_context = self._format_knowledge_context(knowledge_results)
|
||||
logger.info(f"[{session_id}] Retrieved {len(knowledge_results)} knowledge chunks")
|
||||
|
||||
# Step 2: Query existing database tables
|
||||
logger.debug(f"[{session_id}] Querying existing tables")
|
||||
existing_tables = self._get_existing_tables(user_input)
|
||||
logger.info(f"[{session_id}] Retrieved existing table information")
|
||||
|
||||
# Step 3: Build prompt
|
||||
prompt = ANALYZE_PROMPT_TEMPLATE.format(
|
||||
user_input=user_input,
|
||||
knowledge_context=knowledge_context,
|
||||
existing_tables=existing_tables
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": SYSTEM_PROMPT},
|
||||
{"role": "assistant", "content": "我已了解平台配置规范,请告诉我您的需求。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
# Step 4: Call Claude API
|
||||
logger.debug(f"[{session_id}] Calling Claude API")
|
||||
response = await self.ai_engine.call_claude(messages, temperature=0.7)
|
||||
|
||||
# Step 5: Parse JSON response
|
||||
result = self.ai_engine.parse_json_response(response)
|
||||
|
||||
function_name = result.get("功能名称", "Unknown")
|
||||
logger.success(f"[{session_id}] Requirement analysis completed: {function_name}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{session_id}] Requirement analysis failed: {e}")
|
||||
raise
|
||||
|
||||
def _format_knowledge_context(self, knowledge_results: list) -> str:
|
||||
"""Format knowledge search results into context string.
|
||||
|
||||
Args:
|
||||
knowledge_results: List of knowledge search results
|
||||
|
||||
Returns:
|
||||
Formatted knowledge context string
|
||||
"""
|
||||
if not knowledge_results:
|
||||
return "未找到相关知识库内容"
|
||||
|
||||
context_parts = []
|
||||
for result in knowledge_results:
|
||||
source = result.get("metadata", {}).get("source", "文档")
|
||||
content = result.get("content", "")
|
||||
if content:
|
||||
context_parts.append(f"【{source}】\n{content}")
|
||||
|
||||
return "\n\n".join(context_parts) if context_parts else "未找到相关知识库内容"
|
||||
|
||||
def _get_existing_tables(self, user_input: str) -> str:
|
||||
"""Query existing database tables relevant to user input.
|
||||
|
||||
Args:
|
||||
user_input: User requirement text
|
||||
|
||||
Returns:
|
||||
Formatted string listing existing tables
|
||||
"""
|
||||
try:
|
||||
# Query top 10 tables (simplified version - could be enhanced with relevance matching)
|
||||
sql = """
|
||||
SELECT TOP 10 TABLE_NAME
|
||||
FROM INFORMATION_SCHEMA.TABLES
|
||||
WHERE TABLE_TYPE = 'BASE TABLE'
|
||||
ORDER BY TABLE_NAME
|
||||
"""
|
||||
tables = self.db_engine.execute_sql(sql)
|
||||
|
||||
if not tables:
|
||||
return "未找到现有数据表"
|
||||
|
||||
table_list = [f"- {t[0]}" for t in tables]
|
||||
return "\n".join(table_list)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to query existing tables: {e}")
|
||||
return "无法获取现有表信息"
|
||||
Reference in New Issue
Block a user