|
|
"""Tests for InMemoryKG prompt functionality.""" |
|
|
|
|
|
from unittest.mock import Mock |
|
|
|
|
|
import pytest |
|
|
|
|
|
from kg_services.knowledge_graph import InMemoryKG |
|
|
from kg_services.ontology import MCPPrompt, MCPTool |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_tool(): |
|
|
"""Create a sample MCPTool for testing.""" |
|
|
return MCPTool( |
|
|
tool_id="test_tool_001", |
|
|
name="Test Tool", |
|
|
description="A tool for testing purposes", |
|
|
tags=["test", "utility"], |
|
|
invocation_command_stub="test --input {data}", |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def sample_prompt(): |
|
|
"""Create a sample MCPPrompt for testing.""" |
|
|
return MCPPrompt( |
|
|
prompt_id="test_prompt_001", |
|
|
name="Test Prompt", |
|
|
description="A prompt for testing purposes", |
|
|
target_tool_id="test_tool_001", |
|
|
template_string="Please process this: {{input_data}}", |
|
|
tags=["test", "basic"], |
|
|
input_variables=["input_data"], |
|
|
use_case="Testing functionality", |
|
|
difficulty_level="beginner", |
|
|
example_inputs={"input_data": "sample data"}, |
|
|
) |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def kg_with_sample_data(sample_tool, sample_prompt): |
|
|
"""Create InMemoryKG with sample tool and prompt data.""" |
|
|
kg = InMemoryKG() |
|
|
kg.add_tool(sample_tool) |
|
|
kg.add_prompt(sample_prompt) |
|
|
return kg |
|
|
|
|
|
|
|
|
def test_add_prompt(sample_prompt): |
|
|
"""Test adding a prompt to the knowledge graph.""" |
|
|
kg = InMemoryKG() |
|
|
kg.add_prompt(sample_prompt) |
|
|
|
|
|
assert len(kg.prompts) == 1 |
|
|
assert sample_prompt.prompt_id in kg.prompts |
|
|
assert kg.prompts[sample_prompt.prompt_id] == sample_prompt |
|
|
|
|
|
|
|
|
def test_get_prompt_by_id(kg_with_sample_data, sample_prompt): |
|
|
"""Test retrieving a prompt by its ID.""" |
|
|
retrieved_prompt = kg_with_sample_data.get_prompt_by_id(sample_prompt.prompt_id) |
|
|
assert retrieved_prompt == sample_prompt |
|
|
|
|
|
|
|
|
non_existent = kg_with_sample_data.get_prompt_by_id("non_existent") |
|
|
assert non_existent is None |
|
|
|
|
|
|
|
|
def test_get_all_prompts(kg_with_sample_data, sample_prompt): |
|
|
"""Test getting all prompts.""" |
|
|
all_prompts = kg_with_sample_data.get_all_prompts() |
|
|
assert len(all_prompts) == 1 |
|
|
assert sample_prompt in all_prompts |
|
|
|
|
|
|
|
|
def test_find_prompts_by_tags(): |
|
|
"""Test finding prompts by tags.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
prompt1 = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="P1", |
|
|
description="Desc1", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
tags=["nlp", "text"], |
|
|
) |
|
|
prompt2 = MCPPrompt( |
|
|
prompt_id="p2", |
|
|
name="P2", |
|
|
description="Desc2", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
tags=["vision", "image"], |
|
|
) |
|
|
prompt3 = MCPPrompt( |
|
|
prompt_id="p3", |
|
|
name="P3", |
|
|
description="Desc3", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
tags=["nlp", "analysis"], |
|
|
) |
|
|
|
|
|
kg.add_prompt(prompt1) |
|
|
kg.add_prompt(prompt2) |
|
|
kg.add_prompt(prompt3) |
|
|
|
|
|
|
|
|
nlp_prompts = kg.find_prompts_by_tags(["nlp"]) |
|
|
assert len(nlp_prompts) == 2 |
|
|
assert prompt1 in nlp_prompts |
|
|
assert prompt3 in nlp_prompts |
|
|
|
|
|
|
|
|
multi_tag_prompts = kg.find_prompts_by_tags(["vision", "analysis"]) |
|
|
assert len(multi_tag_prompts) == 2 |
|
|
assert prompt2 in multi_tag_prompts |
|
|
assert prompt3 in multi_tag_prompts |
|
|
|
|
|
|
|
|
def test_find_prompts_by_tool_id(): |
|
|
"""Test finding prompts by target tool ID.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
prompt1 = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="P1", |
|
|
description="Desc1", |
|
|
target_tool_id="tool_a", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
prompt2 = MCPPrompt( |
|
|
prompt_id="p2", |
|
|
name="P2", |
|
|
description="Desc2", |
|
|
target_tool_id="tool_b", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
prompt3 = MCPPrompt( |
|
|
prompt_id="p3", |
|
|
name="P3", |
|
|
description="Desc3", |
|
|
target_tool_id="tool_a", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
|
|
|
kg.add_prompt(prompt1) |
|
|
kg.add_prompt(prompt2) |
|
|
kg.add_prompt(prompt3) |
|
|
|
|
|
|
|
|
tool_a_prompts = kg.find_prompts_by_tool_id("tool_a") |
|
|
assert len(tool_a_prompts) == 2 |
|
|
assert prompt1 in tool_a_prompts |
|
|
assert prompt3 in tool_a_prompts |
|
|
|
|
|
|
|
|
tool_b_prompts = kg.find_prompts_by_tool_id("tool_b") |
|
|
assert len(tool_b_prompts) == 1 |
|
|
assert prompt2 in tool_b_prompts |
|
|
|
|
|
|
|
|
no_prompts = kg.find_prompts_by_tool_id("non_existent") |
|
|
assert len(no_prompts) == 0 |
|
|
|
|
|
|
|
|
def test_find_prompts_by_difficulty(): |
|
|
"""Test finding prompts by difficulty level.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
prompt1 = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="P1", |
|
|
description="Desc1", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
difficulty_level="beginner", |
|
|
) |
|
|
prompt2 = MCPPrompt( |
|
|
prompt_id="p2", |
|
|
name="P2", |
|
|
description="Desc2", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
difficulty_level="advanced", |
|
|
) |
|
|
prompt3 = MCPPrompt( |
|
|
prompt_id="p3", |
|
|
name="P3", |
|
|
description="Desc3", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
difficulty_level="beginner", |
|
|
) |
|
|
|
|
|
kg.add_prompt(prompt1) |
|
|
kg.add_prompt(prompt2) |
|
|
kg.add_prompt(prompt3) |
|
|
|
|
|
|
|
|
beginner_prompts = kg.find_prompts_by_difficulty("beginner") |
|
|
assert len(beginner_prompts) == 2 |
|
|
assert prompt1 in beginner_prompts |
|
|
assert prompt3 in beginner_prompts |
|
|
|
|
|
|
|
|
advanced_prompts = kg.find_prompts_by_difficulty("advanced") |
|
|
assert len(advanced_prompts) == 1 |
|
|
assert prompt2 in advanced_prompts |
|
|
|
|
|
|
|
|
def test_get_all_prompt_tags(): |
|
|
"""Test getting all unique prompt tags.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
prompt1 = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="P1", |
|
|
description="Desc1", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
tags=["nlp", "text", "basic"], |
|
|
) |
|
|
prompt2 = MCPPrompt( |
|
|
prompt_id="p2", |
|
|
name="P2", |
|
|
description="Desc2", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
tags=["vision", "image", "basic"], |
|
|
) |
|
|
|
|
|
kg.add_prompt(prompt1) |
|
|
kg.add_prompt(prompt2) |
|
|
|
|
|
all_tags = kg.get_all_prompt_tags() |
|
|
expected_tags = {"nlp", "text", "basic", "vision", "image"} |
|
|
assert all_tags == expected_tags |
|
|
|
|
|
|
|
|
def test_find_similar_prompts(): |
|
|
"""Test finding similar prompts using vector similarity.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
kg.prompt_embeddings = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]] |
|
|
kg.prompt_ids_for_vectors = ["p1", "p2", "p3"] |
|
|
|
|
|
|
|
|
query_embedding = [0.9, 0.1, 0.0] |
|
|
similar_prompts = kg.find_similar_prompts(query_embedding, top_k=2) |
|
|
|
|
|
assert len(similar_prompts) == 2 |
|
|
assert similar_prompts[0] == "p1" |
|
|
|
|
|
|
|
|
def test_find_similar_prompts_for_tool(): |
|
|
"""Test finding similar prompts for a specific tool.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
prompt1 = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="P1", |
|
|
description="Desc1", |
|
|
target_tool_id="tool_a", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
prompt2 = MCPPrompt( |
|
|
prompt_id="p2", |
|
|
name="P2", |
|
|
description="Desc2", |
|
|
target_tool_id="tool_b", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
prompt3 = MCPPrompt( |
|
|
prompt_id="p3", |
|
|
name="P3", |
|
|
description="Desc3", |
|
|
target_tool_id="tool_a", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
|
|
|
kg.add_prompt(prompt1) |
|
|
kg.add_prompt(prompt2) |
|
|
kg.add_prompt(prompt3) |
|
|
|
|
|
|
|
|
kg.prompt_embeddings = [[1.0, 0.0], [0.0, 1.0], [0.8, 0.2]] |
|
|
kg.prompt_ids_for_vectors = ["p1", "p2", "p3"] |
|
|
|
|
|
|
|
|
query_embedding = [0.9, 0.1] |
|
|
tool_a_prompts = kg.find_similar_prompts_for_tool( |
|
|
query_embedding, "tool_a", top_k=2 |
|
|
) |
|
|
|
|
|
assert len(tool_a_prompts) == 2 |
|
|
assert "p1" in tool_a_prompts |
|
|
assert "p3" in tool_a_prompts |
|
|
assert "p2" not in tool_a_prompts |
|
|
|
|
|
|
|
|
def test_load_prompts_from_json(tmp_path): |
|
|
"""Test loading prompts from JSON file.""" |
|
|
|
|
|
prompt_data = [ |
|
|
{ |
|
|
"prompt_id": "test_prompt_001", |
|
|
"name": "Test Prompt", |
|
|
"description": "A test prompt", |
|
|
"target_tool_id": "test_tool", |
|
|
"template_string": "Process: {{input}}", |
|
|
"tags": ["test"], |
|
|
"input_variables": ["input"], |
|
|
"use_case": "Testing", |
|
|
"difficulty_level": "beginner", |
|
|
"example_inputs": {"input": "test data"}, |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
json_file = tmp_path / "test_prompts.json" |
|
|
import json |
|
|
|
|
|
with json_file.open("w") as f: |
|
|
json.dump(prompt_data, f) |
|
|
|
|
|
|
|
|
kg = InMemoryKG() |
|
|
success = kg.load_prompts_from_json(json_file) |
|
|
|
|
|
assert success is True |
|
|
assert len(kg.prompts) == 1 |
|
|
assert "test_prompt_001" in kg.prompts |
|
|
|
|
|
loaded_prompt = kg.get_prompt_by_id("test_prompt_001") |
|
|
assert loaded_prompt.name == "Test Prompt" |
|
|
assert loaded_prompt.target_tool_id == "test_tool" |
|
|
|
|
|
|
|
|
def test_load_prompts_from_json_invalid_file(): |
|
|
"""Test loading prompts from non-existent file.""" |
|
|
kg = InMemoryKG() |
|
|
success = kg.load_prompts_from_json("non_existent.json") |
|
|
assert success is False |
|
|
assert len(kg.prompts) == 0 |
|
|
|
|
|
|
|
|
def test_build_vector_index_with_prompts(): |
|
|
"""Test building vector index including prompts.""" |
|
|
kg = InMemoryKG() |
|
|
|
|
|
|
|
|
tool = MCPTool( |
|
|
tool_id="t1", |
|
|
name="Tool 1", |
|
|
description="Test tool", |
|
|
tags=["test"], |
|
|
invocation_command_stub="test", |
|
|
) |
|
|
prompt = MCPPrompt( |
|
|
prompt_id="p1", |
|
|
name="Prompt 1", |
|
|
description="Test prompt", |
|
|
target_tool_id="t1", |
|
|
template_string="{{input}}", |
|
|
) |
|
|
|
|
|
kg.add_tool(tool) |
|
|
kg.add_prompt(prompt) |
|
|
|
|
|
|
|
|
mock_embedder = Mock() |
|
|
mock_embedder.get_embedding.return_value = [0.1, 0.2, 0.3] |
|
|
|
|
|
|
|
|
success = kg.build_vector_index(mock_embedder) |
|
|
|
|
|
assert success is True |
|
|
assert len(kg.tool_embeddings) == 1 |
|
|
assert len(kg.prompt_embeddings) == 1 |
|
|
assert len(kg.tool_ids_for_vectors) == 1 |
|
|
assert len(kg.prompt_ids_for_vectors) == 1 |
|
|
|
|
|
|
|
|
assert mock_embedder.get_embedding.call_count == 2 |
|
|
|