← 返回文章列表

从零到一实现生产级 MCP Gateway(三):服务注册与发现

2025-02-28·8 分钟阅读

从零到一实现生产级 MCP Gateway(三):注册中心实现

前言

注册中心(Registry)是 MCP Gateway 的核心组件,负责管理工具(Tools)、资源(Resources)和提示词(Prompts)的完整生命周期。它就像一个"能力市场",AI Agent 可以在这里发现、查询和调用各种能力。本章将深入介绍三大注册中心的设计与实现。

设计思路:为什么需要 Provider 注册中心?

问题背景

当系统需要管理多个 MCP Provider 时,直接硬编码连接信息会带来问题:

┌─────────────────────────────────────────────────────────────────────┐
│                    硬编码配置的问题                                   │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  方式 A:配置文件中写死                                              │
│  providers:                                                         │
│    - name: "filesystem"                                             │
│      url: "http://localhost:8001"                                   │
│    - name: "git"                                                    │
│      url: "http://localhost:8002"                                   │
│                                                                      │
│  问题:                                                              │
│  1. Provider 地址变化需要重启服务                                    │
│  2. 无法动态增删 Provider                                           │
│  3. 无法感知 Provider 健康状态                                      │
│  4. 扩展困难                                                        │
│                                                                      │
│  方式 B:数据库存储                                                  │
│  providers 表: id, name, url, status                                │
│                                                                      │
│  优点:                                                              │
│  1. 支持动态配置                                                    │
│  2. 可以查询历史状态                                                │
│  3. 支持多实例部署                                                  │
│                                                                      │
│  方式 C:注册中心                                                    │
│  Provider 启动时注册,心跳保活,自动发现                             │
│                                                                      │
│  优点:                                                              │
│  1. 自动健康检查                                                    │
│  2. 自动服务发现                                                    │
│  3. 负载均衡支持                                                    │
│  4. 故障自动摘除                                                    │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

为什么选择简化版注册中心?

考虑因素

┌─────────────────────────────────────────────────────────────────────┐
│                    注册中心选型                                       │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  选项 A:etcd / Consul                                               │
│  - 优点:成熟稳定,生产级可靠性                                      │
│  - 缺点:部署复杂,学习成本高                                        │
│  - 适用:大规模分布式系统                                            │
│                                                                      │
│  选项 B:Redis                                                       │
│  - 优点:轻量,可能已有 Redis 基础设施                               │
│  - 缺点:非专为服务注册设计                                          │
│  - 适用:中小规模系统                                                │
│                                                                      │
│  选项 C:自建简化版(本文方案)                                      │
│  - 优点:完全可控,便于理解原理                                      │
│  - 缺点:可靠性需自行保证                                            │
│  - 适用:教学演示、小型项目                                          │
│                                                                      │
│  关键能力:                                                          │
│  1. Provider 注册与注销                                             │
│  2. 健康检查(心跳机制)                                            │
│  3. 服务发现                                                        │
│  4. 状态变更通知                                                    │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

心跳机制的设计考量

┌─────────────────────────────────────────────────────────────────────┐
│                    心跳参数设计                                       │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  关键参数:                                                          │
│  - heartbeat_interval: 心跳间隔(Provider 发送频率)                │
│  - heartbeat_timeout: 心跳超时(判定死亡的时间)                     │
│                                                                      │
│  经验值:                                                            │
│  - heartbeat_interval = 5-10 秒                                     │
│  - heartbeat_timeout = 3 * heartbeat_interval                       │
│                                                                      │
│  为什么是这个比例?                                                  │
│  - 3 次心跳机会:允许网络抖动丢包                                    │
│  - 不会太长:故障能较快被发现                                        │
│  - 不会太短:避免误判                                               │
│                                                                      │
│  示例:                                                              │
│  interval = 10 秒                                                    │
│  timeout = 30 秒                                                     │
│                                                                      │
│  时间线:                                                            │
│  T+0s:  Provider 发送心跳 → 注册中心记录 last_seen                  │
│  T+10s: Provider 发送心跳 → 更新 last_seen                          │
│  T+20s: 网络抖动,心跳丢失                                           │
│  T+30s: 网络恢复,Provider 发送心跳 → 更新 last_seen(正常)         │
│                                                                      │
│  如果 T+30s 还没收到心跳:                                           │
│  → 判定 Provider 不可用,从服务列表移除                              │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

方案对比:健康检查策略

方案一:被动检查(本文方案)

# Provider 主动发送心跳
@app.post("/heartbeat")
async def heartbeat(provider_id: str):
    registry.update_last_seen(provider_id)

# 注册中心定期检查超时
async def check_timeouts():
    while True:
        for provider in registry.list_all():
            if time.now() - provider.last_seen > TIMEOUT:
                registry.mark_unhealthy(provider)
        await asyncio.sleep(CHECK_INTERVAL)

优点:实现简单,Provider 控制心跳频率
缺点:Provider 可能过载仍能发心跳
适用:网络质量较好的环境

方案二:主动探测

# 注册中心主动探测
async def health_check():
    for provider in registry.list_all():
        try:
            response = await httpx.get(f"{provider.url}/health", timeout=5)
            if response.status_code == 200:
                registry.mark_healthy(provider)
            else:
                registry.mark_unhealthy(provider)
        except:
            registry.mark_unhealthy(provider)

优点:能检测更多故障类型
缺点:增加网络开销,探测频率需权衡
适用:关键服务

方案三:混合模式(推荐)

# 结合心跳 + 主动探测
async def check_provider_health(provider):
    # 1. 心跳超时检查
    if time.now() - provider.last_seen > TIMEOUT:
        return "unhealthy"
    
    # 2. 定期主动探测
    if time.now() - provider.last_probe > PROBE_INTERVAL:
        try:
            await probe_health(provider)
            provider.last_probe = time.now()
        except:
            return "unhealthy"
    
    return "healthy"

优点:双重保障,更可靠
缺点:实现复杂
适用:生产环境

常见陷阱与解决方案

陷阱一:心跳间隔设置不合理

问题描述

# 心跳间隔太短
HEARTBEAT_INTERVAL = 1  # 秒
# 问题:网络开销大,Provider 负担重

# 心跳间隔太长
HEARTBEAT_INTERVAL = 60  # 秒
# 问题:故障发现慢,影响用户体验

解决方案:根据场景选择合适值

# 开发环境:快速发现问题
HEARTBEAT_INTERVAL = 5
HEARTBEAT_TIMEOUT = 15

# 生产环境:减少网络开销
HEARTBEAT_INTERVAL = 10
HEARTBEAT_TIMEOUT = 30

# 关键服务:快速故障转移
HEARTBEAT_INTERVAL = 3
HEARTBEAT_TIMEOUT = 10

陷阱二:并发心跳请求导致状态不一致

问题描述

# 多个请求同时更新 last_seen
async def heartbeat(provider_id: str):
    provider = await get_provider(provider_id)
    provider.last_seen = time.now()  # 竞态条件
    await save_provider(provider)

解决方案:使用原子操作

# 使用 Redis 原子操作
await redis.hset(f"provider:{provider_id}", "last_seen", time.now())

# 或使用数据库乐观锁
UPDATE providers SET last_seen = ?, version = version + 1 
WHERE id = ? AND version = ?

陷阱三:Provider 重启后 ID 变化

问题描述

Provider A 启动 → 分配 ID: "provider-123"
Provider A 重启 → 分配 ID: "provider-456"  # 新 ID

问题:
- 历史记录关联丢失
- Gateway 缓存失效
- 客户端需要重新发现

解决方案:使用稳定的 ID 生成策略

# 基于 Provider 名称生成稳定 ID
def generate_provider_id(name: str) -> str:
    return f"provider-{hashlib.sha256(name.encode()).hexdigest()[:8]}"

# 或使用配置文件中的固定 ID
provider:
  id: "filesystem-provider-01"  # 固定 ID
  name: "filesystem"

陷阱四:忘记清理已注销 Provider 的资源

问题描述

# Provider 注销后
registry.unregister(provider_id)

# 但忘记清理:
# - WebSocket 连接
# - 缓存数据
# - 会话状态

解决方案:实现完整的清理流程

async def unregister_provider(provider_id: str):
    provider = await registry.get(provider_id)
    
    # 1. 关闭连接
    if provider.websocket:
        await provider.websocket.close()
    
    # 2. 清理缓存
    await cache.delete(f"provider:{provider_id}:*")
    
    # 3. 通知客户端
    await notify_clients({"type": "provider_removed", "id": provider_id})
    
    # 4. 从注册表移除
    await registry.remove(provider_id)

陷阱五:Gateway 重启后 Provider 状态丢失

问题描述

Gateway 重启 → 内存中的注册表清空
Provider 仍在运行 → 不知道需要重新注册

结果:所有 Provider "丢失"

解决方案:持久化 + 恢复机制

# 方案 1:启动时从数据库恢复
async def startup():
    providers = await db.query("SELECT * FROM providers WHERE status = 'healthy'")
    for p in providers:
        registry.add(p)
    
    # 验证恢复的 Provider 是否存活
    await verify_all_providers()

# 方案 2:Provider 定时重注册
async def provider_loop():
    while True:
        await register_to_gateway()
        await asyncio.sleep(HEARTBEAT_INTERVAL)

# 方案 3:使用外部存储(Redis)
# Gateway 无状态,数据都在 Redis 中

注册中心架构

┌─────────────────────────────────────────────────────────────────────┐
│                       Registry Architecture                          │
│                                                                      │
│  ┌───────────────────────────────────────────────────────────────┐  │
│  │                     MCP Protocol Layer                         │  │
│  │  tools/list │ tools/call │ resources/* │ prompts/*            │  │
│  └───────────────────────────┬───────────────────────────────────┘  │
│                              │                                       │
│  ┌───────────────────────────┴───────────────────────────────────┐  │
│  │                     Registry Layer                             │  │
│  │                                                                │  │
│  │  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐           │  │
│  │  │    Tool     │  │  Resource   │  │   Prompt    │           │  │
│  │  │  Registry   │  │  Registry   │  │  Registry   │           │  │
│  │  ├─────────────┤  ├─────────────┤  ├─────────────┤           │  │
│  │  │ • register  │  │ • register  │  │ • register  │           │  │
│  │  │ • deregister│  │ • read      │  │ • get       │           │  │
│  │  │ • list      │  │ • subscribe │  │ • list      │           │  │
│  │  │ • execute   │  │ • notify    │  │             │           │  │
│  │  └─────────────┘  └─────────────┘  └─────────────┘           │  │
│  │                                                                │  │
│  └───────────────────────────────────────────────────────────────┘  │
│                              │                                       │
│  ┌───────────────────────────┴───────────────────────────────────┐  │
│  │                     Handler Layer                              │  │
│  │  Async Functions │ REST Adapters │ MCP Servers                │  │
│  └───────────────────────────────────────────────────────────────┘  │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

工具注册中心

核心数据结构

# registry/tool_registry.py

from __future__ import annotations
import asyncio
import json
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Awaitable

import jsonschema
from jsonschema import ValidationError as JsonSchemaValidationError

from ..protocol import (
    Tool,
    ToolAnnotations,
    CallToolResult,
    TextContent,
    ContentBlock,
)

logger = logging.getLogger(__name__)


@dataclass
class RegisteredTool:
    """已注册工具的完整定义"""
    tool: Tool                              # MCP 工具定义
    handler: Callable[..., Awaitable[Any]]  # 异步执行函数
    server_id: str | None = None            # 所属服务器 ID(联邦场景)
    metadata: dict[str, Any] = field(default_factory=dict)


class ToolExecutor(ABC):
    """工具执行器抽象基类"""
    
    @abstractmethod
    async def execute(
        self, 
        name: str, 
        arguments: dict[str, Any]
    ) -> CallToolResult:
        """执行工具"""
        pass


class ToolRegistry:
    """工具注册中心
    
    功能:
    - 工具注册与注销
    - 工具发现(按名称、模式、服务器过滤)
    - 工具执行与参数验证
    - 多服务器联邦支持
    """
    
    def __init__(self):
        self._tools: dict[str, RegisteredTool] = {}
        self._tools_by_server: dict[str, list[str]] = {}
        self._lock = asyncio.Lock()

工具注册

    async def register(
        self,
        name: str,
        description: str,
        input_schema: dict[str, Any],
        handler: Callable[..., Awaitable[Any]],
        annotations: ToolAnnotations | None = None,
        server_id: str | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> Tool:
        """注册新工具
        
        Args:
            name: 工具唯一标识名
            description: 工具描述(供 AI 理解)
            input_schema: JSON Schema 格式的参数定义
            handler: 异步执行函数
            annotations: 工具注解(只读、破坏性等提示)
            server_id: 所属服务器 ID(联邦场景)
            metadata: 扩展元数据
            
        Returns:
            注册成功的 Tool 对象
            
        Raises:
            ValueError: 工具名已存在
        """
        async with self._lock:
            if name in self._tools:
                raise ValueError(f"Tool already registered: {name}")
            
            tool = Tool(
                name=name,
                description=description,
                input_schema=input_schema,
                annotations=annotations,
            )
            
            registered_tool = RegisteredTool(
                tool=tool,
                handler=handler,
                server_id=server_id,
                metadata=metadata or {},
            )
            
            self._tools[name] = registered_tool
            
            # 按服务器分组(联邦场景)
            if server_id:
                if server_id not in self._tools_by_server:
                    self._tools_by_server[server_id] = []
                self._tools_by_server[server_id].append(name)
            
            logger.info(f"Registered tool: {name}" + 
                       (f" (server: {server_id})" if server_id else ""))
            
            return tool

工具注销

    async def deregister(self, name: str) -> bool:
        """注销工具
        
        Args:
            name: 工具名
            
        Returns:
            True 注销成功,False 工具不存在
        """
        async with self._lock:
            if name not in self._tools:
                return False
            
            registered_tool = self._tools.pop(name)
            
            # 清理服务器分组
            if registered_tool.server_id:
                server_tools = self._tools_by_server.get(
                    registered_tool.server_id, []
                )
                if name in server_tools:
                    server_tools.remove(name)
            
            logger.info(f"Deregistered tool: {name}")
            return True
    
    async def deregister_server(self, server_id: str) -> int:
        """注销服务器的所有工具(联邦场景)
        
        Args:
            server_id: 服务器 ID
            
        Returns:
            注销的工具数量
        """
        async with self._lock:
            tool_names = self._tools_by_server.pop(server_id, [])
            count = 0
            
            for name in tool_names:
                if name in self._tools:
                    del self._tools[name]
                    count += 1
            
            logger.info(f"Deregistered {count} tools from server: {server_id}")
            return count

参数验证

    def validate_arguments(
        self, 
        tool: Tool, 
        arguments: dict[str, Any]
    ) -> tuple[bool, str | None]:
        """验证工具参数
        
        使用 JSON Schema 验证参数格式
        
        Args:
            tool: 工具定义
            arguments: 待验证参数
            
        Returns:
            (是否有效, 错误信息)
        """
        schema = tool.input_schema
        
        # 空 schema 表示无验证
        if not schema or schema == {"type": "object"}:
            return True, None
        
        try:
            jsonschema.validate(arguments, schema)
            return True, None
        except JsonSchemaValidationError as e:
            # 构建友好的错误信息
            path = ".".join(str(p) for p in e.absolute_path) 
            if e.absolute_path else "root"
            error_msg = f"Validation error at '{path}': {e.message}"
            return False, error_msg

工具执行

    async def execute(
        self, 
        name: str, 
        arguments: dict[str, Any]
    ) -> CallToolResult:
        """执行工具
        
        Args:
            name: 工具名
            arguments: 工具参数(将被验证)
            
        Returns:
            CallToolResult 包含内容或错误
        """
        registered = self._tools.get(name)
        
        if registered is None:
            return CallToolResult(
                content=[TextContent(text=f"Tool not found: {name}")],
                is_error=True,
            )
        
        # 验证参数
        is_valid, error_msg = self.validate_arguments(
            registered.tool, arguments
        )
        if not is_valid:
            logger.warning(f"Invalid arguments for tool '{name}': {error_msg}")
            return CallToolResult(
                content=[TextContent(text=f"Invalid arguments: {error_msg}")],
                is_error=True,
            )
        
        try:
            # 执行处理器
            result = await registered.handler(**arguments)
            
            # 转换结果为内容块
            content = self._convert_result(result)
            
            return CallToolResult(content=content, is_error=False)
            
        except TypeError as e:
            logger.warning(f"Type error executing tool '{name}': {e}")
            return CallToolResult(
                content=[TextContent(text=f"Argument error: {str(e)}")],
                is_error=True,
            )
        except Exception as e:
            logger.exception(f"Error executing tool: {name}")
            return CallToolResult(
                content=[TextContent(text=f"Error: {str(e)}")],
                is_error=True,
            )
    
    def _convert_result(self, result: Any) -> list[ContentBlock]:
        """将各种类型结果转换为 MCP Content 块"""
        if isinstance(result, list):
            return [
                item if isinstance(item, (TextContent, ContentBlock)) 
                else TextContent(text=str(item))
                for item in result
            ]
        
        if isinstance(result, str):
            return [TextContent(text=result)]
        
        if isinstance(result, dict):
            return [TextContent(text=json.dumps(result, indent=2))]
        
        return [TextContent(text=str(result))]

工具发现

    def get(self, name: str) -> RegisteredTool | None:
        """按名称获取工具"""
        return self._tools.get(name)
    
    def list_tools(
        self, 
        server_id: str | None = None,
        pattern: str | None = None,
    ) -> list[Tool]:
        """列出工具
        
        Args:
            server_id: 按服务器过滤
            pattern: 按名称模式过滤(支持通配符)
            
        Returns:
            Tool 对象列表
        """
        tools = []
        
        for name, registered in self._tools.items():
            # 服务器过滤
            if server_id and registered.server_id != server_id:
                continue
            
            # 模式匹配
            if pattern:
                import fnmatch
                if not fnmatch.fnmatch(name, pattern):
                    continue
            
            tools.append(registered.tool)
        
        return tools

全局实例

# 全局注册中心实例
_registry: ToolRegistry | None = None


def get_registry() -> ToolRegistry:
    """获取全局工具注册中心"""
    global _registry
    if _registry is None:
        _registry = ToolRegistry()
    return _registry

资源注册中心

核心数据结构

# registry/resource_registry.py

from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Awaitable

from ..protocol import (
    Resource,
    ResourceTemplate,
    ResourceContents,
    TextResourceContents,
    BlobResourceContents,
)

logger = logging.getLogger(__name__)


@dataclass
class RegisteredResource:
    """已注册资源的完整定义"""
    resource: Resource
    read_handler: Callable[[str], Awaitable[ResourceContents]]
    server_id: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)


@dataclass 
class Subscription:
    """资源订阅"""
    uri: str
    subscriber_id: str
    callback: Callable[[ResourceContents], Awaitable[None]] | None = None


class ResourceRegistry:
    """资源注册中心
    
    功能:
    - 资源注册与发现
    - 资源读取
    - 订阅/通知机制
    - 资源模板支持
    """
    
    def __init__(self):
        self._resources: dict[str, RegisteredResource] = {}
        self._templates: list[ResourceTemplate] = []
        self._subscriptions: dict[str, list[Subscription]] = {}
        self._lock = asyncio.Lock()

资源注册

    async def register(
        self,
        uri: str,
        name: str,
        read_handler: Callable[[str], Awaitable[ResourceContents]],
        description: str | None = None,
        mime_type: str | None = None,
        server_id: str | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> Resource:
        """注册资源
        
        Args:
            uri: 资源 URI
            name: 资源名称
            read_handler: 读取处理器
            description: 资源描述
            mime_type: MIME 类型
            server_id: 所属服务器 ID
            metadata: 扩展元数据
            
        Returns:
            注册成功的 Resource 对象
        """
        async with self._lock:
            if uri in self._resources:
                raise ValueError(f"Resource already registered: {uri}")
            
            resource = Resource(
                uri=uri,
                name=name,
                description=description,
                mime_type=mime_type,
            )
            
            registered = RegisteredResource(
                resource=resource,
                read_handler=read_handler,
                server_id=server_id,
                metadata=metadata or {},
            )
            
            self._resources[uri] = registered
            logger.info(f"Registered resource: {uri}")
            
            return resource

资源读取

    async def read(self, uri: str) -> list[ResourceContents]:
        """读取资源内容
        
        Args:
            uri: 资源 URI
            
        Returns:
            资源内容列表
        """
        registered = self._resources.get(uri)
        
        if registered is None:
            # 尝试匹配模板
            registered = self._match_template(uri)
        
        if registered is None:
            raise ValueError(f"Resource not found: {uri}")
        
        try:
            contents = await registered.read_handler(uri)
            return [contents] if not isinstance(contents, list) else contents
        except Exception as e:
            logger.exception(f"Error reading resource: {uri}")
            raise
    
    def _match_template(self, uri: str) -> RegisteredResource | None:
        """匹配资源模板"""
        # TODO: 实现 URI 模板匹配
        return None

订阅机制

    async def subscribe(
        self,
        uri: str,
        subscriber_id: str,
        callback: Callable[[ResourceContents], Awaitable[None]] | None = None,
    ) -> bool:
        """订阅资源变更
        
        Args:
            uri: 资源 URI
            subscriber_id: 订阅者 ID
            callback: 可选的变更回调
            
        Returns:
            订阅是否成功
        """
        if uri not in self._resources:
            return False
        
        async with self._lock:
            if uri not in self._subscriptions:
                self._subscriptions[uri] = []
            
            subscription = Subscription(
                uri=uri,
                subscriber_id=subscriber_id,
                callback=callback,
            )
            
            self._subscriptions[uri].append(subscription)
            logger.info(f"Subscribed to resource: {uri}")
            
            return True
    
    async def unsubscribe(
        self,
        uri: str,
        subscriber_id: str,
    ) -> bool:
        """取消订阅"""
        async with self._lock:
            if uri not in self._subscriptions:
                return False
            
            self._subscriptions[uri] = [
                s for s in self._subscriptions[uri]
                if s.subscriber_id != subscriber_id
            ]
            
            logger.info(f"Unsubscribed from resource: {uri}")
            return True
    
    async def notify_update(self, uri: str) -> None:
        """通知资源更新"""
        subscriptions = self._subscriptions.get(uri, [])
        
        for sub in subscriptions:
            if sub.callback:
                try:
                    contents = await self.read(uri)
                    for content in contents:
                        await sub.callback(content)
                except Exception as e:
                    logger.error(f"Error notifying subscriber: {e}")

全局实例

_resource_registry: ResourceRegistry | None = None


def get_resource_registry() -> ResourceRegistry:
    """获取全局资源注册中心"""
    global _resource_registry
    if _resource_registry is None:
        _resource_registry = ResourceRegistry()
    return _resource_registry

提示词注册中心

核心数据结构

# registry/prompt_registry.py

from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Awaitable

from ..protocol import (
    Prompt,
    PromptArgument,
    PromptMessage,
    GetPromptResult,
    TextContent,
)

logger = logging.getLogger(__name__)


@dataclass
class RegisteredPrompt:
    """已注册提示词的完整定义"""
    prompt: Prompt
    template: str | Callable[[dict[str, str]], Awaitable[str]]
    server_id: str | None = None
    metadata: dict[str, Any] = field(default_factory=dict)


class PromptRegistry:
    """提示词注册中心
    
    功能:
    - 提示词注册与发现
    - 参数渲染
    - 模板支持
    """
    
    def __init__(self):
        self._prompts: dict[str, RegisteredPrompt] = {}
        self._lock = asyncio.Lock()

提示词注册

    async def register(
        self,
        name: str,
        template: str | Callable[[dict[str, str]], Awaitable[str]],
        description: str | None = None,
        arguments: list[PromptArgument] | None = None,
        server_id: str | None = None,
        metadata: dict[str, Any] | None = None,
    ) -> Prompt:
        """注册提示词
        
        Args:
            name: 提示词名称
            template: 模板字符串或渲染函数
            description: 提示词描述
            arguments: 参数定义
            server_id: 所属服务器 ID
            metadata: 扩展元数据
            
        Returns:
            注册成功的 Prompt 对象
        """
        async with self._lock:
            if name in self._prompts:
                raise ValueError(f"Prompt already registered: {name}")
            
            prompt = Prompt(
                name=name,
                description=description,
                arguments=arguments,
            )
            
            registered = RegisteredPrompt(
                prompt=prompt,
                template=template,
                server_id=server_id,
                metadata=metadata or {},
            )
            
            self._prompts[name] = registered
            logger.info(f"Registered prompt: {name}")
            
            return prompt

提示词获取与渲染

    async def get(
        self,
        name: str,
        arguments: dict[str, str] | None = None,
    ) -> GetPromptResult:
        """获取渲染后的提示词
        
        Args:
            name: 提示词名称
            arguments: 渲染参数
            
        Returns:
            GetPromptResult 包含消息列表
        """
        registered = self._prompts.get(name)
        
        if registered is None:
            raise ValueError(f"Prompt not found: {name}")
        
        # 验证必需参数
        if registered.prompt.arguments:
            for arg in registered.prompt.arguments:
                if arg.required and (not arguments or arg.name not in arguments):
                    raise ValueError(f"Missing required argument: {arg.name}")
        
        # 渲染模板
        if callable(registered.template):
            text = await registered.template(arguments or {})
        else:
            text = self._render_template(registered.template, arguments or {})
        
        return GetPromptResult(
            description=registered.prompt.description,
            messages=[
                PromptMessage(
                    role="user",
                    content=TextContent(text=text),
                )
            ],
        )
    
    def _render_template(
        self, 
        template: str, 
        arguments: dict[str, str]
    ) -> str:
        """渲染模板字符串"""
        result = template
        for key, value in arguments.items():
            result = result.replace(f"{{{key}}}", value)
        return result

全局实例

_prompt_registry: PromptRegistry | None = None


def get_prompt_registry() -> PromptRegistry:
    """获取全局提示词注册中心"""
    global _prompt_registry
    if _prompt_registry is None:
        _prompt_registry = PromptRegistry()
    return _prompt_registry

内置工具注册示例

# 注册内置工具
async def _register_builtin_tools(registry: ToolRegistry):
    """注册网关内置工具"""
    
    # Echo 工具 - 用于测试
    async def handle_echo(message: str) -> str:
        return message
    
    await registry.register(
        name="echo",
        description="Echo back the input message",
        input_schema={
            "type": "object",
            "properties": {
                "message": {
                    "type": "string",
                    "description": "Message to echo back"
                }
            },
            "required": ["message"]
        },
        handler=handle_echo,
        annotations=ToolAnnotations(
            read_only_hint=True,
            destructive_hint=False,
            idempotent_hint=True,
        ),
    )
    
    # 服务器信息工具
    async def handle_server_info() -> dict[str, Any]:
        return {
            "name": "mcp-gateway-core",
            "version": "1.0.0",
            "tools_count": len(registry._tools),
        }
    
    await registry.register(
        name="server_info",
        description="Get information about the MCP Gateway server",
        input_schema={"type": "object", "properties": {}},
        handler=handle_server_info,
        annotations=ToolAnnotations(
            read_only_hint=True,
            destructive_hint=False,
        ),
    )

使用示例

注册自定义工具

from mcp_gateway_core import get_registry

async def my_tool(name: str, count: int = 1) -> str:
    """自定义工具实现"""
    return f"Hello, {name}! " * count

async def register():
    registry = get_registry()
    await registry.register(
        name="greet",
        description="Greet a person multiple times",
        input_schema={
            "type": "object",
            "properties": {
                "name": {
                    "type": "string",
                    "description": "Name to greet"
                },
                "count": {
                    "type": "integer",
                    "description": "Number of greetings",
                    "default": 1,
                    "minimum": 1,
                    "maximum": 10
                }
            },
            "required": ["name"]
        },
        handler=my_tool,
    )

注册资源

from mcp_gateway_core.registry import get_resource_registry
from mcp_gateway_core.protocol import TextResourceContents

async def read_status(uri: str) -> TextResourceContents:
    """读取应用状态"""
    return TextResourceContents(
        uri=uri,
        mime_type="application/json",
        text='{"status": "ok", "uptime": 3600}'
    )

async def register():
    registry = get_resource_registry()
    await registry.register(
        uri="myapp://status",
        name="Application Status",
        read_handler=read_status,
        description="Current application status",
    )

注册提示词

from mcp_gateway_core.registry import get_prompt_registry
from mcp_gateway_core.protocol import PromptArgument

async def register():
    registry = get_prompt_registry()
    await registry.register(
        name="summarize",
        description="Generate a summary prompt",
        arguments=[
            PromptArgument(
                name="text",
                description="Text to summarize",
                required=True
            ),
            PromptArgument(
                name="length",
                description="Summary length (short/medium/long)",
                required=False
            ),
        ],
        template="Please summarize the following text in a {length} manner:\n\n{text}",
    )

设计亮点

特性说明面试价值
异步安全使用 asyncio.Lock 保护并发访问并发编程能力
JSON Schema 验证完整的参数验证机制API 设计规范
结果自动转换将各种类型转换为 MCP Content 格式协议适配能力
联邦支持按 server_id 分组管理分布式设计思维
订阅通知资源变更通知机制观察者模式

小结

本章详细实现了工具、资源、提示词三大注册中心。它们是 MCP Gateway 的核心组件,提供了完整的能力管理机制。

关键要点

  1. ToolRegistry 负责工具注册、发现、执行和参数验证
  2. ResourceRegistry 支持资源读取和订阅通知机制
  3. PromptRegistry 提供提示词模板和参数渲染
  4. 使用 asyncio.Lock 保证并发安全
  5. JSON Schema 提供严格的参数验证

下一章我们将实现认证授权模块,包括 JWT、API Key 和 RBAC 权限模型。

参考资料

分享: