simplify mcp

This commit is contained in:
2025-12-26 02:18:00 +01:00
parent 383e452322
commit 337540496a
2 changed files with 35 additions and 66 deletions

View File

@@ -59,14 +59,22 @@ def embed(text, input_type):
assert resp.embeddings.float_ is not None
return resp.embeddings.float_[0]
def search(query, roles: list[str], top_n: int = 3, max_content_length: int = 1500, priority_prefixes: list[str] | None = None) -> list[dict]:
"""Search with vector similarity, then rerank with Cohere for better relevance."""
@mcp.tool
def search_caving_documents(query: str, priority_prefixes: list[str] | None = None) -> dict:
"""Search caving documents for information about caves, techniques, safety, accidents, history, and more.
Args:
query: Search query
priority_prefixes: Optional list of key prefixes to prioritize (e.g., ['nss/aca'] for rescue topics)
"""
roles = get_user_roles()
if not roles:
return {"results": [], "note": "No results. Answer based on your knowledge."}
query_embedding = embed(query, 'search_query')
if not roles:
return []
# Fetch more candidates for reranking
top_n = 3
candidate_limit = top_n * 4
rows = conn.execute(
'SELECT * FROM embeddings WHERE embedding IS NOT NULL AND role = ANY(%s) ORDER BY embedding <=> %s::vector LIMIT %s',
@@ -74,14 +82,14 @@ def search(query, roles: list[str], top_n: int = 3, max_content_length: int = 15
).fetchall()
if not rows:
return []
return {"results": [], "note": "No results found. Answer based on your knowledge."}
# Rerank with Cohere for better relevance
rerank_resp = co.rerank(
query=query,
documents=[row['content'] or '' for row in rows],
model='rerank-v3.5',
top_n=min(top_n * 2, len(rows)), # Get more for re-sorting after boost
top_n=min(top_n * 2, len(rows)),
)
# Build results with optional priority boost
@@ -94,49 +102,17 @@ def search(query, roles: list[str], top_n: int = 3, max_content_length: int = 15
if priority_prefixes:
key = row['key'] or ''
if any(key.startswith(prefix) for prefix in priority_prefixes):
score = min(1.0, score * 1.3) # 30% boost, capped at 1.0
score = min(1.0, score * 1.3)
content = row['content'] or ''
if len(content) > max_content_length:
content = content[:max_content_length] + '...[truncated, use get_document_page for full text]'
docs.append({'key': row['key'], 'content': content, 'relevance': round(score, 3)})
# Re-sort by boosted score and return top_n
docs.sort(key=lambda x: x['relevance'], reverse=True)
return docs[:top_n]
@mcp.tool
def get_cave_location(cave: str, state: str, county: str) -> list[dict]:
"""Lookup cave location as coordinates."""
roles = get_user_roles()
return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.', roles)
@mcp.tool
def general_caving_information(query: str, priority_prefixes: list[str] | None = None) -> list[dict]:
"""General purpose search for any topic related to caves.
Args:
query: Search query
priority_prefixes: Optional list of key prefixes to prioritize in results (e.g., ['nss/aca'] for rescue topics)
"""
roles = get_user_roles()
return search(query, roles, priority_prefixes=priority_prefixes)
@mcp.tool
def get_document_page(key: str) -> dict:
"""Fetch full content for a document page. Pass the exact 'key' value from search results."""
roles = get_user_roles()
if not roles:
return {"error": "No roles assigned"}
row = conn.execute(
'SELECT key, content FROM embeddings WHERE key = %s AND role = ANY(%s)',
(key, roles)
).fetchone()
if row:
return {"key": row["key"], "content": row["content"]}
return {"error": f"Page not found: {key}"}
return {
"results": docs[:top_n],
"note": "These are ALL available results. Do NOT search again - answer using these results now."
}
@mcp.tool
def get_user_info() -> dict:

View File

@@ -26,7 +26,6 @@ logfire.configure(
logfire.instrument_pydantic_ai()
logfire.instrument_httpx()
from typing import Any
from pydantic_ai import Agent, ModelMessage, RunContext
from pydantic_ai.settings import ModelSettings
from pydantic_ai.mcp import CallToolFunc
@@ -43,8 +42,8 @@ def limit_history(ctx: RunContext[None], messages: list[ModelMessage]) -> list[M
if not messages:
return messages
# Keep only the last 4 messages
messages = messages[-4:]
# Keep last 10 messages
messages = messages[-10:]
# Check if the last message is an assistant response with a tool call
# If so, remove it - it's orphaned (no tool result followed)
@@ -80,34 +79,28 @@ AGENT_INSTRUCTIONS = """Caving assistant. Help with exploration, safety, surveyi
Rules:
1. ALWAYS cite sources in a bulleted list at the end of every reply, even if there's only one. Format them human-readably (e.g., "- The Trog 2021, page 19" not "vpi/trog/2021-trog.pdf/page-19.pdf").
2. Say when uncertain. Never hallucinate.
3. Be safety-conscious.
4. Can create ascii diagrams/maps.
5. Be direct—no sycophantic phrases.
6. Keep responses concise.
7. Use tools sparingly—one search usually suffices.
8. If you hit the search limit, end your reply with an italicized note: *Your question may be too broad. Try asking something more specific.* Do NOT mention "tools" or "tool limits"—the user doesn't know what those are.
9. For rescue, accident, or emergency-related queries, use priority_prefixes=['nss/aca'] when searching to prioritize official accident reports."""
3. Be direct—no sycophantic phrases.
4. Keep responses concise.
5. SEARCH EXACTLY ONCE. After searching, IMMEDIATELY answer using those results. NEVER search again - additional searches are blocked and waste resources.
6. For rescue, accident, or emergency-related queries, use priority_prefixes=['nss/aca'] when searching to prioritize official accident reports."""
SOURCES_ONLY_INSTRUCTIONS = """SOURCES ONLY MODE: Give exactly ONE sentence summary. Then list sources with specific page numbers (e.g., "- The Trog 2021, page 19"). No explanations."""
def create_tool_call_limiter(max_calls: int = 3):
"""Create a process_tool_call callback that limits tool calls."""
call_count = [0] # Mutable container for closure
def create_search_limiter():
"""Block searches after the first one."""
searched = [False]
async def process_tool_call(
ctx: RunContext,
call_tool: CallToolFunc,
name: str,
tool_args: dict[str, Any],
tool_args: dict,
):
call_count[0] += 1
if call_count[0] > max_calls:
return (
f"SEARCH LIMIT REACHED: You have made {max_calls} searches. "
"Stop searching and answer now with what you have. "
"End your reply with: *Your question may be too broad. Try asking something more specific.*"
)
if name == "search_caving_documents":
if searched[0]:
return "You have already searched. Use the results you have."
searched[0] = True
return await call_tool(name, tool_args)
return process_tool_call
@@ -132,7 +125,7 @@ def create_agent(user_roles: list[str] | None = None, sources_only: bool = False
url=CAVE_MCP_URL,
headers={"x-user-roles": roles_header},
timeout=30.0,
process_tool_call=create_tool_call_limiter(max_calls=3),
process_tool_call=create_search_limiter(),
)
toolsets.append(mcp_server)
logger.info(f"MCP server configured with roles: {user_roles}")