MCP 最佳实践与安全指南
本指南总结了 MCP 开发、部署和使用的最佳实践,重点关注安全性、性能优化和可维护性。
🔐 安全性最佳实践
1. 输入验证与清理
原则:永远不要信任用户输入,所有外部数据都应该被视为不可信。
import re
from typing import Optional
# 正确:严格的输入验证
@server.tool()
async def query_database(query: str) -> str:
# 验证 SQL 查询
if not query or len(query) > 10000:
return "错误:查询长度超出限制"
# 只允许 SELECT 查询
if not re.match(r'^\s*SELECT\s+', query, re.IGNORECASE):
return "错误:只允许 SELECT 查询"
# 检测危险的 SQL 关键字
dangerous_keywords = ['DROP', 'DELETE', 'UPDATE', 'INSERT', 'ALTER', 'EXEC']
for keyword in dangerous_keywords:
if keyword in query.upper():
return f"错误:不允许使用 {keyword}"
# 执行查询(使用参数化查询)
try:
result = await execute_safe_query(query)
return result
except Exception as e:
return f"查询错误: {str(e)}"
# 错误:直接执行用户输入
@server.tool()
async def unsafe_query(query: str) -> str:
# 不要这样做!存在 SQL 注入风险
return await execute_raw_query(query)2. 权限最小化原则
原则:只授予必要的最小权限,避免过度授权。
import os
from pathlib import Path
# 配置允许的路径
ALLOWED_PATHS = [
"/safe/directory1",
"/safe/directory2"
]
def is_path_allowed(path: str) -> bool:
"""检查路径是否在允许的范围内"""
resolved_path = Path(path).resolve()
for allowed in ALLOWED_PATHS:
allowed_path = Path(allowed).resolve()
try:
resolved_path.relative_to(allowed_path)
return True
except ValueError:
continue
return False
@server.tool()
async def read_file(file_path: str) -> str:
# 验证路径
if not is_path_allowed(file_path):
return "错误:路径不在允许的范围内"
# 检查路径遍历攻击
if ".." in file_path or file_path.startswith("/"):
return "错误:无效的路径"
try:
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
return f"读取文件失败: {str(e)}"3. 环境变量管理
原则:敏感信息(API 密钥、密码等)必须存储在环境变量中,绝不能硬编码。
import os
from typing import Optional
# 好的做法:从环境变量读取
class Config:
def __init__(self):
self.api_key = os.getenv("MY_API_KEY")
self.db_password = os.getenv("DB_PASSWORD")
self.encryption_key = os.getenv("ENCRYPTION_KEY")
# 验证必需的环境变量
if not self.api_key:
raise ValueError("MY_API_KEY 环境变量未设置")
config = Config()
# 坏的做法:硬编码敏感信息
# API_KEY = "sk-1234567890abcdef" # 永远不要这样做!配置文件示例:
{
"mcp_servers": [
{
"name": "secure-service",
"command": "python",
"args": ["server.py"],
"env": {
"API_KEY": "${API_KEY}", # 使用环境变量引用
"DB_PASSWORD": "${DB_PASSWORD}"
}
}
]
}4. 认证与授权
from typing import Dict, Optional
import jwt
# 用户权限映射
USER_PERMISSIONS = {
"user1": ["read", "write"],
"admin": ["read", "write", "delete", "admin"]
}
def verify_token(token: str) -> Optional[Dict]:
"""验证 JWT token"""
try:
payload = jwt.decode(
token,
os.getenv("JWT_SECRET"),
algorithms=["HS256"]
)
return payload
except jwt.InvalidTokenError:
return None
def check_permission(user: str, action: str) -> bool:
"""检查用户权限"""
permissions = USER_PERMISSIONS.get(user, [])
return action in permissions
@server.tool()
async def admin_operation(token: str, data: str) -> str:
# 验证 token
payload = verify_token(token)
if not payload:
return "错误:无效的认证令牌"
user = payload.get("username")
# 检查权限
if not check_permission(user, "admin"):
return f"错误:用户 {user} 没有执行此操作的权限"
# 执行操作
return "操作成功"5. 敏感数据处理
import re
def mask_sensitive_data(data: str, patterns: list) -> str:
"""掩码敏感数据"""
masked = data
for pattern, mask in patterns:
masked = re.sub(pattern, mask, masked)
return masked
@server.tool()
async def process_data(input_data: str) -> str:
# 处理数据
result = process(input_data)
# 掩码敏感信息
masked_result = mask_sensitive_data(
result,
[
(r'\b\d{16}\b', '****-****-****-****'), # 信用卡号
(r'\b\d{3}-\d{2}-\d{4}\b', '***-**-****'), # SSN
(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', '***@***.***'), # 邮箱
(r'["\']?password["\']?\s*[:=]\s*["\']?[^"\'\s]+["\']?', 'password=***'), # 密码
]
)
# 记录日志时使用掩码后的数据
logger.info(f"处理结果: {masked_result}")
return result⚡ 性能优化
1. 缓存策略
from functools import lru_cache
from datetime import datetime, timedelta
import hashlib
import json
# 简单的内存缓存
class CacheManager:
def __init__(self):
self.cache = {}
self.expirations = {}
def get(self, key: str) -> Optional[any]:
"""获取缓存"""
if key not in self.cache:
return None
# 检查是否过期
if datetime.now() > self.expirations[key]:
del self.cache[key]
del self.expirations[key]
return None
return self.cache[key]
def set(self, key: str, value: any, ttl_seconds: int = 300):
"""设置缓存"""
self.cache[key] = value
self.expirations[key] = datetime.now() + timedelta(seconds=ttl_seconds)
cache = CacheManager()
@server.tool()
async def cached_api_call(endpoint: str) -> str:
cache_key = hashlib.md5(endpoint.encode()).hexdigest()
# 尝试从缓存获取
cached_result = cache.get(cache_key)
if cached_result:
return f"[缓存] {cached_result}"
# 调用 API
result = await fetch_from_api(endpoint)
# 存入缓存(5分钟有效期)
cache.set(cache_key, result, ttl_seconds=300)
return result
# 使用 lru_cache(适用于参数较少的场景)
@lru_cache(maxsize=128)
def expensive_computation(param1: int, param2: str) -> str:
"""缓存计算结果"""
# 耗时的计算...
return result2. 异步 I/O
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor
# 异步 HTTP 请求
async def fetch_multiple_urls(urls: list) -> list:
"""并发获取多个 URL"""
async with aiohttp.ClientSession() as session:
tasks = [fetch_url(session, url) for url in urls]
results = await asyncio.gather(*tasks)
return results
async def fetch_url(session, url: str) -> str:
async with session.get(url) as response:
return await response.text()
@server.tool()
async def parallel_fetch(urls: list) -> str:
"""并发获取多个网页内容"""
results = await fetch_multiple_urls(urls)
return json.dumps(results, ensure_ascii=False)
# 使用线程池处理 CPU 密集型任务
executor = ThreadPoolExecutor(max_workers=4)
@server.tool()
async def cpu_intensive_task(data: str) -> str:
"""CPU 密集型任务(在线程池中执行)"""
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
executor,
process_heavy_computation,
data
)
return result3. 连接池管理
import asyncpg
from contextlib import asynccontextmanager
class DatabasePool:
def __init__(self, dsn: str):
self.dsn = dsn
self.pool = None
async def init(self):
"""初始化连接池"""
self.pool = await asyncpg.create_pool(
self.dsn,
min_size=2,
max_size=10
)
async def close(self):
"""关闭连接池"""
if self.pool:
await self.pool.close()
@asynccontextmanager
async def acquire(self):
"""获取连接"""
async with self.pool.acquire() as conn:
yield conn
# 全局连接池
db_pool = DatabasePool(os.getenv("DATABASE_URL"))
# 在服务器启动时初始化
async def on_startup():
await db_pool.init()
@server.tool()
async def query_database(sql: str) -> str:
"""执行数据库查询"""
async with db_pool.acquire() as conn:
rows = await conn.fetch(sql)
return json.dumps([dict(row) for row in rows], ensure_ascii=False)4. 批量操作优化
@server.tool()
async def batch_operations(items: list) -> str:
"""批量处理操作"""
# 好的做法:批量插入
async with db_pool.acquire() as conn:
await conn.executemany(
"INSERT INTO items (name, value) VALUES ($1, $2)",
[(item["name"], item["value"]) for item in items]
)
return f"成功插入 {len(items)} 条记录"
# 坏的做法:逐条插入
@server.tool()
async def inefficient_batch(items: list) -> str:
# 不要这样做!效率低下
for item in items:
await insert_item(item)
return "完成"🧪 测试最佳实践
1. 单元测试
import pytest
from your_server import add, subtract, multiply
class TestMathTools:
"""测试数学工具"""
@pytest.mark.asyncio
async def test_add_positive_numbers(self):
"""测试正数相加"""
result = await add(2, 3)
assert "5" in result
@pytest.mark.asyncio
async def test_add_negative_numbers(self):
"""测试负数相加"""
result = await add(-2, -3)
assert "-5" in result
@pytest.mark.asyncio
async def test_add_zero(self):
"""测试加零"""
result = await add(5, 0)
assert "5" in result
@pytest.mark.asyncio
async def test_invalid_input(self):
"""测试无效输入"""
result = await add("abc", 3) # 类型错误
assert "错误" in result2. 集成测试
import pytest
from testcontainers.postgres import PostgresContainer
@pytest.fixture
async def postgres_container():
"""PostgreSQL 测试容器"""
with PostgresContainer("postgres:16") as postgres:
yield postgres
@pytest.mark.asyncio
async def test_database_integration(postgres_container):
"""测试数据库集成"""
# 使用测试数据库
connection = await asyncpg.connect(postgres_container.get_connection_url())
# 创建测试表
await connection.execute("""
CREATE TABLE test_items (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL
)
""")
# 插入测试数据
await connection.execute(
"INSERT INTO test_items (name) VALUES ($1)",
"test"
)
# 验证数据
result = await connection.fetchval(
"SELECT name FROM test_items WHERE id = 1"
)
assert result == "test"
await connection.close()3. Mock 测试
from unittest.mock import AsyncMock, patch
import aiohttp
@pytest.mark.asyncio
async def test_external_api_mock():
"""测试外部 API(使用 Mock)"""
# Mock HTTP 响应
mock_response = AsyncMock()
mock_response.text.return_value = '{"data": "test"}'
with patch.object(aiohttp.ClientSession, 'get', return_value=mock_response):
result = await fetch_from_api("http://example.com/api")
assert '{"data": "test"}' in result📊 监控与日志
1. 结构化日志
import logging
import json
from datetime import datetime
# 配置结构化日志
class StructuredFormatter(logging.Formatter):
def format(self, record):
log_data = {
"timestamp": datetime.utcnow().isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
"function": record.funcName,
"line": record.lineno
}
if hasattr(record, "extra"):
log_data.update(record.extra)
return json.dumps(log_data, ensure_ascii=False)
# 配置日志
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(StructuredFormatter())
logger.addHandler(handler)
logger.setLevel(logging.INFO)
# 使用日志
@server.tool()
async def logged_operation(input: str) -> str:
logger.info(
"开始执行操作",
extra={
"operation": "process",
"input_length": len(input),
"user_id": "user123"
}
)
try:
result = process(input)
logger.info("操作成功", extra={"result_length": len(result)})
return result
except Exception as e:
logger.error("操作失败", extra={"error": str(e)})
raise2. 性能监控
import time
from functools import wraps
from collections import defaultdict
# 性能指标收集器
class PerformanceMonitor:
def __init__(self):
self.metrics = defaultdict(list)
def record(self, name: str, duration: float):
"""记录性能指标"""
self.metrics[name].append(duration)
def get_stats(self, name: str) -> dict:
"""获取统计信息"""
durations = self.metrics[name]
if not durations:
return {}
return {
"count": len(durations),
"min": min(durations),
"max": max(durations),
"avg": sum(durations) / len(durations)
}
monitor = PerformanceMonitor()
def monitor_performance(func):
"""性能监控装饰器"""
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
try:
result = await func(*args, **kwargs)
duration = time.time() - start
monitor.record(func.__name__, duration)
return result
except Exception as e:
duration = time.time() - start
monitor.record(f"{func.__name__}_error", duration)
raise
return wrapper
@server.tool()
@monitor_performance
async def monitored_tool(input: str) -> str:
return process(input)
# 获取性能统计
@server.tool()
async def get_performance_stats() -> str:
"""获取性能统计"""
return json.dumps(monitor.metrics, ensure_ascii=False, indent=2)3. 健康检查
from fastapi import FastAPI
import psutil
import asyncpg
app = FastAPI()
@app.get("/health")
async def health_check():
"""健康检查端点"""
health_status = {
"status": "healthy",
"timestamp": datetime.utcnow().isoformat()
}
# 检查数据库连接
try:
async with db_pool.acquire() as conn:
await conn.fetchval("SELECT 1")
health_status["database"] = "connected"
except Exception as e:
health_status["status"] = "unhealthy"
health_status["database"] = f"disconnected: {str(e)}"
# 检查系统资源
health_status["cpu_usage"] = psutil.cpu_percent()
health_status["memory_usage"] = psutil.virtual_memory().percent
# 检查磁盘空间
disk = psutil.disk_usage('/')
health_status["disk_usage"] = {
"used": disk.used,
"free": disk.free,
"percent": disk.percent
}
return health_status🚀 错误处理最佳实践
1. 优雅的错误处理
from typing import Union
import traceback
class MCPError(Exception):
"""自定义 MCP 错误"""
def __init__(self, message: str, code: str = "ERROR"):
self.message = message
self.code = code
super().__init__(message)
def handle_error(error: Exception) -> dict:
"""统一错误处理"""
if isinstance(error, MCPError):
return {
"success": False,
"code": error.code,
"message": error.message
}
else:
return {
"success": False,
"code": "INTERNAL_ERROR",
"message": "内部服务器错误",
"details": str(error)
}
@server.tool()
async def robust_operation(input: str) -> str:
"""健壮的操作处理"""
try:
# 验证输入
if not input or len(input) > 1000:
raise MCPError("输入无效", "INVALID_INPUT")
# 执行操作
result = await process(input)
return json.dumps({
"success": True,
"data": result
}, ensure_ascii=False)
except MCPError as e:
logger.error(f"MCP 错误: {e.message}")
return json.dumps(handle_error(e), ensure_ascii=False)
except Exception as e:
logger.error(f"未预期的错误: {str(e)}")
logger.error(traceback.format_exc())
return json.dumps(handle_error(e), ensure_ascii=False)2. 重试机制
import asyncio
from typing import Callable, TypeVar
T = TypeVar('T')
async def retry(
func: Callable[..., T],
max_retries: int = 3,
delay: float = 1.0,
backoff: float = 2.0,
exceptions: tuple = (Exception,)
) -> T:
"""重试机制"""
last_error = None
current_delay = delay
for attempt in range(max_retries):
try:
return await func()
except exceptions as e:
last_error = e
if attempt < max_retries - 1:
logger.warning(
f"重试 {attempt + 1}/{max_retries}: {str(e)}",
extra={"delay": current_delay}
)
await asyncio.sleep(current_delay)
current_delay *= backoff
else:
logger.error(f"达到最大重试次数: {str(e)}")
raise last_error
@server.tool()
async def operation_with_retry(url: str) -> str:
"""带重试机制的操作"""
async def fetch():
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
response.raise_for_status()
return await response.text()
try:
result = await retry(fetch, max_retries=3)
return f"成功: {result}"
except Exception as e:
return f"失败: {str(e)}"📚 文档和可维护性
1. 完整的 API 文档
@server.tool()
async def complex_tool(
param1: str,
param2: int,
optional_param: Optional[bool] = False
) -> str:
"""
复杂工具的完整文档
参数:
param1: 第一个参数,必须是有效的字符串,长度在1-100之间
param2: 第二个参数,必须是正整数
optional_param: 可选参数,默认为 False
返回:
操作结果的 JSON 字符串
示例:
>>> complex_tool("test", 42)
'{"success": true, "result": "..."}'
错误:
ValueError: 如果参数无效
ConnectionError: 如果连接失败
"""
# 实现逻辑
pass2. 配置管理
from dataclasses import dataclass
from typing import Optional
import json
@dataclass
class ServerConfig:
"""服务器配置"""
host: str = "0.0.0.0"
port: int = 8080
debug: bool = False
max_connections: int = 100
log_level: str = "INFO"
@classmethod
def from_dict(cls, data: dict) -> 'ServerConfig':
"""从字典创建配置"""
return cls(**{k: v for k, v in data.items() if k in cls.__annotations__})
@classmethod
def from_file(cls, filepath: str) -> 'ServerConfig':
"""从文件加载配置"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
return cls.from_dict(data)
# 使用配置
config = ServerConfig.from_file("config.json")
@server.tool()
async def configurable_operation() -> str:
"""使用配置的操作"""
logger.info(f"调试模式: {config.debug}")
# 实现逻辑
pass🎯 总结
安全检查清单
- 所有输入都经过验证和清理
- 敏感信息存储在环境变量中
- 实现了最小权限原则
- 使用认证和授权机制
- 实现了速率限制
- 日志中不记录敏感信息
性能检查清单
- 使用缓存减少重复计算
- 异步 I/O 操作并发执行
- 使用连接池管理数据库连接
- 批量操作而非逐条处理
- 监控性能指标
可维护性检查清单
- 编写完整的单元测试
- 实现结构化日志
- 编写清晰的 API 文档
- 使用配置文件管理设置
- 实现优雅的错误处理
- 监控系统健康状态
遵循这些最佳实践,可以构建出安全、高效、可维护的 MCP 服务器。