04. 上下文增强 RAG
在传统 RAG 中,我们通常只检索最相关的文本块。然而,有时候单个文本块可能缺乏足够的上下文信息来回答问题。上下文增强 RAG通过检索相关文本块的邻近块来提供更丰富的上下文信息。
核心思想
传统 RAG 检索过程:
文档: [块A] [块B] [块C] [块D] [块E]
查询: "什么是机器学习?"
检索结果: [块C] (最相关)
上下文增强 RAG 检索过程:
文档: [块A] [块B] [块C] [块D] [块E]
查询: "什么是机器学习?"
检索结果: [块B] [块C] [块D] (块C最相关,加上前后邻居)
技术优势
🎯 更完整的上下文
- 提供更多相关信息,避免信息断裂
- 保持文档的逻辑连贯性
📈 提高回答质量
- 减少因上下文不足导致的回答不完整
- 增强语言模型对复杂问题的理解能力
🔄 保持文本连贯性
- 维护原文档的叙述流畅性
- 避免孤立文本块造成的理解偏差
完整代码实现
import fitz
import os
import numpy as np
import json
from openai import OpenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
def extract_text_from_pdf(pdf_path):
"""
从PDF文件中提取文本内容
Args:
pdf_path (str): PDF文件路径
Returns:
str: 提取的文本内容
"""
mypdf = fitz.open(pdf_path)
all_text = ""
for page_num in range(mypdf.page_count):
page = mypdf[page_num]
text = page.get_text("text")
all_text += text
return all_text
def chunk_text(text, n, overlap):
"""
将文本分割成重叠的文本块
Args:
text (str): 要分割的文本
n (int): 每个文本块的字符数
overlap (int): 重叠字符数
Returns:
List[str]: 文本块列表
"""
chunks = []
for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])
return chunks
# 初始化OpenAI客户端
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY")
)
def create_embeddings(text, model="BAAI/bge-base-en-v1.5"):
"""
为给定文本创建嵌入向量
Args:
text (str): 输入文本
model (str): 嵌入模型名称
Returns:
List[float]: 嵌入向量
"""
embedding_model = HuggingFaceEmbedding(model_name=model)
if isinstance(text, list):
response = embedding_model.get_text_embedding_batch(text)
else:
response = embedding_model.get_text_embedding(text)
return response
def cosine_similarity(vec1, vec2):
"""
计算两个向量的余弦相似度
Args:
vec1 (np.ndarray): 第一个向量
vec2 (np.ndarray): 第二个向量
Returns:
float: 余弦相似度值
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def context_enriched_search(query, text_chunks, embeddings, k=1, context_size=1):
"""
执行上下文增强检索
Args:
query (str): 查询问题
text_chunks (List[str]): 文本块列表
embeddings (List): 嵌入向量列表
k (int): 检索的相关块数量
context_size (int): 上下文邻居块数量
Returns:
List[str]: 包含上下文的相关文本块
"""
# 将查询转换为嵌入向量
query_embedding = create_embeddings(query)
similarity_scores = []
# 计算查询与每个文本块的相似度
for i, chunk_embedding in enumerate(embeddings):
similarity_score = cosine_similarity(
np.array(query_embedding),
np.array(chunk_embedding)
)
similarity_scores.append((i, similarity_score))
# 按相似度降序排序
similarity_scores.sort(key=lambda x: x[1], reverse=True)
# 获取最相关块的索引
top_index = similarity_scores[0][0]
print(f'最相关块索引: {top_index}')
# 确定上下文范围,确保不超出边界
start = max(0, top_index - context_size)
end = min(len(text_chunks), top_index + context_size + 1)
# 返回相关块及其邻近上下文
return [text_chunks[i] for i in range(start, end)]
def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
"""
生成AI回答
Args:
system_prompt (str): 系统提示词
user_message (str): 用户消息
model (str): 使用的模型
Returns:
str: AI生成的回答
"""
response = client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
)
return response.choices[0].message.content
实际应用示例
# 1. 文档处理
pdf_path = "data/AI_Information.pdf"
extracted_text = extract_text_from_pdf(pdf_path)
# 2. 文本分块
text_chunks = chunk_text(extracted_text, 1000, 200)
print(f"创建了 {len(text_chunks)} 个文本块")
# 3. 创建嵌入向量
embeddings = create_embeddings(text_chunks)
# 4. 加载测试查询
with open('data/val.json') as f:
data = json.load(f)
query = data[0]['question']
print(f"查询: {query}")
# 5. 执行上下文增强检索
# context_size=1 表示包含前后各1个邻居块
top_chunks = context_enriched_search(
query,
text_chunks,
embeddings,
k=1,
context_size=1
)
print(f"检索到 {len(top_chunks)} 个上下文块")
# 6. 显示检索结果
for i, chunk in enumerate(top_chunks):
print(f"上下文 {i + 1}:\n{chunk}\n" + "="*50)
# 7. 生成最终回答
system_prompt = "你是一个AI助手,严格基于给定的上下文回答问题。如果无法从提供的上下文中得出答案,请回答:'我没有足够的信息来回答这个问题。'"
# 组合上下文
context = "\n\n".join([f"上下文{i+1}: {chunk}" for i, chunk in enumerate(top_chunks)])
user_message = f"上下文:\n{context}\n\n问题: {query}"
response = generate_response(system_prompt, user_message)
print(f"AI回答: {response}")
关键技术解析
1. 邻居块选择策略
# 确定上下文范围的核心逻辑
start = max(0, top_index - context_size) # 前邻居边界
end = min(len(text_chunks), top_index + context_size + 1) # 后邻居边界
# 示例:如果最相关块是索引5,context_size=2
# start = max(0, 5-2) = 3
# end = min(total_chunks, 5+2+1) = 8
# 结果:返回索引[3,4,5,6,7]的块
2. 上下文大小调优
不同的context_size
会产生不同的效果:
- context_size=0: 等同于传统 RAG,只返回最相关块
- context_size=1: 返回相关块及前后各 1 个邻居(推荐)
- context_size=2: 返回相关块及前后各 2 个邻居
- 过大的 context_size: 可能引入噪声信息
3. 边界处理
# 处理文档开头和结尾的边界情况
start = max(0, top_index - context_size) # 防止负索引
end = min(len(text_chunks), top_index + 1 + context_size) # 防止超出范围
效果对比
传统 RAG vs 上下文增强 RAG
查询: "什么是深度学习的反向传播算法?"
传统 RAG 结果:
只返回最相关的1个块,可能只包含反向传播的定义,
缺乏算法步骤和实际应用的详细说明。
上下文增强 RAG 结果:
返回3个连续块:
1. 深度学习基础概念(背景)
2. 反向传播算法详解(核心)
3. 算法实现和应用(扩展)
提供更完整、连贯的信息。
参数调优建议
1. context_size 选择
# 根据文档类型调整
academic_papers = 1-2 # 学术论文,逻辑严密
technical_docs = 1 # 技术文档,条理清晰
narrative_text = 2-3 # 叙述性文本,连贯性强
2. 块大小优化
# 块大小与上下文大小的平衡
chunk_size = 1000 # 基础块大小
context_size = 1 # 总上下文 ≈ 3000字符
# 确保总上下文不超过模型限制
最佳实践
✅ 推荐做法
- 合理设置上下文大小: 从 1 开始测试,根据效果调整
- 考虑文档结构: 对于结构化文档,保持逻辑完整性
- 监控性能: 更多上下文意味着更高的计算成本
- 评估质量: 定期评估上下文是否提升了回答质量
❌ 避免问题
- 过大的上下文: 可能引入无关信息
- 忽略边界: 未处理文档开头结尾的特殊情况
- 固定参数: 未根据不同类型文档调整参数
- 缺乏评估: 未验证上下文增强的实际效果
扩展应用
上下文增强 RAG 可以进一步扩展为:
- 动态上下文大小: 根据查询复杂度动态调整
- 语义上下文: 不仅考虑位置邻近,还考虑语义相关性
- 多模态上下文: 结合图像、表格等多模态信息
- 递归上下文: 逐步扩展上下文范围直到找到满意答案
上下文增强 RAG 是改善传统 RAG 系统的重要技术,通过简单的邻居块包含策略就能显著提升回答的完整性和准确性。