|
|
""" |
|
|
Test script for VisualizationAgent initialization (Task 3) |
|
|
Tests Requirements: 2.1, 3.1, 3.2 |
|
|
""" |
|
|
import os |
|
|
import sys |
|
|
import tempfile |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) |
|
|
|
|
|
from agent import VisualizationAgent |
|
|
|
|
|
|
|
|
def test_initialization_with_api_key(): |
|
|
"""Test initialization with explicit API key""" |
|
|
print("\n=== Test 1: Initialization with explicit API key ===") |
|
|
try: |
|
|
agent = VisualizationAgent(api_key="test_api_key_123") |
|
|
assert agent.api_key == "test_api_key_123", "API key not set correctly" |
|
|
assert agent.client is not None, "Client not initialized" |
|
|
assert agent.model == "gemini-2.5-flash-image", "Default model not set correctly" |
|
|
assert agent.output_dir == "./generated_images", "Default output_dir not set correctly" |
|
|
print("β Initialization with explicit API key successful") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_initialization_from_env_gemini(): |
|
|
"""Test initialization from GEMINI_API_KEY environment variable""" |
|
|
print("\n=== Test 2: Initialization from GEMINI_API_KEY ===") |
|
|
try: |
|
|
|
|
|
os.environ["GEMINI_API_KEY"] = "env_gemini_key_456" |
|
|
|
|
|
agent = VisualizationAgent() |
|
|
assert agent.api_key == "env_gemini_key_456", "API key not read from GEMINI_API_KEY" |
|
|
assert agent.client is not None, "Client not initialized" |
|
|
print("β Initialization from GEMINI_API_KEY successful") |
|
|
|
|
|
|
|
|
del os.environ["GEMINI_API_KEY"] |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
if "GEMINI_API_KEY" in os.environ: |
|
|
del os.environ["GEMINI_API_KEY"] |
|
|
return False |
|
|
|
|
|
|
|
|
def test_initialization_from_env_google(): |
|
|
"""Test initialization from GOOGLE_API_KEY environment variable""" |
|
|
print("\n=== Test 3: Initialization from GOOGLE_API_KEY ===") |
|
|
try: |
|
|
|
|
|
os.environ["GOOGLE_API_KEY"] = "env_google_key_789" |
|
|
|
|
|
agent = VisualizationAgent() |
|
|
assert agent.api_key == "env_google_key_789", "API key not read from GOOGLE_API_KEY" |
|
|
assert agent.client is not None, "Client not initialized" |
|
|
print("β Initialization from GOOGLE_API_KEY successful") |
|
|
|
|
|
|
|
|
del os.environ["GOOGLE_API_KEY"] |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
if "GOOGLE_API_KEY" in os.environ: |
|
|
del os.environ["GOOGLE_API_KEY"] |
|
|
return False |
|
|
|
|
|
|
|
|
def test_initialization_no_api_key(): |
|
|
"""Test that initialization fails without API key""" |
|
|
print("\n=== Test 4: Initialization without API key (should fail) ===") |
|
|
try: |
|
|
|
|
|
gemini_key = os.environ.pop("GEMINI_API_KEY", None) |
|
|
google_key = os.environ.pop("GOOGLE_API_KEY", None) |
|
|
|
|
|
try: |
|
|
agent = VisualizationAgent() |
|
|
print("β Should have raised ValueError for missing API key") |
|
|
return False |
|
|
except ValueError as e: |
|
|
assert "API key is required" in str(e), f"Wrong error message: {e}" |
|
|
print("β Correctly raises ValueError when API key is missing") |
|
|
return True |
|
|
finally: |
|
|
|
|
|
if gemini_key: |
|
|
os.environ["GEMINI_API_KEY"] = gemini_key |
|
|
if google_key: |
|
|
os.environ["GOOGLE_API_KEY"] = google_key |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_custom_model(): |
|
|
"""Test initialization with custom model""" |
|
|
print("\n=== Test 5: Initialization with custom model ===") |
|
|
try: |
|
|
agent = VisualizationAgent( |
|
|
api_key="test_key", |
|
|
model="gemini-3-pro-image-preview" |
|
|
) |
|
|
assert agent.model == "gemini-3-pro-image-preview", "Custom model not set correctly" |
|
|
print("β Custom model set successfully") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_default_model(): |
|
|
"""Test that default model is gemini-2.5-flash-image""" |
|
|
print("\n=== Test 6: Default model is gemini-2.5-flash-image ===") |
|
|
try: |
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
assert agent.model == "gemini-2.5-flash-image", f"Default model should be gemini-2.5-flash-image, got {agent.model}" |
|
|
print("β Default model is gemini-2.5-flash-image") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_custom_output_dir(): |
|
|
"""Test initialization with custom output directory""" |
|
|
print("\n=== Test 7: Initialization with custom output directory ===") |
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
custom_dir = os.path.join(temp_dir, "custom_output") |
|
|
|
|
|
agent = VisualizationAgent( |
|
|
api_key="test_key", |
|
|
output_dir=custom_dir |
|
|
) |
|
|
assert agent.output_dir == custom_dir, "Custom output_dir not set correctly" |
|
|
assert os.path.exists(custom_dir), "Output directory was not created" |
|
|
assert os.path.isdir(custom_dir), "Output directory is not a directory" |
|
|
print("β Custom output directory set and created successfully") |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_default_output_dir(): |
|
|
"""Test that default output directory is ./generated_images""" |
|
|
print("\n=== Test 8: Default output directory ===") |
|
|
try: |
|
|
agent = VisualizationAgent(api_key="test_key") |
|
|
assert agent.output_dir == "./generated_images", f"Default output_dir should be ./generated_images, got {agent.output_dir}" |
|
|
print("β Default output directory is ./generated_images") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_output_dir_creation(): |
|
|
"""Test that output directory is created if it doesn't exist""" |
|
|
print("\n=== Test 9: Output directory creation ===") |
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
test_dir = os.path.join(temp_dir, "test_output", "nested", "dir") |
|
|
|
|
|
|
|
|
assert not os.path.exists(test_dir), "Test directory should not exist yet" |
|
|
|
|
|
agent = VisualizationAgent( |
|
|
api_key="test_key", |
|
|
output_dir=test_dir |
|
|
) |
|
|
|
|
|
assert os.path.exists(test_dir), "Output directory was not created" |
|
|
assert os.path.isdir(test_dir), "Output directory is not a directory" |
|
|
print("β Output directory created successfully (including nested directories)") |
|
|
|
|
|
|
|
|
shutil.rmtree(temp_dir) |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
def test_client_initialization(): |
|
|
"""Test that genai.Client is properly initialized""" |
|
|
print("\n=== Test 10: genai.Client initialization ===") |
|
|
try: |
|
|
agent = VisualizationAgent(api_key="test_key_abc") |
|
|
assert agent.client is not None, "Client should not be None" |
|
|
assert hasattr(agent.client, 'models'), "Client should have 'models' attribute" |
|
|
print("β genai.Client initialized correctly") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Test failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 70) |
|
|
print("VisualizationAgent Initialization Tests (Task 3)") |
|
|
print("Testing Requirements: 2.1, 3.1, 3.2") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
test_results = [] |
|
|
test_results.append(("Explicit API key", test_initialization_with_api_key())) |
|
|
test_results.append(("GEMINI_API_KEY env var", test_initialization_from_env_gemini())) |
|
|
test_results.append(("GOOGLE_API_KEY env var", test_initialization_from_env_google())) |
|
|
test_results.append(("Missing API key (should fail)", test_initialization_no_api_key())) |
|
|
test_results.append(("Custom model", test_custom_model())) |
|
|
test_results.append(("Default model", test_default_model())) |
|
|
test_results.append(("Custom output directory", test_custom_output_dir())) |
|
|
test_results.append(("Default output directory", test_default_output_dir())) |
|
|
test_results.append(("Output directory creation", test_output_dir_creation())) |
|
|
test_results.append(("genai.Client initialization", test_client_initialization())) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("TEST SUMMARY") |
|
|
print("=" * 70) |
|
|
|
|
|
passed = sum(1 for _, result in test_results if result) |
|
|
total = len(test_results) |
|
|
|
|
|
for test_name, result in test_results: |
|
|
status = "β
PASS" if result else "β FAIL" |
|
|
print(f" {status}: {test_name}") |
|
|
|
|
|
print(f"\n{'=' * 70}") |
|
|
print(f"Total: {passed}/{total} tests passed") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
if passed == total: |
|
|
print("\nβ
All initialization tests passed!") |
|
|
sys.exit(0) |
|
|
else: |
|
|
print(f"\nβ {total - passed} test(s) failed") |
|
|
sys.exit(1) |
|
|
|