Skip to main content

04. 上下文增强 RAG

在传统 RAG 中,我们通常只检索最相关的文本块。然而,有时候单个文本块可能缺乏足够的上下文信息来回答问题。上下文增强 RAG通过检索相关文本块的邻近块来提供更丰富的上下文信息。

核心思想

传统 RAG 检索过程:

文档: [块A] [块B] [块C] [块D] [块E]
查询: "什么是机器学习?"
检索结果: [块C] (最相关)

上下文增强 RAG 检索过程:

文档: [块A] [块B] [块C] [块D] [块E]
查询: "什么是机器学习?"
检索结果: [块B] [块C] [块D] (块C最相关,加上前后邻居)

技术优势

🎯 更完整的上下文

  • 提供更多相关信息,避免信息断裂
  • 保持文档的逻辑连贯性

📈 提高回答质量

  • 减少因上下文不足导致的回答不完整
  • 增强语言模型对复杂问题的理解能力

🔄 保持文本连贯性

  • 维护原文档的叙述流畅性
  • 避免孤立文本块造成的理解偏差

完整代码实现

import fitz
import os
import numpy as np
import json
from openai import OpenAI
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

def extract_text_from_pdf(pdf_path):
"""
从PDF文件中提取文本内容

Args:
pdf_path (str): PDF文件路径

Returns:
str: 提取的文本内容
"""
mypdf = fitz.open(pdf_path)
all_text = ""

for page_num in range(mypdf.page_count):
page = mypdf[page_num]
text = page.get_text("text")
all_text += text

return all_text

def chunk_text(text, n, overlap):
"""
将文本分割成重叠的文本块

Args:
text (str): 要分割的文本
n (int): 每个文本块的字符数
overlap (int): 重叠字符数

Returns:
List[str]: 文本块列表
"""
chunks = []

for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])

return chunks

# 初始化OpenAI客户端
client = OpenAI(
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY")
)

def create_embeddings(text, model="BAAI/bge-base-en-v1.5"):
"""
为给定文本创建嵌入向量

Args:
text (str): 输入文本
model (str): 嵌入模型名称

Returns:
List[float]: 嵌入向量
"""
embedding_model = HuggingFaceEmbedding(model_name=model)

if isinstance(text, list):
response = embedding_model.get_text_embedding_batch(text)
else:
response = embedding_model.get_text_embedding(text)

return response

def cosine_similarity(vec1, vec2):
"""
计算两个向量的余弦相似度

Args:
vec1 (np.ndarray): 第一个向量
vec2 (np.ndarray): 第二个向量

Returns:
float: 余弦相似度值
"""
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

def context_enriched_search(query, text_chunks, embeddings, k=1, context_size=1):
"""
执行上下文增强检索

Args:
query (str): 查询问题
text_chunks (List[str]): 文本块列表
embeddings (List): 嵌入向量列表
k (int): 检索的相关块数量
context_size (int): 上下文邻居块数量

Returns:
List[str]: 包含上下文的相关文本块
"""
# 将查询转换为嵌入向量
query_embedding = create_embeddings(query)
similarity_scores = []

# 计算查询与每个文本块的相似度
for i, chunk_embedding in enumerate(embeddings):
similarity_score = cosine_similarity(
np.array(query_embedding),
np.array(chunk_embedding)
)
similarity_scores.append((i, similarity_score))

# 按相似度降序排序
similarity_scores.sort(key=lambda x: x[1], reverse=True)

# 获取最相关块的索引
top_index = similarity_scores[0][0]
print(f'最相关块索引: {top_index}')

# 确定上下文范围,确保不超出边界
start = max(0, top_index - context_size)
end = min(len(text_chunks), top_index + context_size + 1)

# 返回相关块及其邻近上下文
return [text_chunks[i] for i in range(start, end)]

def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
"""
生成AI回答

Args:
system_prompt (str): 系统提示词
user_message (str): 用户消息
model (str): 使用的模型

Returns:
str: AI生成的回答
"""
response = client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
)

return response.choices[0].message.content

实际应用示例

# 1. 文档处理
pdf_path = "data/AI_Information.pdf"
extracted_text = extract_text_from_pdf(pdf_path)

# 2. 文本分块
text_chunks = chunk_text(extracted_text, 1000, 200)
print(f"创建了 {len(text_chunks)} 个文本块")

# 3. 创建嵌入向量
embeddings = create_embeddings(text_chunks)

# 4. 加载测试查询
with open('data/val.json') as f:
data = json.load(f)

query = data[0]['question']
print(f"查询: {query}")

# 5. 执行上下文增强检索
# context_size=1 表示包含前后各1个邻居块
top_chunks = context_enriched_search(
query,
text_chunks,
embeddings,
k=1,
context_size=1
)

print(f"检索到 {len(top_chunks)} 个上下文块")

# 6. 显示检索结果
for i, chunk in enumerate(top_chunks):
print(f"上下文 {i + 1}:\n{chunk}\n" + "="*50)

# 7. 生成最终回答
system_prompt = "你是一个AI助手,严格基于给定的上下文回答问题。如果无法从提供的上下文中得出答案,请回答:'我没有足够的信息来回答这个问题。'"

# 组合上下文
context = "\n\n".join([f"上下文{i+1}: {chunk}" for i, chunk in enumerate(top_chunks)])
user_message = f"上下文:\n{context}\n\n问题: {query}"

response = generate_response(system_prompt, user_message)
print(f"AI回答: {response}")

关键技术解析

1. 邻居块选择策略

# 确定上下文范围的核心逻辑
start = max(0, top_index - context_size) # 前邻居边界
end = min(len(text_chunks), top_index + context_size + 1) # 后邻居边界

# 示例:如果最相关块是索引5,context_size=2
# start = max(0, 5-2) = 3
# end = min(total_chunks, 5+2+1) = 8
# 结果:返回索引[3,4,5,6,7]的块

2. 上下文大小调优

不同的context_size会产生不同的效果:

  • context_size=0: 等同于传统 RAG,只返回最相关块
  • context_size=1: 返回相关块及前后各 1 个邻居(推荐)
  • context_size=2: 返回相关块及前后各 2 个邻居
  • 过大的 context_size: 可能引入噪声信息

3. 边界处理

# 处理文档开头和结尾的边界情况
start = max(0, top_index - context_size) # 防止负索引
end = min(len(text_chunks), top_index + 1 + context_size) # 防止超出范围

效果对比

传统 RAG vs 上下文增强 RAG

查询: "什么是深度学习的反向传播算法?"

传统 RAG 结果:

只返回最相关的1个块,可能只包含反向传播的定义,
缺乏算法步骤和实际应用的详细说明。

上下文增强 RAG 结果:

返回3个连续块:
1. 深度学习基础概念(背景)
2. 反向传播算法详解(核心)
3. 算法实现和应用(扩展)

提供更完整、连贯的信息。

参数调优建议

1. context_size 选择

# 根据文档类型调整
academic_papers = 1-2 # 学术论文,逻辑严密
technical_docs = 1 # 技术文档,条理清晰
narrative_text = 2-3 # 叙述性文本,连贯性强

2. 块大小优化

# 块大小与上下文大小的平衡
chunk_size = 1000 # 基础块大小
context_size = 1 # 总上下文 ≈ 3000字符
# 确保总上下文不超过模型限制

最佳实践

✅ 推荐做法

  1. 合理设置上下文大小: 从 1 开始测试,根据效果调整
  2. 考虑文档结构: 对于结构化文档,保持逻辑完整性
  3. 监控性能: 更多上下文意味着更高的计算成本
  4. 评估质量: 定期评估上下文是否提升了回答质量

❌ 避免问题

  1. 过大的上下文: 可能引入无关信息
  2. 忽略边界: 未处理文档开头结尾的特殊情况
  3. 固定参数: 未根据不同类型文档调整参数
  4. 缺乏评估: 未验证上下文增强的实际效果

扩展应用

上下文增强 RAG 可以进一步扩展为:

  1. 动态上下文大小: 根据查询复杂度动态调整
  2. 语义上下文: 不仅考虑位置邻近,还考虑语义相关性
  3. 多模态上下文: 结合图像、表格等多模态信息
  4. 递归上下文: 逐步扩展上下文范围直到找到满意答案

上下文增强 RAG 是改善传统 RAG 系统的重要技术,通过简单的邻居块包含策略就能显著提升回答的完整性和准确性。