Compare commits
2 Commits
07392c13a4
...
09939111a8
| Author | SHA1 | Date | |
|---|---|---|---|
| 09939111a8 | |||
| 4928a894fe |
@@ -59,10 +59,11 @@ def embed(text, input_type):
|
|||||||
assert resp.embeddings.float_ is not None
|
assert resp.embeddings.float_ is not None
|
||||||
return resp.embeddings.float_[0]
|
return resp.embeddings.float_[0]
|
||||||
|
|
||||||
def search(query, roles: list[str], limit: int = 3, max_content_length: int = 1500) -> list[dict]:
|
def search(query, roles: list[str], limit: int = 5) -> list[dict]:
|
||||||
query_embedding = embed(query, 'search_query')
|
query_embedding = embed(query, 'search_query')
|
||||||
|
|
||||||
if not roles:
|
if not roles:
|
||||||
|
# No roles = no results
|
||||||
return []
|
return []
|
||||||
|
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
@@ -70,13 +71,7 @@ def search(query, roles: list[str], limit: int = 3, max_content_length: int = 15
|
|||||||
(roles, query_embedding, limit)
|
(roles, query_embedding, limit)
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
docs = []
|
return [{'key': row['key'], 'content': row['content']} for row in rows]
|
||||||
for row in rows:
|
|
||||||
content = row['content'] or ''
|
|
||||||
if len(content) > max_content_length:
|
|
||||||
content = content[:max_content_length] + '...[truncated, use get_document_page for full text]'
|
|
||||||
docs.append({'key': row['key'], 'content': content})
|
|
||||||
return docs
|
|
||||||
|
|
||||||
@mcp.tool
|
@mcp.tool
|
||||||
def get_cave_location(cave: str, state: str, county: str) -> list[dict]:
|
def get_cave_location(cave: str, state: str, county: str) -> list[dict]:
|
||||||
|
|||||||
@@ -19,19 +19,37 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
CAVE_MCP_URL = os.getenv("CAVE_MCP_URL", "https://mcp.caving.dev/mcp")
|
CAVE_MCP_URL = os.getenv("CAVE_MCP_URL", "https://mcp.caving.dev/mcp")
|
||||||
|
|
||||||
logger.info("Initializing Cavepedia agent...")
|
logger.info(f"Initializing Cavepedia agent with CAVE_MCP_URL={CAVE_MCP_URL}")
|
||||||
|
|
||||||
|
|
||||||
def limit_history(ctx: RunContext[None], messages: list[ModelMessage]) -> list[ModelMessage]:
|
def limit_history(ctx: RunContext[None], messages: list[ModelMessage]) -> list[ModelMessage]:
|
||||||
"""Limit conversation history to manage token usage."""
|
"""Limit history and clean up orphaned tool calls to prevent API errors."""
|
||||||
# Keep last 8 messages for context, but not unlimited
|
from pydantic_ai.messages import ModelResponse, ToolCallPart
|
||||||
return messages[-8:]
|
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# Keep only the last 4 messages
|
||||||
|
messages = messages[-4:]
|
||||||
|
|
||||||
|
# Check if the last message is an assistant response with a tool call
|
||||||
|
# If so, remove it - it's orphaned (no tool result followed)
|
||||||
|
if messages:
|
||||||
|
last_msg = messages[-1]
|
||||||
|
if isinstance(last_msg, ModelResponse):
|
||||||
|
has_tool_call = any(isinstance(part, ToolCallPart) for part in last_msg.parts)
|
||||||
|
if has_tool_call:
|
||||||
|
logger.warning("Removing orphaned tool call from history")
|
||||||
|
return messages[:-1]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
def check_mcp_available(url: str, timeout: float = 5.0) -> bool:
|
def check_mcp_available(url: str, timeout: float = 5.0) -> bool:
|
||||||
"""Check if MCP server is reachable via health endpoint."""
|
"""Check if MCP server is reachable via health endpoint."""
|
||||||
try:
|
try:
|
||||||
# Use the health endpoint instead of the MCP endpoint
|
# Use the health endpoint instead of the MCP endpoint
|
||||||
health_url = url.rsplit("/", 1)[0] + "/health"
|
health_url = url.rsplit("/", 1)[0] + "/health"
|
||||||
|
logger.info(f"Checking MCP health at: {health_url}")
|
||||||
response = httpx.get(health_url, timeout=timeout, follow_redirects=True)
|
response = httpx.get(health_url, timeout=timeout, follow_redirects=True)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return True
|
return True
|
||||||
@@ -41,9 +59,7 @@ def check_mcp_available(url: str, timeout: float = 5.0) -> bool:
|
|||||||
logger.warning(f"MCP server not reachable: {e}")
|
logger.warning(f"MCP server not reachable: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if MCP is available at startup
|
# MCP availability is checked lazily in create_agent()
|
||||||
MCP_AVAILABLE = check_mcp_available(CAVE_MCP_URL)
|
|
||||||
logger.info(f"MCP server available: {MCP_AVAILABLE}")
|
|
||||||
|
|
||||||
AGENT_INSTRUCTIONS = """Caving assistant. Help with exploration, safety, surveying, locations, geology, equipment, history, conservation.
|
AGENT_INSTRUCTIONS = """Caving assistant. Help with exploration, safety, surveying, locations, geology, equipment, history, conservation.
|
||||||
|
|
||||||
@@ -54,14 +70,17 @@ Rules:
|
|||||||
4. Can create ascii diagrams/maps.
|
4. Can create ascii diagrams/maps.
|
||||||
5. Be direct—no sycophantic phrases.
|
5. Be direct—no sycophantic phrases.
|
||||||
6. Keep responses concise.
|
6. Keep responses concise.
|
||||||
7. Search ONCE, then answer with what you found. Do not search repeatedly for the same topic."""
|
7. Use tools sparingly—one search usually suffices. Answer from your knowledge when possible."""
|
||||||
|
|
||||||
|
|
||||||
def create_agent(user_roles: list[str] | None = None):
|
def create_agent(user_roles: list[str] | None = None):
|
||||||
"""Create an agent with MCP tools configured for the given user roles."""
|
"""Create an agent with MCP tools configured for the given user roles."""
|
||||||
toolsets = []
|
toolsets = []
|
||||||
|
|
||||||
if MCP_AVAILABLE and user_roles:
|
# Check MCP availability lazily (each request) to handle startup race conditions
|
||||||
|
mcp_available = check_mcp_available(CAVE_MCP_URL) if user_roles else False
|
||||||
|
|
||||||
|
if mcp_available and user_roles:
|
||||||
try:
|
try:
|
||||||
import json
|
import json
|
||||||
from pydantic_ai.mcp import MCPServerStreamableHTTP
|
from pydantic_ai.mcp import MCPServerStreamableHTTP
|
||||||
@@ -92,7 +111,4 @@ def create_agent(user_roles: list[str] | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Create a default agent for health checks etc
|
|
||||||
agent = create_agent()
|
|
||||||
|
|
||||||
logger.info("Agent module initialized successfully")
|
logger.info("Agent module initialized successfully")
|
||||||
|
|||||||
Reference in New Issue
Block a user