This commit is contained in:
2025-12-11 18:17:06 +01:00
parent cdf998dded
commit 0f9e6d51f7
2 changed files with 11 additions and 9 deletions

View File

@@ -26,9 +26,7 @@ mcp = FastMCP("Cavepedia MCP")
def get_user_roles() -> list[str]:
"""Extract user roles from the X-User-Roles header."""
headers = get_http_headers()
print(f"[MCP] All headers: {dict(headers)}")
roles_header = headers.get("x-user-roles", "")
print(f"[MCP] X-User-Roles header: {roles_header}")
if roles_header:
try:
return json.loads(roles_header)
@@ -45,10 +43,17 @@ def embed(text, input_type):
)
return resp.embeddings.float[0]
def search(query) -> list[dict]:
def search(query, roles: list[str]) -> list[dict]:
query_embedding = embed(query, 'search_query')
rows = conn.execute('SELECT * FROM embeddings WHERE embedding IS NOT NULL ORDER BY embedding <=> %s::vector LIMIT 5', (query_embedding,)).fetchall()
if not roles:
# No roles = no results
return []
rows = conn.execute(
'SELECT * FROM embeddings WHERE embedding IS NOT NULL AND role = ANY(%s) ORDER BY embedding <=> %s::vector LIMIT 5',
(roles, query_embedding)
).fetchall()
docs = []
for row in rows:
docs.append({ 'key': row['key'], 'content': row['content']})
@@ -58,15 +63,14 @@ def search(query) -> list[dict]:
def get_cave_location(cave: str, state: str, county: str) -> list[dict]:
"""Lookup cave location as coordinates. Returns up to 5 matches, ordered by most to least relevant."""
roles = get_user_roles()
print(f"get_cave_location called with roles: {roles}")
return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.')
return search(f'{cave} Location, latitude, Longitude. Located in {state} and {county} county.', roles)
@mcp.tool
def general_caving_information(query: str) -> list[dict]:
"""General purpose endpoint for any topic related to caves. Returns up to 5 matches, ordered by most to least relevant."""
roles = get_user_roles()
print(f"general_caving_information called with roles: {roles}")
return search(query)
return search(query, roles)
@mcp.tool
def get_user_info() -> dict:

View File

@@ -106,8 +106,6 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> dict:
context = configurable.get("context", {})
user_roles = context.get("auth0_user_roles", [])
print(f"Chat node invoked with roles: {user_roles}")
# 1. Define the model
model = ChatGoogleGenerativeAI(model="gemini-3-pro-preview", max_output_tokens=65536)