diff --git a/mcp/server.py b/mcp/server.py index 4cb7816..b92451e 100644 --- a/mcp/server.py +++ b/mcp/server.py @@ -26,9 +26,7 @@ mcp = FastMCP("Cavepedia MCP") def get_user_roles() -> list[str]: """Extract user roles from the X-User-Roles header.""" headers = get_http_headers() - print(f"[MCP] All headers: {dict(headers)}") roles_header = headers.get("x-user-roles", "") - print(f"[MCP] X-User-Roles header: {roles_header}") if roles_header: try: return json.loads(roles_header) @@ -45,10 +43,17 @@ def embed(text, input_type): ) return resp.embeddings.float[0] -def search(query) -> list[dict]: +def search(query, roles: list[str]) -> list[dict]: query_embedding = embed(query, 'search_query') - rows = conn.execute('SELECT * FROM embeddings WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::vector LIMIT 5', (query_embedding,)).fetchall() + if not roles: + # No roles = no results + return [] + + rows = conn.execute( + 'SELECT * FROM embeddings WHERE embedding IS NOT NULL AND role = ANY(%s) ORDER BY embedding <=> %s::vector LIMIT 5', + (roles, query_embedding) + ).fetchall() docs = [] for row in rows: docs.append({ 'key': row['key'], 'content': row['content']}) @@ -58,15 +63,14 @@ def search(query) -> list[dict]: def get_cave_location(cave: str, state: str, county: str) -> list[dict]: """Lookup cave location as coordinates. Returns up to 5 matches, ordered by most to least relevant.""" roles = get_user_roles() - print(f"get_cave_location called with roles: {roles}") - return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.') + return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.', roles) @mcp.tool def general_caving_information(query: str) -> list[dict]: """General purpose endpoint for any topic related to caves. Returns up to 5 matches, ordered by most to least relevant.""" roles = get_user_roles() print(f"general_caving_information called with roles: {roles}") - return search(query) + return search(query, roles) @mcp.tool def get_user_info() -> dict: diff --git a/web/agent/main.py b/web/agent/main.py index 730a3ae..18a83b4 100644 --- a/web/agent/main.py +++ b/web/agent/main.py @@ -106,8 +106,6 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> dict: context = configurable.get("context", {}) user_roles = context.get("auth0_user_roles", []) - print(f"Chat node invoked with roles: {user_roles}") - # 1. Define the model model = ChatGoogleGenerativeAI(model="gemini-3-pro-preview", max_output_tokens=65536)