pass roles to mcp
This commit is contained in:
@@ -3,7 +3,8 @@ This is the main entry point for the agent.
|
||||
It defines the workflow graph, state, tools, nodes and edges.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Callable, Awaitable
|
||||
import json
|
||||
|
||||
from langchain.tools import tool
|
||||
from langchain_core.messages import BaseMessage, SystemMessage
|
||||
@@ -13,6 +14,7 @@ from langgraph.graph import END, MessagesState, StateGraph, START
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from langgraph.types import Command
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.interceptors import MCPToolCallRequest, MCPToolCallResult
|
||||
|
||||
|
||||
class AgentState(MessagesState):
|
||||
@@ -37,41 +39,55 @@ backend_tools = [
|
||||
# your_tool_here
|
||||
]
|
||||
|
||||
def get_mcp_client(access_token: str = None):
|
||||
"""Create MCP client with optional authentication headers."""
|
||||
headers = {}
|
||||
if access_token:
|
||||
headers["Authorization"] = f"Bearer {access_token}"
|
||||
class RolesHeaderInterceptor:
|
||||
"""Interceptor that injects user roles header into MCP tool calls."""
|
||||
|
||||
def __init__(self, user_roles: list = None):
|
||||
self.user_roles = user_roles or []
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: MCPToolCallRequest,
|
||||
handler: Callable[[MCPToolCallRequest], Awaitable[MCPToolCallResult]]
|
||||
) -> MCPToolCallResult:
|
||||
headers = dict(request.headers or {})
|
||||
if self.user_roles:
|
||||
headers["X-User-Roles"] = json.dumps(self.user_roles)
|
||||
|
||||
modified_request = request.override(headers=headers)
|
||||
return await handler(modified_request)
|
||||
|
||||
def get_mcp_client(user_roles: list = None):
|
||||
"""Create MCP client with user roles header."""
|
||||
return MultiServerMCPClient(
|
||||
{
|
||||
"cavepedia": {
|
||||
"transport": "streamable_http",
|
||||
"url": "https://mcp.caving.dev/mcp",
|
||||
"timeout": 10.0,
|
||||
"headers": headers,
|
||||
}
|
||||
}
|
||||
},
|
||||
tool_interceptors=[RolesHeaderInterceptor(user_roles)]
|
||||
)
|
||||
|
||||
# Cache for MCP tools per access token
|
||||
_mcp_tools_cache = {}
|
||||
|
||||
async def get_mcp_tools(access_token: str = None):
|
||||
"""Lazy load MCP tools with authentication."""
|
||||
cache_key = access_token or "default"
|
||||
async def get_mcp_tools(user_roles: list = None):
|
||||
"""Lazy load MCP tools with user roles."""
|
||||
roles_key = ",".join(sorted(user_roles)) if user_roles else "default"
|
||||
|
||||
if cache_key not in _mcp_tools_cache:
|
||||
if roles_key not in _mcp_tools_cache:
|
||||
try:
|
||||
mcp_client = get_mcp_client(access_token)
|
||||
mcp_client = get_mcp_client(user_roles)
|
||||
tools = await mcp_client.get_tools()
|
||||
_mcp_tools_cache[cache_key] = tools
|
||||
print(f"Loaded {len(tools)} tools from MCP server with auth: {bool(access_token)}")
|
||||
_mcp_tools_cache[roles_key] = tools
|
||||
print(f"Loaded {len(tools)} tools from MCP server with roles: {user_roles}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load MCP tools: {e}")
|
||||
_mcp_tools_cache[cache_key] = []
|
||||
_mcp_tools_cache[roles_key] = []
|
||||
|
||||
return _mcp_tools_cache[cache_key]
|
||||
return _mcp_tools_cache[roles_key]
|
||||
|
||||
|
||||
async def chat_node(state: AgentState, config: RunnableConfig) -> dict:
|
||||
@@ -86,18 +102,18 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> dict:
|
||||
https://www.perplexity.ai/search/react-agents-NcXLQhreS0WDzpVaS4m9Cg
|
||||
"""
|
||||
|
||||
# 0. Extract Auth0 access token from config
|
||||
# 0. Extract user roles from config.configurable.context
|
||||
configurable = config.get("configurable", {})
|
||||
access_token = configurable.get("auth0_access_token")
|
||||
user_roles = configurable.get("auth0_user_roles", [])
|
||||
context = configurable.get("context", {})
|
||||
user_roles = context.get("auth0_user_roles", [])
|
||||
|
||||
print(f"Chat node invoked with auth token: {bool(access_token)}, roles: {user_roles}")
|
||||
print(f"Chat node invoked with roles: {user_roles}")
|
||||
|
||||
# 1. Define the model
|
||||
model = ChatAnthropic(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# 1.5 Load MCP tools from the cavepedia server with authentication
|
||||
mcp_tools = await get_mcp_tools(access_token)
|
||||
# 1.5 Load MCP tools from the cavepedia server with roles
|
||||
mcp_tools = await get_mcp_tools(user_roles)
|
||||
|
||||
# 2. Bind the tools to the model
|
||||
model_with_tools = model.bind_tools(
|
||||
@@ -135,12 +151,13 @@ async def tool_node_wrapper(state: AgentState, config: RunnableConfig) -> dict:
|
||||
"""
|
||||
Custom tool node that handles both backend tools and MCP tools.
|
||||
"""
|
||||
# Extract Auth0 access token from config
|
||||
# Extract user roles from config.configurable.context
|
||||
configurable = config.get("configurable", {})
|
||||
access_token = configurable.get("auth0_access_token")
|
||||
context = configurable.get("context", {})
|
||||
user_roles = context.get("auth0_user_roles", [])
|
||||
|
||||
# Load MCP tools with authentication and combine with backend tools
|
||||
mcp_tools = await get_mcp_tools(access_token)
|
||||
# Load MCP tools with roles
|
||||
mcp_tools = await get_mcp_tools(user_roles)
|
||||
all_tools = [*backend_tools, *mcp_tools]
|
||||
|
||||
# Use the standard ToolNode with all tools
|
||||
|
||||
Reference in New Issue
Block a user