Chat2BI系统架构设计
本文详细介绍如何设计和实现一个生产级的 Chat2BI 系统架构。
系统架构概览
graph TB
A[用户界面层] --> B[API网关]
B --> C[对话管理服务]
C --> D[NL2SQL服务]
C --> E[查询执行服务]
C --> F[可视化服务]
D --> G[Schema服务]
D --> H[LLM服务]
E --> I[数据库集群]
F --> J[图表引擎]
C --> K[缓存层]
C --> L[监控服务]
1. 分层架构设计
用户界面层
提供多种交互方式:
// React前端组件
interface Chat2BIProps {
sessionId: string;
datasource: string;
}
const Chat2BIComponent: React.FC<Chat2BIProps> = ({
sessionId,
datasource
}) => {
const [messages, setMessages] = useState<Message[]>([]);
const [loading, setLoading] = useState(false);
const sendQuestion = async (question: string) => {
setLoading(true);
try {
const response = await fetch('/api/chat2bi/query', {
method: 'POST',
body: JSON.stringify({
session_id: sessionId,
question: question,
datasource: datasource
})
});
const result = await response.json();
// 添加结果到消息列表
setMessages([...messages, {
type: 'user',
content: question
}, {
type: 'assistant',
content: result.answer,
visualization: result.visualization,
insights: result.insights
}]);
} finally {
setLoading(false);
}
};
return (
<div className="chat2bi-container">
<MessageList messages={messages} />
<InputBox onSend={sendQuestion} loading={loading} />
</div>
);
};
API服务层
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
app = FastAPI()
class QueryRequest(BaseModel):
session_id: str
question: str
datasource: str
options: dict = {}
class QueryResponse(BaseModel):
answer: str
sql: str
visualization: dict
insights: list
execution_time: float
@app.post("/api/chat2bi/query", response_model=QueryResponse)
async def query(
request: QueryRequest,
user=Depends(get_current_user)
):
"""处理Chat2BI查询"""
start_time = time.time()
try:
# 1. 权限检查
check_permission(user, request.datasource)
# 2. 处理查询
result = await chat2bi_service.process_query(
question=request.question,
session_id=request.session_id,
datasource=request.datasource,
user_id=user.id
)
# 3. 记录审计日志
audit_logger.log({
'user_id': user.id,
'question': request.question,
'sql': result['sql'],
'timestamp': time.time()
})
execution_time = time.time() - start_time
return QueryResponse(
answer=result['answer'],
sql=result['sql'],
visualization=result['visualization'],
insights=result['insights'],
execution_time=execution_time
)
except Exception as e:
logger.error(f"Query failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
2. NL2SQL服务架构
Schema管理服务
class SchemaService:
def __init__(self, db_connections, vector_store):
self.db_connections = db_connections
self.vector_store = vector_store
self.schema_cache = TTLCache(maxsize=1000, ttl=3600)
async def get_relevant_schema(
self,
question: str,
datasource: str
):
"""获取相关schema"""
# 检查缓存
cache_key = f"{datasource}:{hash(question)}"
if cache_key in self.schema_cache:
return self.schema_cache[cache_key]
# 向量检索相关表
question_embedding = await self.embed_text(question)
relevant_tables = await self.vector_store.search(
embedding=question_embedding,
filter={'datasource': datasource},
top_k=10
)
# 构建schema上下文
schema_context = await self.build_schema_context(
relevant_tables,
datasource
)
# 缓存结果
self.schema_cache[cache_key] = schema_context
return schema_context
async def build_schema_context(self, tables, datasource):
"""构建schema上下文"""
db = self.db_connections[datasource]
schema_parts = []
for table in tables:
# 获取表结构
columns = await db.get_columns(table.name)
# 获取外键关系
foreign_keys = await db.get_foreign_keys(table.name)
# 获取示例数据
samples = await db.get_sample_data(
table.name,
limit=3
)
# 格式化schema
schema_part = f"""
表名:{table.name}
描述:{table.description}
列:
{self.format_columns(columns)}
关系:
{self.format_foreign_keys(foreign_keys)}
示例数据:
{self.format_samples(samples)}
"""
schema_parts.append(schema_part)
return "\n---\n".join(schema_parts)
SQL生成服务
class SQLGenerationService:
def __init__(self, llm_client, schema_service):
self.llm = llm_client
self.schema_service = schema_service
self.sql_cache = LRUCache(maxsize=10000)
async def generate_sql(
self,
question: str,
datasource: str,
context: dict = None
):
"""生成SQL查询"""
# 检查缓存
cache_key = self.build_cache_key(
question, datasource, context
)
if cache_key in self.sql_cache:
return self.sql_cache[cache_key]
# 获取schema
schema = await self.schema_service.get_relevant_schema(
question, datasource
)
# 构建prompt
prompt = self.build_nl2sql_prompt(
question, schema, context
)
# 生成SQL(带重试机制)
sql = await self.generate_with_retry(prompt)
# 验证SQL
validated_sql = await self.validate_sql(sql, datasource)
# 缓存结果
self.sql_cache[cache_key] = validated_sql
return validated_sql
async def generate_with_retry(
self,
prompt: str,
max_retries: int = 3
):
"""带重试的SQL生成"""
for attempt in range(max_retries):
try:
sql = await self.llm.generate(prompt)
# 基础语法检查
self.check_syntax(sql)
return sql
except SQLSyntaxError as e:
if attempt == max_retries - 1:
raise
# 添加错误信息到prompt重试
prompt += f"\n\n上次生成的SQL有错误:{e}\n请修正并重新生成。"
def build_nl2sql_prompt(self, question, schema, context):
"""构建NL2SQL prompt"""
context_str = ""
if context and context.get('history'):
recent_history = context['history'][-3:]
context_str = "\n".join([
f"Q: {h['question']}\nSQL: {h['sql']}"
for h in recent_history
])
return f"""
你是一个SQL专家。请根据以下信息生成准确的SQL查询。
# 数据库Schema
{schema}
# 对话历史
{context_str}
# 用户问题
{question}
# 要求
1. 只返回SQL语句,不要有任何解释
2. 使用标准SQL语法
3. 添加适当的WHERE条件进行数据过滤
4. 如果需要聚合,使用GROUP BY
5. 考虑性能优化,添加LIMIT
6. 使用表别名简化查询
7. 注意日期格式和时区
# SQL查询
```sql
"""
SQL验证与优化
class SQLValidator:
def __init__(self, db_connections):
self.db_connections = db_connections
self.blocked_keywords = [
'DROP', 'DELETE', 'TRUNCATE',
'UPDATE', 'INSERT', 'ALTER'
]
async def validate(self, sql: str, datasource: str):
"""验证SQL"""
# 1. 安全检查
self.check_security(sql)
# 2. 语法验证
await self.check_syntax(sql, datasource)
# 3. 性能检查
estimated_cost = await self.estimate_cost(sql, datasource)
if estimated_cost > 1000: # 成本阈值
sql = self.optimize_sql(sql)
return sql
def check_security(self, sql: str):
"""安全检查"""
sql_upper = sql.upper()
# 检查危险关键字
for keyword in self.blocked_keywords:
if keyword in sql_upper:
raise SecurityError(
f"不允许使用关键字: {keyword}"
)
# 检查SQL注入模式
injection_patterns = [
r";\s*DROP",
r"UNION\s+SELECT",
r"--",
r"/\*.*\*/"
]
for pattern in injection_patterns:
if re.search(pattern, sql_upper):
raise SecurityError("检测到可疑的SQL模式")
async def estimate_cost(self, sql: str, datasource: str):
"""估算查询成本"""
db = self.db_connections[datasource]
# 使用EXPLAIN获取执行计划
explain_result = await db.execute(f"EXPLAIN {sql}")
# 分析执行计划,计算成本
cost = self.calculate_cost_from_plan(explain_result)
return cost
3. 查询执行服务
异步查询执行器
class QueryExecutor:
def __init__(self, db_pool, result_cache):
self.db_pool = db_pool
self.result_cache = result_cache
self.query_queue = asyncio.Queue()
self.max_concurrent = 10
async def execute(
self,
sql: str,
datasource: str,
timeout: int = 30
):
"""执行SQL查询"""
# 检查缓存
cache_key = hashlib.md5(
f"{datasource}:{sql}".encode()
).hexdigest()
if cache_key in self.result_cache:
return self.result_cache[cache_key]
# 获取数据库连接
async with self.db_pool.acquire(datasource) as conn:
try:
# 设置查询超时
await conn.execute(
f"SET statement_timeout = {timeout * 1000}"
)
# 执行查询
result = await conn.fetch(sql)
# 转换为DataFrame
df = pd.DataFrame(result)
# 缓存结果
self.result_cache.set(
cache_key,
df,
ttl=3600
)
return df
except asyncio.TimeoutError:
raise QueryTimeoutError(
f"查询超时({timeout}秒)"
)
except Exception as e:
logger.error(f"查询执行失败: {e}")
raise QueryExecutionError(str(e))
结果分页
class PaginationHandler:
def __init__(self, page_size=100):
self.page_size = page_size
def paginate(self, data, page=1):
"""分页处理"""
total_rows = len(data)
total_pages = math.ceil(total_rows / self.page_size)
start_idx = (page - 1) * self.page_size
end_idx = min(start_idx + self.page_size, total_rows)
page_data = data[start_idx:end_idx]
return {
'data': page_data,
'pagination': {
'current_page': page,
'page_size': self.page_size,
'total_pages': total_pages,
'total_rows': total_rows
}
}
4. 可视化服务架构
图表推荐引擎
class ChartRecommendationEngine:
def __init__(self, llm_client=None):
self.llm = llm_client
self.rules = self.load_chart_rules()
def recommend(self, data, question=None):
"""推荐图表类型"""
# 1. 数据分析
profile = self.analyze_data_profile(data)
# 2. 基于规则推荐
rule_recommendations = self.rule_based_recommend(profile)
# 3. 基于LLM推荐(可选)
if self.llm and question:
llm_recommendation = self.llm_based_recommend(
profile, question
)
# 合并推荐结果
final_recommendation = self.merge_recommendations(
rule_recommendations,
llm_recommendation
)
else:
final_recommendation = rule_recommendations[0]
return final_recommendation
def analyze_data_profile(self, data):
"""分析数据特征"""
profile = {
'num_rows': len(data),
'num_columns': len(data.columns),
'column_types': {},
'has_time_column': False,
'categorical_columns': [],
'numeric_columns': []
}
for col in data.columns:
dtype = data[col].dtype
if pd.api.types.is_numeric_dtype(dtype):
profile['numeric_columns'].append(col)
elif pd.api.types.is_datetime64_any_dtype(dtype):
profile['has_time_column'] = True
profile['time_column'] = col
else:
profile['categorical_columns'].append(col)
profile['column_types'][col] = 'categorical'
# 分析基数
profile['cardinality'] = {
col: data[col].nunique()
for col in profile['categorical_columns']
}
return profile
def rule_based_recommend(self, profile):
"""基于规则推荐"""
recommendations = []
# 时间序列 -> 折线图
if profile['has_time_column']:
recommendations.append({
'chart_type': 'line',
'confidence': 0.9,
'reason': '数据包含时间维度'
})
# 分类对比 -> 柱状图
if (len(profile['categorical_columns']) >= 1 and
len(profile['numeric_columns']) >= 1):
cat_col = profile['categorical_columns'][0]
if profile['cardinality'][cat_col] <= 20:
recommendations.append({
'chart_type': 'bar',
'confidence': 0.85,
'reason': '适合分类对比'
})
# 占比 -> 饼图
if (len(profile['categorical_columns']) == 1 and
profile['cardinality'][profile['categorical_columns'][0]] <= 7):
recommendations.append({
'chart_type': 'pie',
'confidence': 0.8,
'reason': '适合展示占比'
})
# 相关性 -> 散点图
if len(profile['numeric_columns']) >= 2:
recommendations.append({
'chart_type': 'scatter',
'confidence': 0.75,
'reason': '可以展示相关性'
})
# 按置信度排序
recommendations.sort(
key=lambda x: x['confidence'],
reverse=True
)
return recommendations if recommendations else [{
'chart_type': 'table',
'confidence': 1.0,
'reason': '默认表格展示'
}]
图表配置生成器
class ChartConfigGenerator:
def generate_echarts_config(
self,
data,
chart_type,
options=None
):
"""生成ECharts配置"""
if chart_type == 'line':
return self.generate_line_config(data, options)
elif chart_type == 'bar':
return self.generate_bar_config(data, options)
elif chart_type == 'pie':
return self.generate_pie_config(data, options)
elif chart_type == 'scatter':
return self.generate_scatter_config(data, options)
else:
return self.generate_table_config(data, options)
def generate_line_config(self, data, options):
"""生成折线图配置"""
# 识别时间列和数值列
time_col = self.identify_time_column(data)
value_cols = self.identify_numeric_columns(data)
config = {
'title': {
'text': options.get('title', '趋势分析')
},
'tooltip': {
'trigger': 'axis'
},
'legend': {
'data': value_cols
},
'xAxis': {
'type': 'category',
'data': data[time_col].tolist()
},
'yAxis': {
'type': 'value'
},
'series': [
{
'name': col,
'type': 'line',
'data': data[col].tolist(),
'smooth': True
}
for col in value_cols
]
}
return config
5. 缓存架构
多级缓存策略
class MultiLevelCache:
def __init__(self):
# L1: 内存缓存(最近访问)
self.memory_cache = LRUCache(maxsize=100)
# L2: Redis缓存(热点数据)
self.redis_cache = Redis()
# L3: 预计算缓存(常见查询)
self.precomputed_cache = PrecomputedCache()
async def get(self, key):
"""获取缓存"""
# L1缓存
if key in self.memory_cache:
return self.memory_cache[key]
# L2缓存
redis_value = await self.redis_cache.get(key)
if redis_value:
# 回填L1
self.memory_cache[key] = redis_value
return redis_value
# L3缓存
precomputed = await self.precomputed_cache.get(key)
if precomputed:
# 回填L1和L2
self.memory_cache[key] = precomputed
await self.redis_cache.set(key, precomputed, ex=3600)
return precomputed
return None
async def set(self, key, value, ttl=3600):
"""设置缓存"""
# 同时写入L1和L2
self.memory_cache[key] = value
await self.redis_cache.set(key, value, ex=ttl)
6. 监控与日志
性能监控
class PerformanceMonitor:
def __init__(self):
self.metrics = PrometheusMetrics()
def track_query(self, query_info):
"""跟踪查询性能"""
# 延迟指标
self.metrics.histogram(
'chat2bi_query_latency',
query_info.total_time,
labels={
'datasource': query_info.datasource,
'chart_type': query_info.chart_type
}
)
# NL2SQL时间
self.metrics.histogram(
'chat2bi_nl2sql_latency',
query_info.nl2sql_time
)
# SQL执行时间
self.metrics.histogram(
'chat2bi_sql_execution_latency',
query_info.sql_execution_time
)
# 成功率
self.metrics.counter(
'chat2bi_queries_total',
labels={
'status': 'success' if query_info.success else 'error'
}
)
审计日志
class AuditLogger:
def __init__(self, db_connection):
self.db = db_connection
async def log_query(self, log_entry):
"""记录查询审计日志"""
await self.db.execute("""
INSERT INTO audit_logs (
user_id,
question,
generated_sql,
datasource,
execution_time,
result_rows,
timestamp
) VALUES ($1, $2, $3, $4, $5, $6, $7)
""",
log_entry['user_id'],
log_entry['question'],
log_entry['sql'],
log_entry['datasource'],
log_entry['execution_time'],
log_entry['result_rows'],
log_entry['timestamp']
)
7. 扩展性设计
数据源插件架构
class DataSourcePlugin:
"""数据源插件基类"""
def __init__(self, config):
self.config = config
async def connect(self):
"""建立连接"""
raise NotImplementedError
async def execute_query(self, sql):
"""执行查询"""
raise NotImplementedError
async def get_schema(self):
"""获取schema"""
raise NotImplementedError
class MySQLPlugin(DataSourcePlugin):
"""MySQL数据源插件"""
async def connect(self):
self.conn = await aiomysql.connect(**self.config)
async def execute_query(self, sql):
async with self.conn.cursor() as cursor:
await cursor.execute(sql)
return await cursor.fetchall()
class PostgreSQLPlugin(DataSourcePlugin):
"""PostgreSQL数据源插件"""
async def connect(self):
self.conn = await asyncpg.connect(**self.config)
async def execute_query(self, sql):
return await self.conn.fetch(sql)
# 数据源管理器
class DataSourceManager:
def __init__(self):
self.plugins = {}
def register_plugin(self, name, plugin_class):
"""注册数据源插件"""
self.plugins[name] = plugin_class
async def get_connection(self, datasource_config):
"""获取数据源连接"""
plugin_class = self.plugins[datasource_config['type']]
plugin = plugin_class(datasource_config)
await plugin.connect()
return plugin
8. 安全架构
数据权限控制
class DataAccessControl:
def __init__(self, permission_db):
self.permission_db = permission_db
async def check_table_access(self, user_id, table_name):
"""检查表访问权限"""
permissions = await self.permission_db.get_user_permissions(
user_id
)
allowed_tables = permissions.get('allowed_tables', [])
denied_tables = permissions.get('denied_tables', [])
# 黑名单优先
if table_name in denied_tables:
return False
# 如果有白名单,检查是否在白名单中
if allowed_tables:
return table_name in allowed_tables
# 默认允许
return True
async def apply_row_level_security(
self,
sql,
user_id
):
"""应用行级安全"""
user_filters = await self.permission_db.get_user_filters(
user_id
)
if user_filters:
# 添加WHERE条件
modified_sql = self.inject_filters(sql, user_filters)
return modified_sql
return sql
总结
一个生产级的 Chat2BI 系统需要:
✅ 分层架构:清晰的职责划分
✅ 高性能:多级缓存、异步执行
✅ 可扩展:插件化数据源、灵活配置
✅ 安全可靠:权限控制、SQL验证
✅ 可观测:完善的监控和日志
✅ 智能推荐:自动图表选择和洞察生成
通过合理的架构设计,可以构建出高效、安全、易用的 Chat2BI 系统。