|
|
"""Tests for FastAPI endpoints in app.py.""" |
|
|
|
|
|
from unittest.mock import patch |
|
|
|
|
|
import pytest |
|
|
from fastapi.testclient import TestClient |
|
|
|
|
|
from app import app |
|
|
from kg_services.ontology import MCPPrompt, MCPTool, PlannedStep |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def client(): |
|
|
"""Create test client for FastAPI app.""" |
|
|
return TestClient(app) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_tool(): |
|
|
"""Create sample MCPTool for testing.""" |
|
|
return MCPTool( |
|
|
tool_id="test_tool_v1", |
|
|
name="Test Tool", |
|
|
description="A tool for testing", |
|
|
tags=["test", "utility"], |
|
|
invocation_command_stub="test_command --input {input}", |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_prompt(): |
|
|
"""Create sample MCPPrompt for testing.""" |
|
|
return MCPPrompt( |
|
|
prompt_id="test_prompt_v1", |
|
|
name="Test Prompt", |
|
|
description="A prompt for testing", |
|
|
target_tool_id="test_tool_v1", |
|
|
template_string="Process this: {{input_text}}", |
|
|
tags=["test", "example"], |
|
|
input_variables=["input_text"], |
|
|
difficulty_level="beginner", |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_planned_step(sample_tool, sample_prompt): |
|
|
"""Create sample PlannedStep for testing.""" |
|
|
return PlannedStep(tool=sample_tool, prompt=sample_prompt, relevance_score=0.85) |
|
|
|
|
|
|
|
|
class TestHealthEndpoint: |
|
|
"""Test health check endpoint.""" |
|
|
|
|
|
def test_health_check_success(self, client): |
|
|
"""Test successful health check.""" |
|
|
response = client.get("/health") |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert data["status"] == "healthy" |
|
|
assert "version" in data |
|
|
assert "environment" in data |
|
|
|
|
|
def test_health_check_response_format(self, client): |
|
|
"""Test health check response format.""" |
|
|
response = client.get("/health") |
|
|
data = response.json() |
|
|
|
|
|
|
|
|
required_fields = ["status", "version", "environment"] |
|
|
for field in required_fields: |
|
|
assert field in data |
|
|
assert isinstance(data[field], str) |
|
|
|
|
|
|
|
|
class TestTaskEndpoints: |
|
|
"""Test task management endpoints.""" |
|
|
|
|
|
def test_get_tasks_success(self, client): |
|
|
"""Test getting all tasks.""" |
|
|
response = client.get("/api/tasks") |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert isinstance(data, list) |
|
|
assert len(data) == 1 |
|
|
|
|
|
|
|
|
task = data[0] |
|
|
assert task["id"] == 1 |
|
|
assert task["title"] == "Complete Sprint 0 setup" |
|
|
assert task["status"] == "In Progress" |
|
|
assert isinstance(task["dependencies"], list) |
|
|
|
|
|
def test_create_task_success(self, client): |
|
|
"""Test creating a new task.""" |
|
|
task_data = { |
|
|
"title": "Test Task", |
|
|
"description": "A test task", |
|
|
"dependencies": [1, 2], |
|
|
} |
|
|
|
|
|
response = client.post("/api/tasks", json=task_data) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert data["id"] == 42 |
|
|
assert data["title"] == task_data["title"] |
|
|
assert data["description"] == task_data["description"] |
|
|
assert data["status"] == "Todo" |
|
|
assert data["dependencies"] == task_data["dependencies"] |
|
|
|
|
|
def test_create_task_validation_error(self, client): |
|
|
"""Test task creation with invalid data.""" |
|
|
|
|
|
response = client.post("/api/tasks", json={}) |
|
|
assert response.status_code == 422 |
|
|
|
|
|
def test_get_task_by_id_success(self, client): |
|
|
"""Test getting specific task.""" |
|
|
response = client.get("/api/tasks/1") |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert data["id"] == 1 |
|
|
assert data["title"] == "Complete Sprint 0 setup" |
|
|
|
|
|
def test_get_task_by_id_not_found(self, client): |
|
|
"""Test getting non-existent task.""" |
|
|
response = client.get("/api/tasks/999") |
|
|
|
|
|
assert response.status_code == 404 |
|
|
assert "not found" in response.json()["detail"].lower() |
|
|
|
|
|
|
|
|
class TestToolSuggestionEndpoint: |
|
|
"""Test tool suggestion endpoint.""" |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_suggest_tools_success(self, mock_agent, client, sample_tool): |
|
|
"""Test successful tool suggestion.""" |
|
|
|
|
|
mock_agent.suggest_tools.return_value = [sample_tool] |
|
|
|
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/tools/suggest", json=request_data) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert isinstance(data, list) |
|
|
assert len(data) == 1 |
|
|
|
|
|
|
|
|
tool = data[0] |
|
|
assert tool["tool_id"] == sample_tool.tool_id |
|
|
assert tool["name"] == sample_tool.name |
|
|
assert tool["description"] == sample_tool.description |
|
|
assert tool["tags"] == sample_tool.tags |
|
|
assert tool["invocation_command_stub"] == sample_tool.invocation_command_stub |
|
|
|
|
|
|
|
|
mock_agent.suggest_tools.assert_called_once_with("test query", top_k=3) |
|
|
|
|
|
@patch("app.planner_agent", None) |
|
|
def test_suggest_tools_agent_not_initialized(self, client): |
|
|
"""Test tool suggestion when agent is not initialized.""" |
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/tools/suggest", json=request_data) |
|
|
|
|
|
assert response.status_code == 503 |
|
|
assert "not initialized" in response.json()["detail"].lower() |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_suggest_tools_empty_query(self, mock_agent, client): |
|
|
"""Test tool suggestion with empty query.""" |
|
|
request_data = {"query": "", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/tools/suggest", json=request_data) |
|
|
|
|
|
assert response.status_code == 400 |
|
|
assert "empty" in response.json()["detail"].lower() |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_suggest_tools_exception_handling(self, mock_agent, client): |
|
|
"""Test tool suggestion error handling.""" |
|
|
|
|
|
mock_agent.suggest_tools.side_effect = Exception("Test error") |
|
|
|
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/tools/suggest", json=request_data) |
|
|
|
|
|
assert response.status_code == 500 |
|
|
assert "error" in response.json()["detail"].lower() |
|
|
|
|
|
def test_suggest_tools_validation_errors(self, client): |
|
|
"""Test tool suggestion with invalid data.""" |
|
|
|
|
|
response = client.post( |
|
|
"/api/tools/suggest", json={"query": "test", "top_k": 0} |
|
|
) |
|
|
assert response.status_code == 422 |
|
|
|
|
|
|
|
|
response = client.post("/api/tools/suggest", json={"top_k": 3}) |
|
|
assert response.status_code == 422 |
|
|
|
|
|
|
|
|
class TestPlanGenerationEndpoint: |
|
|
"""Test plan generation endpoint.""" |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_generate_plan_success(self, mock_agent, client, sample_planned_step): |
|
|
"""Test successful plan generation.""" |
|
|
|
|
|
mock_agent.generate_plan.return_value = [sample_planned_step] |
|
|
|
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/plan/generate", json=request_data) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
|
|
|
assert data["query"] == "test query" |
|
|
assert data["total_steps"] == 1 |
|
|
assert isinstance(data["planned_steps"], list) |
|
|
assert len(data["planned_steps"]) == 1 |
|
|
|
|
|
|
|
|
step = data["planned_steps"][0] |
|
|
assert "tool" in step |
|
|
assert "prompt" in step |
|
|
assert "relevance_score" in step |
|
|
assert "summary" in step |
|
|
|
|
|
|
|
|
tool = step["tool"] |
|
|
assert tool["tool_id"] == sample_planned_step.tool.tool_id |
|
|
assert tool["name"] == sample_planned_step.tool.name |
|
|
|
|
|
|
|
|
prompt = step["prompt"] |
|
|
assert prompt["prompt_id"] == sample_planned_step.prompt.prompt_id |
|
|
assert prompt["name"] == sample_planned_step.prompt.name |
|
|
assert prompt["template_string"] == sample_planned_step.prompt.template_string |
|
|
|
|
|
|
|
|
mock_agent.generate_plan.assert_called_once_with("test query", top_k=3) |
|
|
|
|
|
@patch("app.planner_agent", None) |
|
|
def test_generate_plan_agent_not_initialized(self, client): |
|
|
"""Test plan generation when agent is not initialized.""" |
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/plan/generate", json=request_data) |
|
|
|
|
|
assert response.status_code == 503 |
|
|
assert "not initialized" in response.json()["detail"].lower() |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_generate_plan_empty_query(self, mock_agent, client): |
|
|
"""Test plan generation with empty query.""" |
|
|
request_data = {"query": "", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/plan/generate", json=request_data) |
|
|
|
|
|
assert response.status_code == 400 |
|
|
assert "empty" in response.json()["detail"].lower() |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_generate_plan_no_results(self, mock_agent, client): |
|
|
"""Test plan generation with no results.""" |
|
|
|
|
|
mock_agent.generate_plan.return_value = [] |
|
|
|
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/plan/generate", json=request_data) |
|
|
|
|
|
assert response.status_code == 200 |
|
|
data = response.json() |
|
|
|
|
|
assert data["query"] == "test query" |
|
|
assert data["total_steps"] == 0 |
|
|
assert data["planned_steps"] == [] |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_generate_plan_exception_handling(self, mock_agent, client): |
|
|
"""Test plan generation error handling.""" |
|
|
|
|
|
mock_agent.generate_plan.side_effect = Exception("Test error") |
|
|
|
|
|
request_data = {"query": "test query", "top_k": 3} |
|
|
|
|
|
response = client.post("/api/plan/generate", json=request_data) |
|
|
|
|
|
assert response.status_code == 500 |
|
|
assert "error" in response.json()["detail"].lower() |
|
|
|
|
|
@pytest.mark.parametrize( |
|
|
("invalid_data", "expected_status"), |
|
|
[ |
|
|
({"query": "test", "top_k": 11}, 422), |
|
|
({"query": "test", "top_k": 0}, 422), |
|
|
({"query": "test", "top_k": -1}, 422), |
|
|
({"top_k": 3}, 422), |
|
|
({}, 422), |
|
|
({"query": "", "top_k": 3}, 503), |
|
|
], |
|
|
) |
|
|
def test_generate_plan_validation_errors( |
|
|
self, client, invalid_data, expected_status |
|
|
): |
|
|
"""Test plan generation with invalid data.""" |
|
|
response = client.post("/api/plan/generate", json=invalid_data) |
|
|
assert response.status_code == expected_status |
|
|
|
|
|
|
|
|
class TestEndpointIntegration: |
|
|
"""Test endpoint integration scenarios.""" |
|
|
|
|
|
def test_cors_headers(self, client): |
|
|
"""Test CORS headers are properly set.""" |
|
|
|
|
|
response = client.get("/health") |
|
|
|
|
|
|
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
def test_api_documentation_accessible(self, client): |
|
|
"""Test API documentation endpoints are accessible.""" |
|
|
|
|
|
response = client.get("/openapi.json") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
response = client.get("/docs") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
@patch("app.planner_agent") |
|
|
def test_multiple_endpoints_sequence( |
|
|
self, mock_agent, client, sample_tool, sample_planned_step |
|
|
): |
|
|
"""Test using multiple endpoints in sequence.""" |
|
|
|
|
|
mock_agent.suggest_tools.return_value = [sample_tool] |
|
|
mock_agent.generate_plan.return_value = [sample_planned_step] |
|
|
|
|
|
|
|
|
response = client.get("/health") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
response = client.get("/api/tasks") |
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
response = client.post( |
|
|
"/api/tools/suggest", json={"query": "test query", "top_k": 3} |
|
|
) |
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
response = client.post( |
|
|
"/api/plan/generate", json={"query": "test query", "top_k": 3} |
|
|
) |
|
|
assert response.status_code == 200 |
|
|
|
|
|
|
|
|
mock_agent.suggest_tools.assert_called_once() |
|
|
mock_agent.generate_plan.assert_called_once() |
|
|
|