Compare commits
7 Commits
e671242eca
...
29b111080f
| Author | SHA1 | Date | |
|---|---|---|---|
| 29b111080f | |||
| f869381283 | |||
| bc1dc8a11a | |||
| 4ac0389ce2 | |||
| 6654496379 | |||
| e2c18b07a5 | |||
| 31a9e868e9 |
@@ -59,23 +59,38 @@ def embed(text, input_type):
|
|||||||
assert resp.embeddings.float_ is not None
|
assert resp.embeddings.float_ is not None
|
||||||
return resp.embeddings.float_[0]
|
return resp.embeddings.float_[0]
|
||||||
|
|
||||||
def search(query, roles: list[str], limit: int = 3, max_content_length: int = 1500) -> list[dict]:
|
def search(query, roles: list[str], top_n: int = 3, max_content_length: int = 1500) -> list[dict]:
|
||||||
|
"""Search with vector similarity, then rerank with Cohere for better relevance."""
|
||||||
query_embedding = embed(query, 'search_query')
|
query_embedding = embed(query, 'search_query')
|
||||||
|
|
||||||
if not roles:
|
if not roles:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# Fetch more candidates for reranking
|
||||||
|
candidate_limit = top_n * 4
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
'SELECT * FROM embeddings WHERE embedding IS NOT NULL AND role = ANY(%s) ORDER BY embedding <=> %s::vector LIMIT %s',
|
'SELECT * FROM embeddings WHERE embedding IS NOT NULL AND role = ANY(%s) ORDER BY embedding <=> %s::vector LIMIT %s',
|
||||||
(roles, query_embedding, limit)
|
(roles, query_embedding, candidate_limit)
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Rerank with Cohere for better relevance
|
||||||
|
rerank_resp = co.rerank(
|
||||||
|
query=query,
|
||||||
|
documents=[row['content'] or '' for row in rows],
|
||||||
|
model='rerank-v3.5',
|
||||||
|
top_n=top_n,
|
||||||
|
)
|
||||||
|
|
||||||
docs = []
|
docs = []
|
||||||
for row in rows:
|
for result in rerank_resp.results:
|
||||||
|
row = rows[result.index]
|
||||||
content = row['content'] or ''
|
content = row['content'] or ''
|
||||||
if len(content) > max_content_length:
|
if len(content) > max_content_length:
|
||||||
content = content[:max_content_length] + '...[truncated, use get_document_page for full text]'
|
content = content[:max_content_length] + '...[truncated, use get_document_page for full text]'
|
||||||
docs.append({'key': row['key'], 'content': content})
|
docs.append({'key': row['key'], 'content': content, 'relevance': round(result.relevance_score, 3)})
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@mcp.tool
|
@mcp.tool
|
||||||
|
|||||||
@@ -12,4 +12,6 @@ dependencies = [
|
|||||||
"ag-ui-protocol",
|
"ag-ui-protocol",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"httpx",
|
"httpx",
|
||||||
|
"logfire>=4.16.0",
|
||||||
|
"python-json-logger>=4.0.0",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,18 +5,32 @@ PydanticAI agent with MCP tools from Cavepedia server.
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import httpx
|
import httpx
|
||||||
|
import logfire
|
||||||
|
|
||||||
from pydantic_ai import Agent, ModelMessage, RunContext
|
# Set up logging BEFORE logfire (otherwise basicConfig is ignored)
|
||||||
from pydantic_ai.settings import ModelSettings
|
from pythonjsonlogger import jsonlogger
|
||||||
|
|
||||||
# Set up logging based on environment
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||||
log_level = logging.DEBUG if os.getenv("DEBUG") else logging.INFO
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s"))
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=log_level,
|
level=getattr(logging, log_level, logging.INFO),
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
handlers=[handler],
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configure Logfire for observability
|
||||||
|
logfire.configure(
|
||||||
|
environment=os.getenv('ENVIRONMENT', 'development'),
|
||||||
|
)
|
||||||
|
logfire.instrument_pydantic_ai()
|
||||||
|
logfire.instrument_httpx()
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from pydantic_ai import Agent, ModelMessage, RunContext
|
||||||
|
from pydantic_ai.settings import ModelSettings
|
||||||
|
from pydantic_ai.mcp import CallToolFunc
|
||||||
|
|
||||||
CAVE_MCP_URL = os.getenv("CAVE_MCP_URL", "https://mcp.caving.dev/mcp")
|
CAVE_MCP_URL = os.getenv("CAVE_MCP_URL", "https://mcp.caving.dev/mcp")
|
||||||
|
|
||||||
logger.info(f"Initializing Cavepedia agent with CAVE_MCP_URL={CAVE_MCP_URL}")
|
logger.info(f"Initializing Cavepedia agent with CAVE_MCP_URL={CAVE_MCP_URL}")
|
||||||
@@ -64,13 +78,36 @@ def check_mcp_available(url: str, timeout: float = 5.0) -> bool:
|
|||||||
AGENT_INSTRUCTIONS = """Caving assistant. Help with exploration, safety, surveying, locations, geology, equipment, history, conservation.
|
AGENT_INSTRUCTIONS = """Caving assistant. Help with exploration, safety, surveying, locations, geology, equipment, history, conservation.
|
||||||
|
|
||||||
Rules:
|
Rules:
|
||||||
1. ALWAYS cite sources at the end of every reply. Use the 'key' from search results (e.g., "Source: vpi/trog/2021-trog.pdf/page-19.pdf").
|
1. ALWAYS cite sources in a bulleted list at the end of every reply, even if there's only one. Format them human-readably (e.g., "- The Trog 2021, page 19" not "vpi/trog/2021-trog.pdf/page-19.pdf").
|
||||||
2. Say when uncertain. Never hallucinate.
|
2. Say when uncertain. Never hallucinate.
|
||||||
3. Be safety-conscious.
|
3. Be safety-conscious.
|
||||||
4. Can create ascii diagrams/maps.
|
4. Can create ascii diagrams/maps.
|
||||||
5. Be direct—no sycophantic phrases.
|
5. Be direct—no sycophantic phrases.
|
||||||
6. Keep responses concise.
|
6. Keep responses concise.
|
||||||
7. Use tools sparingly—one search usually suffices."""
|
7. Use tools sparingly—one search usually suffices.
|
||||||
|
8. If you hit the search limit, end your reply with an italicized note: *Your question may be too broad. Try asking something more specific.* Do NOT mention "tools" or "tool limits"—the user doesn't know what those are."""
|
||||||
|
|
||||||
|
|
||||||
|
def create_tool_call_limiter(max_calls: int = 3):
|
||||||
|
"""Create a process_tool_call callback that limits tool calls."""
|
||||||
|
call_count = [0] # Mutable container for closure
|
||||||
|
|
||||||
|
async def process_tool_call(
|
||||||
|
ctx: RunContext,
|
||||||
|
call_tool: CallToolFunc,
|
||||||
|
name: str,
|
||||||
|
tool_args: dict[str, Any],
|
||||||
|
):
|
||||||
|
call_count[0] += 1
|
||||||
|
if call_count[0] > max_calls:
|
||||||
|
return (
|
||||||
|
f"SEARCH LIMIT REACHED: You have made {max_calls} searches. "
|
||||||
|
"Stop searching and answer now with what you have. "
|
||||||
|
"End your reply with: *Your question may be too broad. Try asking something more specific.*"
|
||||||
|
)
|
||||||
|
return await call_tool(name, tool_args)
|
||||||
|
|
||||||
|
return process_tool_call
|
||||||
|
|
||||||
|
|
||||||
def create_agent(user_roles: list[str] | None = None):
|
def create_agent(user_roles: list[str] | None = None):
|
||||||
@@ -92,6 +129,7 @@ def create_agent(user_roles: list[str] | None = None):
|
|||||||
url=CAVE_MCP_URL,
|
url=CAVE_MCP_URL,
|
||||||
headers={"x-user-roles": roles_header},
|
headers={"x-user-roles": roles_header},
|
||||||
timeout=30.0,
|
timeout=30.0,
|
||||||
|
process_tool_call=create_tool_call_limiter(max_calls=3),
|
||||||
)
|
)
|
||||||
toolsets.append(mcp_server)
|
toolsets.append(mcp_server)
|
||||||
logger.info(f"MCP server configured with roles: {user_roles}")
|
logger.info(f"MCP server configured with roles: {user_roles}")
|
||||||
|
|||||||
@@ -14,13 +14,27 @@ from pydantic_ai.settings import ModelSettings
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Set up logging based on environment
|
# Set up logging based on environment
|
||||||
log_level = logging.DEBUG if os.getenv("DEBUG") else logging.INFO
|
from pythonjsonlogger import jsonlogger
|
||||||
|
|
||||||
|
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||||
|
json_formatter = jsonlogger.JsonFormatter("%(asctime)s %(name)s %(levelname)s %(message)s")
|
||||||
|
|
||||||
|
# Configure root logger with JSON
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setFormatter(json_formatter)
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=log_level,
|
level=getattr(logging, log_level, logging.INFO),
|
||||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
handlers=[handler],
|
||||||
)
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Apply JSON formatter to uvicorn loggers (works even when run via `uvicorn src.main:app`)
|
||||||
|
for uvicorn_logger_name in ("uvicorn", "uvicorn.error", "uvicorn.access"):
|
||||||
|
uvicorn_logger = logging.getLogger(uvicorn_logger_name)
|
||||||
|
uvicorn_logger.handlers = [handler]
|
||||||
|
uvicorn_logger.setLevel(getattr(logging, log_level, logging.INFO))
|
||||||
|
uvicorn_logger.propagate = False
|
||||||
|
|
||||||
# Validate required environment variables
|
# Validate required environment variables
|
||||||
if not os.getenv("ANTHROPIC_API_KEY"):
|
if not os.getenv("ANTHROPIC_API_KEY"):
|
||||||
logger.error("ANTHROPIC_API_KEY environment variable is required")
|
logger.error("ANTHROPIC_API_KEY environment variable is required")
|
||||||
@@ -41,12 +55,9 @@ logger.info("Creating AG-UI app...")
|
|||||||
|
|
||||||
async def handle_agent_request(request: Request) -> Response:
|
async def handle_agent_request(request: Request) -> Response:
|
||||||
"""Handle incoming AG-UI requests with dynamic role-based MCP configuration."""
|
"""Handle incoming AG-UI requests with dynamic role-based MCP configuration."""
|
||||||
# Debug: log all incoming headers
|
|
||||||
logger.info(f"DEBUG: All request headers: {dict(request.headers)}")
|
|
||||||
|
|
||||||
# Extract user roles from request headers
|
# Extract user roles from request headers
|
||||||
roles_header = request.headers.get("x-user-roles", "")
|
roles_header = request.headers.get("x-user-roles", "")
|
||||||
logger.info(f"DEBUG: x-user-roles header value: '{roles_header}'")
|
|
||||||
user_roles = []
|
user_roles = []
|
||||||
|
|
||||||
if roles_header:
|
if roles_header:
|
||||||
@@ -59,13 +70,12 @@ async def handle_agent_request(request: Request) -> Response:
|
|||||||
# Create agent with the user's roles
|
# Create agent with the user's roles
|
||||||
agent = create_agent(user_roles)
|
agent = create_agent(user_roles)
|
||||||
|
|
||||||
# Dispatch the request using AGUIAdapter with usage limits
|
# Dispatch the request - tool limits handled by ToolCallLimiter in agent.py
|
||||||
return await AGUIAdapter.dispatch_request(
|
return await AGUIAdapter.dispatch_request(
|
||||||
request,
|
request,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
usage_limits=UsageLimits(
|
usage_limits=UsageLimits(
|
||||||
request_limit=5, # Max 5 LLM requests per query
|
request_limit=10, # Safety net for runaway requests
|
||||||
tool_calls_limit=3, # Max 3 tool calls per query
|
|
||||||
),
|
),
|
||||||
model_settings=ModelSettings(max_tokens=4096),
|
model_settings=ModelSettings(max_tokens=4096),
|
||||||
)
|
)
|
||||||
|
|||||||
4
web/agent/uv.lock
generated
4
web/agent/uv.lock
generated
@@ -231,10 +231,12 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "ag-ui-protocol" },
|
{ name = "ag-ui-protocol" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
{ name = "logfire" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "pydantic-ai" },
|
{ name = "pydantic-ai" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
|
{ name = "python-json-logger" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "uvicorn" },
|
{ name = "uvicorn" },
|
||||||
]
|
]
|
||||||
@@ -243,10 +245,12 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "ag-ui-protocol" },
|
{ name = "ag-ui-protocol" },
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
|
{ name = "logfire", specifier = ">=4.16.0" },
|
||||||
{ name = "mcp" },
|
{ name = "mcp" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "pydantic-ai" },
|
{ name = "pydantic-ai" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
|
{ name = "python-json-logger", specifier = ">=4.0.0" },
|
||||||
{ name = "starlette" },
|
{ name = "starlette" },
|
||||||
{ name = "uvicorn" },
|
{ name = "uvicorn" },
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -7,20 +7,29 @@ import { useUser } from "@auth0/nextjs-auth0/client";
|
|||||||
import LoginButton from "@/components/LoginButton";
|
import LoginButton from "@/components/LoginButton";
|
||||||
import LogoutButton from "@/components/LogoutButton";
|
import LogoutButton from "@/components/LogoutButton";
|
||||||
|
|
||||||
// Separate component to safely use useCopilotChat hook
|
// Block input and show indicator while agent is processing
|
||||||
function ThinkingIndicator() {
|
function LoadingOverlay() {
|
||||||
try {
|
try {
|
||||||
const { isLoading } = useCopilotChat();
|
const { isLoading } = useCopilotChat();
|
||||||
if (!isLoading) return null;
|
if (!isLoading) return null;
|
||||||
return (
|
return (
|
||||||
<div className="absolute bottom-24 left-1/2 transform -translate-x-1/2 bg-white shadow-lg rounded-full px-4 py-2 flex items-center gap-2 z-50">
|
<>
|
||||||
<div className="flex gap-1">
|
{/* Overlay to block input area */}
|
||||||
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "0ms" }}></span>
|
<div
|
||||||
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "150ms" }}></span>
|
className="absolute bottom-0 left-0 right-0 h-24 z-40"
|
||||||
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "300ms" }}></span>
|
style={{ pointerEvents: 'all' }}
|
||||||
|
onClick={(e) => e.stopPropagation()}
|
||||||
|
/>
|
||||||
|
{/* Thinking indicator */}
|
||||||
|
<div className="absolute bottom-24 left-1/2 transform -translate-x-1/2 bg-white shadow-lg rounded-full px-4 py-2 flex items-center gap-2 z-50">
|
||||||
|
<div className="flex gap-1">
|
||||||
|
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "0ms" }}></span>
|
||||||
|
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "150ms" }}></span>
|
||||||
|
<span className="w-2 h-2 bg-indigo-500 rounded-full animate-bounce" style={{ animationDelay: "300ms" }}></span>
|
||||||
|
</div>
|
||||||
|
<span className="text-sm text-gray-600">Thinking...</span>
|
||||||
</div>
|
</div>
|
||||||
<span className="text-sm text-gray-600">Thinking...</span>
|
</>
|
||||||
</div>
|
|
||||||
);
|
);
|
||||||
} catch {
|
} catch {
|
||||||
return null;
|
return null;
|
||||||
@@ -121,7 +130,7 @@ export default function CopilotKitPage() {
|
|||||||
className="h-full w-full"
|
className="h-full w-full"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<ThinkingIndicator />
|
<LoadingOverlay />
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user