feat: implement ERP AI Assistant Phase 1
Backend (FastAPI + SQLAlchemy + Claude API + RAG): - Config management with Pydantic v2 - Database engine with connection pooling and SQL injection prevention - AI engine with Claude API integration (support custom base URL) - RAG engine with ChromaDB and sentence-transformers - Requirement analysis service - Config generation service - Executor engine with SQL validation - REST API endpoints: /analyze, /generate, /execute Frontend (Vue 3 + Element Plus + Pinia): - Complete 3-step workflow: analyze → generate → execute - Step indicator with progress visualization - Analysis result display with field table - SQL preview with monospace font - Execute confirmation dialog with safety warning - Execution result display - State management with Pinia - API service integration Security: - SQL injection prevention with parameterized queries - Dangerous SQL operation blocking - Database password URL encoding - Transaction auto-rollback - Pydantic config validation Features: - Natural language requirement analysis - Automated SQL configuration generation - Safe execution with human review - LAN access support - Custom Claude API endpoint support Documentation: - README with quick start guide - Quick start guide - LAN access configuration - Dependency fixes guide - Claude API configuration - Git operation guide - Implementation report Dependencies fixed: - numpy<2.0.0 for chromadb compatibility - sentence-transformers==2.7.0 for huggingface_hub compatibility Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
1
backend/tests/__init__.py
Normal file
1
backend/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests"""
|
||||
26
backend/tests/conftest.py
Normal file
26
backend/tests/conftest.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
from app.config import get_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings():
|
||||
"""Test settings"""
|
||||
return get_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_engine(mocker):
|
||||
"""Mock database engine"""
|
||||
from app.core.db_engine import DatabaseEngine
|
||||
return mocker.MagicMock(spec=DatabaseEngine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_engine(mocker):
|
||||
"""Mock AI engine with default parse_json_response behavior"""
|
||||
from app.core.ai_engine import ClaudeEngine
|
||||
mock_engine = mocker.MagicMock(spec=ClaudeEngine)
|
||||
# Default behavior: returns a test function dict
|
||||
# Can be overridden in individual tests via mock_ai_engine.parse_json_response.return_value = {...}
|
||||
mock_engine.parse_json_response.return_value = {"function_name": "test_function"}
|
||||
return mock_engine
|
||||
148
backend/tests/test_ai_engine.py
Normal file
148
backend/tests/test_ai_engine.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import pytest
|
||||
from app.core.ai_engine import ClaudeEngine
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings(mocker):
|
||||
"""Mock settings for test isolation."""
|
||||
mock_settings = mocker.MagicMock()
|
||||
mock_settings.ANTHROPIC_API_KEY = "test-key"
|
||||
mock_settings.CLAUDE_MODEL = "claude-sonnet-4-6"
|
||||
mock_settings.CLAUDE_MAX_TOKENS = 1024
|
||||
mock_settings.CLAUDE_TEMPERATURE = 0.7
|
||||
mocker.patch('app.core.ai_engine.get_settings', return_value=mock_settings)
|
||||
return mock_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_anthropic_client(mocker):
|
||||
"""Mock Anthropic async client."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mocker.patch('app.core.ai_engine.anthropic.AsyncAnthropic', return_value=mock_client)
|
||||
return mock_client
|
||||
|
||||
|
||||
def test_claude_engine_init(mocker, mock_settings):
|
||||
"""测试 Claude 引擎初始化"""
|
||||
engine = ClaudeEngine()
|
||||
assert engine.client is not None
|
||||
assert engine.model == "claude-sonnet-4-6"
|
||||
assert engine.max_tokens == 1024
|
||||
assert engine.temperature == 0.7
|
||||
|
||||
|
||||
def test_parse_json_response(mocker, mock_settings):
|
||||
"""测试 JSON 解析"""
|
||||
engine = ClaudeEngine()
|
||||
|
||||
# 测试纯 JSON
|
||||
json_str = '{"name": "test", "value": 123}'
|
||||
result = engine.parse_json_response(json_str)
|
||||
assert result["name"] == "test"
|
||||
assert result["value"] == 123
|
||||
|
||||
# 测试 markdown 代码块
|
||||
md_str = '```json\n{"name": "test"}\n```'
|
||||
result = engine.parse_json_response(md_str)
|
||||
assert result["name"] == "test"
|
||||
|
||||
|
||||
def test_parse_json_response_empty_content(mocker, mock_settings):
|
||||
"""测试空内容错误处理"""
|
||||
engine = ClaudeEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Empty content provided"):
|
||||
engine.parse_json_response("")
|
||||
|
||||
with pytest.raises(ValueError, match="Empty content provided"):
|
||||
engine.parse_json_response(" ")
|
||||
|
||||
|
||||
def test_parse_json_response_invalid_json(mocker, mock_settings):
|
||||
"""测试无效 JSON 错误处理"""
|
||||
engine = ClaudeEngine()
|
||||
|
||||
# 无效 JSON 且无法提取任何代码块
|
||||
invalid_str = "This is not JSON at all"
|
||||
with pytest.raises(ValueError, match="无法解析 Claude 返回的 JSON"):
|
||||
engine.parse_json_response(invalid_str)
|
||||
|
||||
# 无效的 JSON 代码块
|
||||
invalid_json_block = '```json\n{invalid json}\n```'
|
||||
with pytest.raises(ValueError, match="无法解析 Claude 返回的 JSON"):
|
||||
engine.parse_json_response(invalid_json_block)
|
||||
|
||||
|
||||
def test_parse_json_response_code_block(mocker, mock_settings):
|
||||
"""测试代码块 JSON 解析"""
|
||||
engine = ClaudeEngine()
|
||||
|
||||
# 普通代码块(无 json 标签)
|
||||
code_block = '```\n{"status": "ok"}\n```'
|
||||
result = engine.parse_json_response(code_block)
|
||||
assert result["status"] == "ok"
|
||||
|
||||
|
||||
def test_parse_json_response_nested_json(mocker, mock_settings):
|
||||
"""测试嵌套 JSON 解析"""
|
||||
engine = ClaudeEngine()
|
||||
|
||||
# 带有一些额外文本的 JSON
|
||||
text_with_json = 'Some text before {"key": "value"} and after'
|
||||
result = engine.parse_json_response(text_with_json)
|
||||
assert result["key"] == "value"
|
||||
|
||||
# 嵌套 JSON
|
||||
nested_json = '{"outer": {"inner": "value"}}'
|
||||
result = engine.parse_json_response(nested_json)
|
||||
assert result["outer"]["inner"] == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_claude(mocker, mock_settings, mock_anthropic_client):
|
||||
"""测试 call_claude 方法"""
|
||||
# 设置 mock 响应
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.content = [mocker.MagicMock(text="Hello, I am Claude")]
|
||||
mock_anthropic_client.messages.create.return_value = mock_response
|
||||
|
||||
engine = ClaudeEngine()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await engine.call_claude(messages)
|
||||
|
||||
assert result == "Hello, I am Claude"
|
||||
mock_anthropic_client.messages.create.assert_called_once_with(
|
||||
model="claude-sonnet-4-6",
|
||||
max_tokens=1024,
|
||||
temperature=0.7,
|
||||
messages=messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_claude_with_temperature(mocker, mock_settings, mock_anthropic_client):
|
||||
"""测试 call_claude 带温度参数"""
|
||||
mock_response = mocker.MagicMock()
|
||||
mock_response.content = [mocker.MagicMock(text="Response")]
|
||||
mock_anthropic_client.messages.create.return_value = mock_response
|
||||
|
||||
engine = ClaudeEngine()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
result = await engine.call_claude(messages, temperature=1.5)
|
||||
|
||||
assert result == "Response"
|
||||
call_args = mock_anthropic_client.messages.create.call_args
|
||||
assert call_args.kwargs["temperature"] == 1.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_claude_error(mocker, mock_settings, mock_anthropic_client):
|
||||
"""测试 call_claude 错误处理"""
|
||||
# 设置 mock 抛出异常
|
||||
mock_anthropic_client.messages.create.side_effect = Exception("API Error")
|
||||
|
||||
engine = ClaudeEngine()
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await engine.call_claude(messages)
|
||||
94
backend/tests/test_config_service.py
Normal file
94
backend/tests/test_config_service.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tests for Config Service.
|
||||
|
||||
This module tests the ConfigService class for configuration generation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from app.services.config_service import ConfigService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_config():
|
||||
"""Test config generation with mocked dependencies."""
|
||||
with patch('app.services.config_service.ClaudeEngine') as MockClaudeEngine, \
|
||||
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
|
||||
patch('app.services.config_service.DatabaseEngine') as MockDBEngine:
|
||||
|
||||
# Setup mocks
|
||||
mock_ai_engine = MagicMock()
|
||||
mock_ai_engine.call_claude = AsyncMock(return_value='{"配置方案": {"sql_list": ["INSERT INTO SYS_FORM..."]}}')
|
||||
mock_ai_engine.parse_json_response = MagicMock(return_value={
|
||||
"配置方案": {
|
||||
"sql_list": ["INSERT INTO SYS_FORM VALUES (...)"]
|
||||
}
|
||||
})
|
||||
MockClaudeEngine.return_value = mock_ai_engine
|
||||
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.search = MagicMock(return_value=[
|
||||
{"content": "Sample rule", "metadata": {}}
|
||||
])
|
||||
MockRAGEngine.return_value = mock_rag_engine
|
||||
|
||||
MockDBEngine.return_value = MagicMock()
|
||||
|
||||
# Create service and test
|
||||
service = ConfigService()
|
||||
|
||||
requirements = {
|
||||
"功能名称": "销售订单",
|
||||
"功能号建议": "11-001",
|
||||
"窗体类型": "5",
|
||||
"主表名建议": "SA_ORDER",
|
||||
"主表字段": [
|
||||
{"字段名": "订单号", "字段类型": "varchar(50)", "必填": True}
|
||||
]
|
||||
}
|
||||
|
||||
result = await service.generate(requirements, "test-session")
|
||||
|
||||
assert result is not None
|
||||
assert "配置方案" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_platform_rules():
|
||||
"""Test platform rules retrieval."""
|
||||
with patch('app.services.config_service.ClaudeEngine'), \
|
||||
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
|
||||
patch('app.services.config_service.DatabaseEngine'):
|
||||
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.search = MagicMock(return_value=[
|
||||
{"content": "Rule 1"},
|
||||
{"content": "Rule 2"}
|
||||
])
|
||||
MockRAGEngine.return_value = mock_rag_engine
|
||||
|
||||
service = ConfigService()
|
||||
rules = service._get_platform_rules("5")
|
||||
|
||||
assert "Rule 1" in rules
|
||||
assert "Rule 2" in rules
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_similar_cases():
|
||||
"""Test similar cases retrieval."""
|
||||
with patch('app.services.config_service.ClaudeEngine'), \
|
||||
patch('app.services.config_service.RAGEngine') as MockRAGEngine, \
|
||||
patch('app.services.config_service.DatabaseEngine'):
|
||||
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.search = MagicMock(return_value=[
|
||||
{"content": "Case 1"},
|
||||
{"content": "Case 2"}
|
||||
])
|
||||
MockRAGEngine.return_value = mock_rag_engine
|
||||
|
||||
service = ConfigService()
|
||||
cases = service._get_similar_cases("销售订单")
|
||||
|
||||
assert "Case 1" in cases
|
||||
assert "Case 2" in cases
|
||||
25
backend/tests/test_db_engine.py
Normal file
25
backend/tests/test_db_engine.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import pytest
|
||||
from app.core.db_engine import DatabaseEngine
|
||||
|
||||
|
||||
def test_database_engine_init():
|
||||
"""测试数据库引擎初始化"""
|
||||
engine = DatabaseEngine()
|
||||
assert engine.engine is not None
|
||||
assert engine.Session is not None
|
||||
|
||||
|
||||
def test_execute_sql_select():
|
||||
"""测试执行 SELECT 查询"""
|
||||
engine = DatabaseEngine()
|
||||
result = engine.execute_sql("SELECT 1 AS test")
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_table_exists():
|
||||
"""测试表存在性检查"""
|
||||
engine = DatabaseEngine()
|
||||
# 假设 SYS_FORM 表存在
|
||||
exists = engine.table_exists("SYS_FORM")
|
||||
assert exists is True
|
||||
141
backend/tests/test_executor.py
Normal file
141
backend/tests/test_executor.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for Config Executor.
|
||||
|
||||
This module tests the ConfigExecutor class for SQL validation and execution.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from app.core.executor import ConfigExecutor
|
||||
|
||||
|
||||
def test_executor_init():
|
||||
"""Test executor initialization."""
|
||||
executor = ConfigExecutor()
|
||||
assert executor.db_engine is not None
|
||||
|
||||
|
||||
def test_validate_sql_safe():
|
||||
"""Test validation of safe SQL statements."""
|
||||
executor = ConfigExecutor()
|
||||
|
||||
# Test SELECT
|
||||
is_valid, msg = executor.validate_sql("SELECT * FROM SYS_FORM")
|
||||
assert is_valid is True
|
||||
assert "验证通过" in msg
|
||||
|
||||
# Test INSERT
|
||||
is_valid, msg = executor.validate_sql(
|
||||
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (1, 'Test')"
|
||||
)
|
||||
assert is_valid is True
|
||||
|
||||
# Test UPDATE with safe WHERE clause
|
||||
is_valid, msg = executor.validate_sql(
|
||||
"UPDATE SYS_FORM SET FORM_NAME = 'Test' WHERE IKEY = 1"
|
||||
)
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
def test_validate_sql_dangerous():
|
||||
"""Test validation catches dangerous SQL statements."""
|
||||
executor = ConfigExecutor()
|
||||
|
||||
# Test DROP DATABASE
|
||||
is_valid, msg = executor.validate_sql("DROP DATABASE test_db")
|
||||
assert is_valid is False
|
||||
assert "危险操作" in msg
|
||||
|
||||
# Test DROP TABLE
|
||||
is_valid, msg = executor.validate_sql("DROP TABLE users")
|
||||
assert is_valid is False
|
||||
assert "危险操作" in msg
|
||||
|
||||
# Test TRUNCATE
|
||||
is_valid, msg = executor.validate_sql("TRUNCATE TABLE important_data")
|
||||
assert is_valid is False
|
||||
assert "危险操作" in msg
|
||||
|
||||
# Test DELETE without WHERE
|
||||
is_valid, msg = executor.validate_sql("DELETE FROM users")
|
||||
assert is_valid is False
|
||||
assert "危险操作" in msg
|
||||
|
||||
|
||||
def test_execute_config_success():
|
||||
"""Test successful execution of SQL list."""
|
||||
with patch('app.core.executor.DatabaseEngine') as MockDBEngine:
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_transaction = MagicMock(return_value=True)
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
executor = ConfigExecutor()
|
||||
|
||||
sql_list = [
|
||||
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (1, 'Test1')",
|
||||
"INSERT INTO SYS_FORM (IKEY, FORM_NAME) VALUES (2, 'Test2')"
|
||||
]
|
||||
|
||||
result = executor.execute_config(sql_list, session_id="test-session")
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(result["executed"]) == 2
|
||||
assert result["failed"] is None
|
||||
assert "成功执行" in result["message"]
|
||||
|
||||
|
||||
def test_execute_config_validation_failure():
|
||||
"""Test execution fails when SQL validation fails."""
|
||||
executor = ConfigExecutor()
|
||||
|
||||
sql_list = [
|
||||
"SELECT * FROM SYS_FORM",
|
||||
"DROP DATABASE test" # Dangerous SQL
|
||||
]
|
||||
|
||||
result = executor.execute_config(sql_list, session_id="test-session")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["failed"] is not None
|
||||
assert "验证失败" in result["message"]
|
||||
assert len(result["executed"]) == 0
|
||||
|
||||
|
||||
def test_execute_config_execution_failure():
|
||||
"""Test execution handles database errors."""
|
||||
with patch('app.core.executor.DatabaseEngine') as MockDBEngine:
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_transaction = MagicMock(
|
||||
side_effect=Exception("Database connection error")
|
||||
)
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
executor = ConfigExecutor()
|
||||
|
||||
sql_list = ["SELECT * FROM SYS_FORM"]
|
||||
|
||||
result = executor.execute_config(sql_list, session_id="test-session")
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["failed"] is not None
|
||||
assert "执行失败" in result["message"]
|
||||
|
||||
|
||||
def test_rollback_placeholder():
|
||||
"""Test rollback functionality placeholder."""
|
||||
executor = ConfigExecutor()
|
||||
|
||||
result = executor.rollback(session_id="test-session")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "待实现" in result["message"]
|
||||
|
||||
|
||||
def test_dangerous_keywords_exist():
|
||||
"""Test that dangerous keywords list is properly defined."""
|
||||
executor = ConfigExecutor()
|
||||
|
||||
assert hasattr(executor, 'DANGEROUS_KEYWORDS')
|
||||
assert len(executor.DANGEROUS_KEYWORDS) > 0
|
||||
assert "DROP DATABASE" in executor.DANGEROUS_KEYWORDS
|
||||
assert "DROP TABLE" in executor.DANGEROUS_KEYWORDS
|
||||
assert "TRUNCATE TABLE" in executor.DANGEROUS_KEYWORDS
|
||||
36
backend/tests/test_prompts.py
Normal file
36
backend/tests/test_prompts.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Tests for prompt templates."""
|
||||
|
||||
from app.core.prompts import SYSTEM_PROMPT, ANALYZE_PROMPT_TEMPLATE, GENERATE_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
def test_system_prompt_exists():
|
||||
"""测试系统 Prompt 存在"""
|
||||
assert SYSTEM_PROMPT is not None
|
||||
assert len(SYSTEM_PROMPT) > 100
|
||||
# Test for stable characteristics rather than exact wording
|
||||
assert "ERP" in SYSTEM_PROMPT
|
||||
assert "配置" in SYSTEM_PROMPT
|
||||
|
||||
|
||||
def test_analyze_prompt_template():
|
||||
"""测试需求解析模板"""
|
||||
rendered = ANALYZE_PROMPT_TEMPLATE.format(
|
||||
user_input="创建销售订单",
|
||||
knowledge_context="测试知识",
|
||||
existing_tables="测试表"
|
||||
)
|
||||
assert "创建销售订单" in rendered
|
||||
assert "测试知识" in rendered
|
||||
assert "测试表" in rendered
|
||||
|
||||
|
||||
def test_generate_prompt_template():
|
||||
"""测试配置生成模板"""
|
||||
rendered = GENERATE_PROMPT_TEMPLATE.format(
|
||||
requirements="需求分析结果",
|
||||
platform_rules="平台规则",
|
||||
similar_cases="类似案例"
|
||||
)
|
||||
assert "需求分析结果" in rendered
|
||||
assert "平台规则" in rendered
|
||||
assert "类似案例" in rendered
|
||||
156
backend/tests/test_rag_engine.py
Normal file
156
backend/tests/test_rag_engine.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for RAG Engine.
|
||||
|
||||
This module tests the RAGEngine class for document indexing and retrieval.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from app.core.rag_engine import RAGEngine
|
||||
|
||||
|
||||
def test_rag_engine_init():
|
||||
"""Test RAG engine initialization."""
|
||||
engine = RAGEngine()
|
||||
assert engine.chroma_client is not None
|
||||
assert engine.documents_collection is not None
|
||||
assert engine.chunk_size > 0
|
||||
assert engine.chunk_overlap >= 0
|
||||
assert engine.chunk_overlap < engine.chunk_size
|
||||
|
||||
|
||||
def test_split_text_basic():
|
||||
"""Test basic text splitting functionality."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Test with text longer than chunk_size
|
||||
long_text = "A" * 1000
|
||||
chunks = engine._split_text(long_text)
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(len(chunk) <= engine.chunk_size for chunk in chunks)
|
||||
assert all(chunk.strip() for chunk in chunks) # No empty chunks
|
||||
|
||||
|
||||
def test_split_text_empty():
|
||||
"""Test splitting empty text."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Test with empty text
|
||||
assert engine._split_text("") == []
|
||||
assert engine._split_text(" ") == []
|
||||
|
||||
|
||||
def test_split_text_overlap():
|
||||
"""Test text splitting with overlap."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Test that chunks overlap correctly
|
||||
text = "A" * 600
|
||||
chunks = engine._split_text(text)
|
||||
|
||||
if len(chunks) > 1:
|
||||
# Check overlap exists between consecutive chunks
|
||||
# (This is a basic check; actual overlap content depends on implementation)
|
||||
assert len(chunks) > 1
|
||||
|
||||
|
||||
def test_add_document_success():
|
||||
"""Test adding a document to the knowledge base."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Mock the collection's add method
|
||||
engine.documents_collection.add = MagicMock()
|
||||
|
||||
doc_id = "test_doc_1"
|
||||
content = "This is a test document for the knowledge base."
|
||||
metadata = {"source": "test", "type": "sample"}
|
||||
|
||||
num_chunks = engine.add_document(doc_id, content, metadata)
|
||||
|
||||
assert num_chunks > 0
|
||||
assert engine.documents_collection.add.called
|
||||
|
||||
# Verify add was called with correct parameters
|
||||
call_args = engine.documents_collection.add.call_args
|
||||
assert "ids" in call_args.kwargs
|
||||
assert "embeddings" in call_args.kwargs
|
||||
assert "documents" in call_args.kwargs
|
||||
assert "metadatas" in call_args.kwargs
|
||||
|
||||
|
||||
def test_add_document_empty_content():
|
||||
"""Test that adding empty document raises ValueError."""
|
||||
engine = RAGEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot add empty document"):
|
||||
engine.add_document("test_doc", "")
|
||||
|
||||
with pytest.raises(ValueError, match="Cannot add empty document"):
|
||||
engine.add_document("test_doc", " ")
|
||||
|
||||
|
||||
def test_search_basic():
|
||||
"""Test basic search functionality."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Mock the collection's query method
|
||||
mock_results = {
|
||||
"documents": [["Result 1", "Result 2"]],
|
||||
"metadatas": [[{"doc_id": "doc1"}, {"doc_id": "doc2"}]],
|
||||
"distances": [[0.1, 0.2]]
|
||||
}
|
||||
engine.documents_collection.query = MagicMock(return_value=mock_results)
|
||||
|
||||
results = engine.search("test query", top_k=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["content"] == "Result 1"
|
||||
assert results[0]["metadata"]["doc_id"] == "doc1"
|
||||
assert results[0]["distance"] == 0.1
|
||||
assert engine.documents_collection.query.called
|
||||
|
||||
|
||||
def test_search_empty_query():
|
||||
"""Test search with empty query returns empty results."""
|
||||
engine = RAGEngine()
|
||||
|
||||
results = engine.search("", top_k=3)
|
||||
assert results == []
|
||||
|
||||
results = engine.search(" ", top_k=3)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_search_invalid_top_k():
|
||||
"""Test that search with invalid top_k raises ValueError."""
|
||||
engine = RAGEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="top_k cannot exceed 100"):
|
||||
engine.search("test", top_k=101)
|
||||
|
||||
|
||||
def test_delete_chunks_for_doc():
|
||||
"""Test deleting chunks for a document."""
|
||||
engine = RAGEngine()
|
||||
|
||||
# Mock the get and delete methods
|
||||
engine.documents_collection.get = MagicMock(return_value={
|
||||
"ids": ["doc1_chunk_0", "doc1_chunk_1"]
|
||||
})
|
||||
engine.documents_collection.delete = MagicMock()
|
||||
|
||||
engine._delete_chunks_for_doc("doc1")
|
||||
|
||||
assert engine.documents_collection.get.called
|
||||
assert engine.documents_collection.delete.called
|
||||
|
||||
|
||||
def test_close():
|
||||
"""Test closing the RAG engine releases resources."""
|
||||
engine = RAGEngine()
|
||||
|
||||
engine.close()
|
||||
|
||||
assert engine.embedding_model is None
|
||||
assert engine.documents_collection is None
|
||||
assert engine.chroma_client is None
|
||||
116
backend/tests/test_requirement_service.py
Normal file
116
backend/tests/test_requirement_service.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for Requirement Service.
|
||||
|
||||
This module tests the RequirementService class for requirement analysis.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
from app.services.requirement_service import RequirementService
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_requirement():
|
||||
"""Test requirement analysis with mocked dependencies."""
|
||||
# Create service with mocked engines
|
||||
with patch('app.services.requirement_service.ClaudeEngine') as MockClaudeEngine, \
|
||||
patch('app.services.requirement_service.RAGEngine') as MockRAGEngine, \
|
||||
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
|
||||
|
||||
# Setup mocks
|
||||
mock_ai_engine = MagicMock()
|
||||
mock_ai_engine.call_claude = AsyncMock(return_value='{"功能名称": "销售订单管理", "功能类型": "列表页面"}')
|
||||
mock_ai_engine.parse_json_response = MagicMock(return_value={
|
||||
"功能名称": "销售订单管理",
|
||||
"功能类型": "列表页面"
|
||||
})
|
||||
MockClaudeEngine.return_value = mock_ai_engine
|
||||
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.search = MagicMock(return_value=[
|
||||
{"content": "Sample knowledge", "metadata": {"source": "docs"}}
|
||||
])
|
||||
MockRAGEngine.return_value = mock_rag_engine
|
||||
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_sql = MagicMock(return_value=[("SYS_FORM",), ("SYS_MENU",)])
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
# Create service and test
|
||||
service = RequirementService()
|
||||
result = await service.analyze(
|
||||
user_input="创建一个销售订单管理页面",
|
||||
session_id="test-session"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "功能名称" in result
|
||||
assert result["功能名称"] == "销售订单管理"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_requirement_without_session_id():
|
||||
"""Test that session_id is auto-generated if not provided."""
|
||||
with patch('app.services.requirement_service.ClaudeEngine') as MockClaudeEngine, \
|
||||
patch('app.services.requirement_service.RAGEngine') as MockRAGEngine, \
|
||||
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
|
||||
|
||||
# Setup mocks
|
||||
mock_ai_engine = MagicMock()
|
||||
mock_ai_engine.call_claude = AsyncMock(return_value='{"功能名称": "测试功能"}')
|
||||
mock_ai_engine.parse_json_response = MagicMock(return_value={"功能名称": "测试功能"})
|
||||
MockClaudeEngine.return_value = mock_ai_engine
|
||||
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.search = MagicMock(return_value=[])
|
||||
MockRAGEngine.return_value = mock_rag_engine
|
||||
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_sql = MagicMock(return_value=[])
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
# Test without session_id
|
||||
service = RequirementService()
|
||||
result = await service.analyze(user_input="测试输入")
|
||||
|
||||
assert result is not None
|
||||
assert "功能名称" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_tables_success():
|
||||
"""Test successful retrieval of existing tables."""
|
||||
with patch('app.services.requirement_service.ClaudeEngine'), \
|
||||
patch('app.services.requirement_service.RAGEngine'), \
|
||||
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
|
||||
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_sql = MagicMock(return_value=[
|
||||
("SYS_FORM",),
|
||||
("SYS_MENU",),
|
||||
("SYS_USER",)
|
||||
])
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
service = RequirementService()
|
||||
tables = service._get_existing_tables("测试")
|
||||
|
||||
assert "SYS_FORM" in tables
|
||||
assert "SYS_MENU" in tables
|
||||
assert "SYS_USER" in tables
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_tables_failure():
|
||||
"""Test handling of database query failure."""
|
||||
with patch('app.services.requirement_service.ClaudeEngine'), \
|
||||
patch('app.services.requirement_service.RAGEngine'), \
|
||||
patch('app.services.requirement_service.DatabaseEngine') as MockDBEngine:
|
||||
|
||||
mock_db_engine = MagicMock()
|
||||
mock_db_engine.execute_sql = MagicMock(side_effect=Exception("DB Error"))
|
||||
MockDBEngine.return_value = mock_db_engine
|
||||
|
||||
service = RequirementService()
|
||||
tables = service._get_existing_tables("测试")
|
||||
|
||||
assert "无法获取现有表信息" in tables
|
||||
Reference in New Issue
Block a user