|
|
""" |
|
|
Simple test for generate_image method |
|
|
""" |
|
|
import os |
|
|
import sys |
|
|
from unittest.mock import Mock, patch |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) |
|
|
|
|
|
from agent import VisualizationAgent |
|
|
from PIL import Image |
|
|
import io |
|
|
|
|
|
|
|
|
def test_generate_image_with_mock(): |
|
|
"""Test generate_image method with mocked API response""" |
|
|
print("\n=== Testing generate_image Method ===") |
|
|
|
|
|
try: |
|
|
|
|
|
test_image = Image.new('RGB', (100, 100), color='red') |
|
|
|
|
|
|
|
|
mock_part = Mock() |
|
|
mock_part.inline_data = Mock() |
|
|
mock_part.as_image = Mock(return_value=test_image) |
|
|
|
|
|
|
|
|
mock_response = Mock() |
|
|
mock_response.parts = [mock_part] |
|
|
|
|
|
|
|
|
with patch('agent.genai.Client') as MockClient: |
|
|
mock_client_instance = Mock() |
|
|
mock_client_instance.models.generate_content.return_value = mock_response |
|
|
MockClient.return_value = mock_client_instance |
|
|
|
|
|
|
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
|
|
|
|
|
|
result = agent.generate_image("Test prompt for disaster-resistant building") |
|
|
|
|
|
|
|
|
assert result["success"] is True, f"Expected success=True, got {result}" |
|
|
assert "image_path" in result, "Missing image_path in result" |
|
|
assert result["image_path"] is not None, "image_path is None" |
|
|
assert os.path.exists(result["image_path"]), f"Image file not created: {result['image_path']}" |
|
|
|
|
|
print("β generate_image method works correctly") |
|
|
print(f" Image saved to: {result['image_path']}") |
|
|
|
|
|
|
|
|
if os.path.exists(result["image_path"]): |
|
|
os.remove(result["image_path"]) |
|
|
print(" Cleaned up test image") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_generate_image_empty_prompt(): |
|
|
"""Test generate_image with empty prompt""" |
|
|
print("\n=== Testing generate_image with Empty Prompt ===") |
|
|
|
|
|
try: |
|
|
with patch('agent.genai.Client') as MockClient: |
|
|
mock_client_instance = Mock() |
|
|
MockClient.return_value = mock_client_instance |
|
|
|
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
|
|
|
|
|
|
result = agent.generate_image("") |
|
|
|
|
|
assert result["success"] is False, "Should fail with empty prompt" |
|
|
assert "error" in result, "Missing error field" |
|
|
assert result["error"]["code"] == "INVALID_INPUT", f"Wrong error code: {result['error']['code']}" |
|
|
|
|
|
print("β Empty prompt validation works") |
|
|
print(f" Error: {result['error']['message']}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_generate_image_no_image_data(): |
|
|
"""Test generate_image when API returns no image data""" |
|
|
print("\n=== Testing generate_image with No Image Data ===") |
|
|
|
|
|
try: |
|
|
|
|
|
mock_part = Mock() |
|
|
mock_part.inline_data = None |
|
|
|
|
|
mock_response = Mock() |
|
|
mock_response.parts = [mock_part] |
|
|
|
|
|
with patch('agent.genai.Client') as MockClient: |
|
|
mock_client_instance = Mock() |
|
|
mock_client_instance.models.generate_content.return_value = mock_response |
|
|
MockClient.return_value = mock_client_instance |
|
|
|
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
|
|
|
result = agent.generate_image("Test prompt") |
|
|
|
|
|
assert result["success"] is False, "Should fail when no image data" |
|
|
assert "error" in result, "Missing error field" |
|
|
assert result["error"]["code"] == "NO_IMAGE_DATA", f"Wrong error code: {result['error']['code']}" |
|
|
|
|
|
print("β No image data handling works") |
|
|
print(f" Error: {result['error']['message']}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_generate_image_with_config(): |
|
|
"""Test generate_image with custom config""" |
|
|
print("\n=== Testing generate_image with Custom Config ===") |
|
|
|
|
|
try: |
|
|
|
|
|
test_image = Image.new('RGB', (100, 100), color='blue') |
|
|
|
|
|
|
|
|
mock_part = Mock() |
|
|
mock_part.inline_data = Mock() |
|
|
mock_part.as_image = Mock(return_value=test_image) |
|
|
|
|
|
|
|
|
mock_response = Mock() |
|
|
mock_response.parts = [mock_part] |
|
|
|
|
|
with patch('agent.genai.Client') as MockClient: |
|
|
mock_client_instance = Mock() |
|
|
mock_client_instance.models.generate_content.return_value = mock_response |
|
|
MockClient.return_value = mock_client_instance |
|
|
|
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
|
|
|
|
|
|
config = { |
|
|
"model": "gemini-3-pro-image-preview", |
|
|
"aspect_ratio": "16:9" |
|
|
} |
|
|
|
|
|
result = agent.generate_image("Test prompt", config=config) |
|
|
|
|
|
assert result["success"] is True, f"Expected success=True, got {result}" |
|
|
|
|
|
|
|
|
call_args = mock_client_instance.models.generate_content.call_args |
|
|
assert call_args[1]["model"] == "gemini-3-pro-image-preview", "Model not passed correctly" |
|
|
|
|
|
print("β Custom config handling works") |
|
|
print(f" Model used: {config['model']}") |
|
|
|
|
|
|
|
|
if result.get("image_path") and os.path.exists(result["image_path"]): |
|
|
os.remove(result["image_path"]) |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 70) |
|
|
print("Testing generate_image Implementation") |
|
|
print("=" * 70) |
|
|
|
|
|
tests = [ |
|
|
test_generate_image_with_mock, |
|
|
test_generate_image_empty_prompt, |
|
|
test_generate_image_no_image_data, |
|
|
test_generate_image_with_config, |
|
|
] |
|
|
|
|
|
passed = 0 |
|
|
failed = 0 |
|
|
|
|
|
for test in tests: |
|
|
try: |
|
|
if test(): |
|
|
passed += 1 |
|
|
else: |
|
|
failed += 1 |
|
|
except Exception as e: |
|
|
print(f"β Test crashed: {e}") |
|
|
failed += 1 |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print(f"Results: {passed} passed, {failed} failed") |
|
|
print("=" * 70) |
|
|
|
|
|
sys.exit(0 if failed == 0 else 1) |
|
|
|