告別機械切割:基于sentence-transformers語義分塊如何讓文本理解更智能?

傳統(tǒng)的文本分塊機制就像用尺子切割布料 —— 不管布料的花紋如何,只按固定長度下刀。這種 "一刀切" 的方式常常導(dǎo)致語義割裂:明明是一個完整的論點,卻被硬生生拆成兩半;本該分開的兩個主題,反而被塞進同一個塊里。
而語義分塊的核心思路是:讓意思相近的內(nèi)容 "抱團"。它借助文本嵌入(Embedding)技術(shù),像磁鐵一樣把語義相似的句子吸在一起,從而實現(xiàn)更符合人類理解習(xí)慣的分塊效果。本文將詳解兩種語義分塊實現(xiàn)方法,從原理到代碼帶你全面掌握這一技術(shù)。
一、方法 1:按語義流自然分割
實現(xiàn)原理
這種方法保留原文的句子順序,通過分析相鄰句子的語義關(guān)聯(lián)度,找到最適合分割的 "自然斷點"。
實現(xiàn)步驟
1. 文本加載與預(yù)處理
先把文本拆成獨立句子,為后續(xù)分析做準(zhǔn)備。
加載文本:以 Paul Graham 的 MIT 論文為例。
with open('../../data/PGEssays/mit.txt') as file:
essay = file.read()分割句子import re
single_sentences_list = re.split(r'(?<=[.?!])\s+', essay)結(jié)構(gòu)化處理:轉(zhuǎn)換成字典列表,方便后續(xù)添加嵌入、距離等信息。
sentences = [{'sentence': x, 'index': i} for i, x in enumerate(single_sentences_list)]2.句子合并:給每個句子 "加前后綴"
單個句子的語義往往不完整(比如含代詞、省略成分),因此需要結(jié)合上下文。通過滑動窗口合并相鄰句子,讓每個句子 "帶上" 前后文信息:
def combine_sentences(sentences, buffer_size=1):
for i in range(len(sentences)):
combined_sentence = ''
for j in range(i - buffer_size, i):
if j >= 0:
combined_sentence += sentences[j]['sentence'] + ' '
combined_sentence += sentences[i]['sentence']
for j in range(i + 1, i + 1 + buffer_size):
if j < len(sentences):
combined_sentence += ' ' + sentences[j]['sentence']
sentences[i]['combined_sentence'] = combined_sentence
return sentences
sentences = combine_sentences(sentences)3. 生成句子嵌入:把文字變成 "語義向量"
使用SentenceTransformer的嵌入模型all-MiniLM-L6-v2對組合后的句子進行嵌入,并將嵌入結(jié)果添加到字典列表中,這些嵌入捕捉了句子的語義信息。
model_path="sentence-transformers/all-MiniLM-L6-v2"
max_tokens = 100
cluster_threshold = 0.4
similarity_threshold = 0.4
semantic_chunker = SemanticChunker(
model_name=model_path,
max_tokens=max_tokens,
cluster_threshold=cluster_threshold,
similarity_threshold=similarity_threshold)
embeddings = semantic_chunker.get_embeddings([{'text': x['combined_sentence']} for x in sentences])
for i, sentence in enumerate(sentences):
sentence['combined_sentence_embedding'] = embeddings[i]4. 計算語義距離:找到句子間的 "隱形鴻溝"
通過余弦相似度計算相鄰句子的語義關(guān)聯(lián)度,較大的距離表示句子之間的語義變化較大,可能是一個自然的分割點。
from sklearn.metrics.pairwise import cosine_similarity
def calculate_cosine_distances(sentences):
distances = []
for i in range(len(sentences) - 1):
# 取當(dāng)前句和下一句的嵌入
embedding_current = sentences[i]['combined_sentence_embedding']
embedding_next = sentences[i + 1]['combined_sentence_embedding']
# 計算相似度(1-相似度=距離)
similarity = cosine_similarity([embedding_current], [embedding_next])[0][0]
distance = 1 - similarity
distances.append(distance)
sentences[i]['distance_to_next'] = distance
return distances, sentences5. 識別斷點:用數(shù)據(jù)可視化定位分割點
把語義距離繪制成折線圖,超過閾值的峰值就是最佳分割點:
import matplotlib.pyplot as plt
import numpy as np
plt.plot(distances)
y_upper_bound = .2
plt.ylim(0, y_upper_bound)
plt.xlim(0, len(distances))
breakpoint_percentile_threshold = 95
breakpoint_distance_threshold = np.percentile(distances, breakpoint_percentile_threshold)
plt.axhline(y=breakpoint_distance_threshold, color='r', linestyle='-')
num_distances_above_theshold = len([x for x in distances if x > breakpoint_distance_threshold])
plt.text(x=(len(distances)*.01), y=y_upper_bound/50, s=f"{num_distances_above_theshold + 1} Chunks")
indices_above_thresh = [i for i, x in enumerate(distances) if x > breakpoint_distance_threshold]
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']
for i, breakpoint_index in enumerate(indices_above_thresh):
start_index = 0 if i == 0 else indices_above_thresh[i - 1]
end_index = breakpoint_index if i < len(indices_above_thresh) - 1 else len(distances)
plt.axvspan(start_index, end_index, facecolor=colors[i % len(colors)], alpha=0.25)
plt.text(x=np.average([start_index, end_index]),
y=breakpoint_distance_threshold + (y_upper_bound)/ 20,
s=f"Chunk #{i}", horiznotallow='center',
rotatinotallow='vertical')
if indices_above_thresh:
last_breakpoint = indices_above_thresh[-1]
if last_breakpoint < len(distances):
plt.axvspan(last_breakpoint, len(distances), facecolor=colors[len(indices_above_thresh) % len(colors)], alpha=0.25)
plt.text(x=np.average([last_breakpoint, len(distances)]),
y=breakpoint_distance_threshold + (y_upper_bound)/ 20,
s=f"Chunk #{i+1}",
rotatinotallow='vertical')
plt.title("PG Essay Chunks Based On Embedding Breakpoints")
plt.xlabel("Index of sentences in essay (Sentence Position)")
plt.ylabel("Cosine distance between sequential sentences")
plt.show()6. 合并成塊:按斷點組裝語義完整的文本塊
根據(jù)識別出的斷點,將句子合并成長度適中的語義塊,確保每個塊的長度不超過設(shè)定的最大長度。
start_index = 0
chunks = []
for index in indices_above_thresh:
end_index = index
group = sentences[start_index:end_index + 1]
combined_text = ' '.join([d['sentence'] for d in group])
chunks.append(combined_text)
start_index = index + 1
if start_index < len(sentences):
combined_text = ' '.join([d['sentence'] for d in sentences[start_index:]])
chunks.append(combined_text)
for i, chunk in enumerate(chunks[:2]):
buffer = 200
print (f"Chunk #{i}")
print (chunk[:buffer].strip())
print ("...")
print (chunk[-buffer:].strip())
print ("\n")二、方法2:讓相似語義 "跨位置相聚"
如果說方法 1 是 "順流而下" 的自然分割,方法 2 則像 "主題拼圖"—— 不管句子在原文中的位置,只要語義相似就放到一起。
實現(xiàn)原理
通過聚類算法,把所有句子按語義相似度分組,打破原文順序限制,適合需要跨段落整合同類信息的場景(如文獻綜述、主題歸納)。
核心流程
1. 生成嵌入與相似度矩陣
與方法 1 類似,但需要計算所有句子之間的相似度(而非僅相鄰句子):
# 生成所有句子的嵌入
embeddings = np.array(model.encode([s['sentence'] for s in sentences]))
# 計算相似度矩陣(n×n,n為句子數(shù))
similarity_matrix = cosine_similarity(embeddings)2. 用并查集算法聚類
通過 Union-Find(并查集)算法,把相似度高于閾值的句子歸為一類:
def cluster_chunks(similarity_matrix, threshold=0.5):
n = similarity_matrix.shape[0]
parent = list(range(n)) # 每個句子初始為自己的"根節(jié)點"
def find(x): # 找根節(jié)點
while parent[x] != x:
parent[x] = parent[parent[x]] # 路徑壓縮
x = parent[x]
return x
def union(x, y): # 合并兩個集合
parent[find(x)] = find(y)
# 遍歷矩陣,合并相似度高的句子
for i in range(n):
for j in range(i + 1, n):
if similarity_matrix[i, j] >= threshold:
union(i, j)
# 為每個句子分配聚類ID
clusters = [find(i) for i in range(n)]
cluster_map = {cid: idx for idx, cid in enumerate(sorted(set(clusters)))}
return [cluster_map[c] for c in clusters]3. 合并聚類結(jié)果
將同一類的句子合并成塊,同時控制塊大小不超過最大token數(shù):
def merge_chunks(chunks, clusters, max_tokens=512):
cluster_map = defaultdict(list)
for idx, cluster_id in enumerate(clusters):
cluster_map[cluster_id].append(chunks[idx])
merged_chunks = []
for chunk_list in cluster_map.values():
current_text = ""
for chunk in chunk_list:
next_text = (current_text + " " + chunk["text"]).strip()
# 檢查令牌數(shù)是否超標(biāo)
if len(tokenizer.encode(next_text)) > max_tokens and current_text:
merged_chunks.append({"text": current_text})
current_text = chunk["text"]
else:
current_text = next_text
if current_text:
merged_chunks.append({"text": current_text})
return merged_chunks完整流程
- 獲取嵌入
get_embeddings方法接收一個包含文本塊的列表(每個文本塊是一個字典,包含text和其他元數(shù)據(jù)),并使用 Sentence Transformer 模型為每個文本塊生成嵌入向量。 - 計算相似性矩陣
compute_similarity方法使用cosine_similarity函數(shù)計算嵌入向量之間的相似性矩陣。相似性矩陣是一個二維數(shù)組,其中每個元素表示兩個文本塊之間的相似性。 - 聚類文本塊
cluster_chunks方法基于相似性矩陣和給定的閾值(cluster_threshold),使用并查集(Union-Find)算法將文本塊聚類。具體步驟如下:
- 初始化每個文本塊為一個獨立的集合。
- 遍歷相似性矩陣,如果兩個文本塊之間的相似性大于或等于閾值,則將它們合并到同一個集合中。
- 最終,每個文本塊被分配到一個聚類中。
- 合并文本塊
merge_chunks方法根據(jù)聚類結(jié)果將文本塊合并。它會嘗試將同一聚類中的文本塊合并為一個更大的文本塊,同時確保合并后的文本塊不超過max_tokens的限制。如果合并后的文本塊超出限制,則會創(chuàng)建一個新的文本塊。 - 查找語義對
find_top_semantic_pairs方法從相似性矩陣中查找語義相似的文本塊對。它會篩選出相似性大于或等于similarity_threshold的文本塊對,并按相似性降序排列,返回前top_k對。
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict, Any
from collections import defaultdict
# 定義 SemanticChunker 類,用于對文本塊進行語義分塊和聚類
class SemanticChunker:
# 類的構(gòu)造函數(shù),初始化各種參數(shù)
def __init__(
self,
# 用于生成文本嵌入的模型名稱,默認(rèn)為 all-MiniLM-L6-v2
model_name='all-MiniLM-L6-v2',
# 合并后每個文本塊允許的最大令牌數(shù),默認(rèn)為 512
max_tokens=512,
# 聚類的閾值,控制聚類的粒度,默認(rèn)為 0.5
cluster_threshold=0.5,
# 確定語義對的最小相似度閾值,默認(rèn)為 0.4
similarity_threshold=0.4
):
# 選擇可用的設(shè)備進行計算,優(yōu)先選擇 CUDA,其次是 MPS,最后是 CPU
self.device = (
"cuda" if torch.cuda.is_available() else
"mps" if torch.backends.mps.is_available() else
"cpu"
)
# 打印當(dāng)前使用的設(shè)備信息
print(f"[Info] Using device: {self.device}")
# 加載指定名稱的 SentenceTransformer 模型,并將其移動到指定設(shè)備上
self.model = SentenceTransformer(model_name, device=self.device)
# 存儲最大令牌數(shù)
self.max_tokens = max_tokens
# 存儲聚類閾值
self.cluster_threshold = cluster_threshold
# 存儲相似度閾值
self.similarity_threshold = similarity_threshold
# 如果模型有分詞器,則存儲該分詞器,否則存儲 None
self.tokenizer = self.model.tokenizer if hasattr(self.model, "tokenizer") else None
# 計算文本塊的嵌入表示
def get_embeddings(self, chunks: List[Dict[str, Any]]):
# 如果輸入的文本塊列表為空,則返回空的 numpy 數(shù)組
if not chunks:
return np.array([])
# 從文本塊列表中提取文本內(nèi)容
texts = [chunk["text"] for chunk in chunks]
# 使用模型對文本內(nèi)容進行編碼,并將結(jié)果轉(zhuǎn)換為 numpy 數(shù)組
return np.array(self.model.encode(texts, show_progress_bar=False))
# 計算嵌入表示之間的余弦相似度矩陣
def compute_similarity(self, embeddings):
# 如果嵌入數(shù)組為空,則返回一個 0x0 的矩陣
if embeddings.size == 0:
return np.zeros((0, 0))
# 計算嵌入之間的余弦相似度矩陣
return cosine_similarity(embeddings)
# 對文本塊進行聚類
def cluster_chunks(self, similarity_matrix, threshold=0.5):
# 獲取相似度矩陣的行數(shù),即文本塊的數(shù)量
n = similarity_matrix.shape[0]
# 初始化每個文本塊的父節(jié)點為自身
parent = list(range(n))
# 查找節(jié)點的根節(jié)點(并查集的查找操作)
def find(x):
# 路徑壓縮,優(yōu)化查找效率
while parent[x] != x:
parent[x] = parent[parent[x]]
x = parent[x]
return x
# 合并兩個節(jié)點所在的集合(并查集的合并操作)
def union(x, y):
# 將 x 所在集合的根節(jié)點指向 y 所在集合的根節(jié)點
parent[find(x)] = find(y)
# 遍歷相似度矩陣,將相似度大于等于閾值的文本塊合并到同一個集合中
for i in range(n):
for j in range(i + 1, n):
if similarity_matrix[i, j] >= threshold:
union(i, j)
# 為每個文本塊找到其所在集合的根節(jié)點
clusters = [find(i) for i in range(n)]
# 為每個集合分配一個唯一的整數(shù) ID
cluster_map = {cid: idx for idx, cid in enumerate(sorted(set(clusters)))}
# 將每個文本塊的集合 ID 映射為唯一的整數(shù) ID
return [cluster_map[c] for c in clusters]
# 合并屬于同一聚類的文本塊
def merge_chunks(self, chunks: List[Dict[str, Any]], clusters: List[int]) -> List[Dict[str, Any]]:
# 如果輸入的文本塊列表或聚類列表為空,則返回空列表
if not chunks or not clusters:
return []
# 使用 defaultdict 存儲每個聚類對應(yīng)的文本塊列表
cluster_map = defaultdict(list)
# 將每個文本塊添加到其所屬的聚類列表中
for idx, cluster_id in enumerate(clusters):
cluster_map[cluster_id].append(chunks[idx])
# 存儲合并后的文本塊
merged_chunks = []
# 遍歷每個聚類的文本塊列表
for chunk_list in cluster_map.values():
# 初始化當(dāng)前合并文本為空
current_text = ""
# 初始化當(dāng)前合并文本的元數(shù)據(jù)為空列表
current_meta = []
# 遍歷當(dāng)前聚類中的每個文本塊
for chunk in chunk_list:
# 嘗試將當(dāng)前文本和下一個文本塊合并
next_text = (current_text + " " + chunk["text"]).strip()
# 如果有分詞器,則計算合并后文本的令牌數(shù)
if self.tokenizer:
num_tokens = len(self.tokenizer.encode(next_text))
# 否則,簡單地按空格分割文本計算令牌數(shù)
else:
num_tokens = len(next_text.split())
# 如果合并后文本的令牌數(shù)超過最大令牌數(shù),且當(dāng)前文本不為空
if num_tokens > self.max_tokens and current_text:
# 將當(dāng)前合并文本添加到合并后的文本塊列表中
merged_chunks.append({
"text": current_text,
"metadata": current_meta
})
# 重置當(dāng)前合并文本為下一個文本塊
current_text = chunk["text"]
# 重置當(dāng)前合并文本的元數(shù)據(jù)為下一個文本塊
current_meta = [chunk]
else:
# 否則,更新當(dāng)前合并文本
current_text = next_text
# 更新當(dāng)前合并文本的元數(shù)據(jù)
current_meta.append(chunk)
# 如果當(dāng)前合并文本不為空,將其添加到合并后的文本塊列表中
if current_text:
merged_chunks.append({
"text": current_text,
"metadata": current_meta
})
return merged_chunks
# 找到相似度最高的語義對
def find_top_semantic_pairs(self, similarity_matrix, min_similarity=0.4, top_k=50):
# 存儲語義對的列表
pairs = []
# 獲取相似度矩陣的行數(shù),即文本塊的數(shù)量
n = similarity_matrix.shape[0]
# 遍歷相似度矩陣的上三角部分
for i in range(n):
for j in range(i + 1, n):
# 獲取當(dāng)前文本塊對的相似度
sim = similarity_matrix[i, j]
# 如果相似度大于等于最小相似度閾值
if sim >= min_similarity:
# 將文本塊對及其相似度添加到列表中
pairs.append((i, j, sim))
# 按相似度降序排序
pairs.sort(key=lambda x: x[2], reverse=True)
# 返回前 top_k 個語義對
return pairs[:top_k]
# 對輸入的文本塊進行分塊和聚類操作
def chunk(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# 計算文本塊的嵌入表示
embeddings = self.get_embeddings(chunks)
# 計算嵌入表示之間的相似度矩陣
similarity_matrix = self.compute_similarity(embeddings)
# 對文本塊進行聚類
clusters = self.cluster_chunks(similarity_matrix, threshold=self.cluster_threshold)
# 合并屬于同一聚類的文本塊
return self.merge_chunks(chunks, clusters)
# 獲取調(diào)試信息,用于可視化或?qū)С稣{(diào)試數(shù)據(jù)
def get_debug_info(self, chunks: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Optional: for visualization or export/debug purposes."""
# 計算文本塊的嵌入表示
embeddings = self.get_embeddings(chunks)
# 計算嵌入表示之間的相似度矩陣
similarity_matrix = self.compute_similarity(embeddings)
# 對文本塊進行聚類
clusters = self.cluster_chunks(similarity_matrix, threshold=self.cluster_threshold)
# 合并屬于同一聚類的文本塊
merged_chunks = self.merge_chunks(chunks, clusters)
# 找到相似度最高的語義對
semantic_pairs = self.find_top_semantic_pairs(similarity_matrix, min_similarity=self.similarity_threshold)
# 返回包含各種調(diào)試信息的字典
return {
"original_chunks": chunks,
"embeddings": embeddings,
"similarity_matrix": similarity_matrix,
"clusters": clusters,
"semantic_pairs": semantic_pairs,
"merged_chunks": merged_chunks
}三、兩種方法對比與適用場景
SemanticTextChunker 更側(cè)重于文本的分割,通過識別語義斷點來實現(xiàn),而 SemanticChunker 更側(cè)重于文本塊的合并,通過聚類和合并策略來實現(xiàn)。這兩個類都利用了語義嵌入和相似度計算,但在具體實現(xiàn)和應(yīng)用場景上有所不同。
方法 | 核心邏輯 | 優(yōu)勢 | 適用場景 |
方法1:不改變文本順序 | 按語義流找斷點,保留原文結(jié)構(gòu) | 符合閱讀習(xí)慣,適合長文分段 | 文檔摘要、上下文問答 |
方法2:改變文本順序 | 按語義聚類重組,打破原文順序 | 聚焦主題整合,適合跨段歸納 | 主題分析、文獻綜述 |
總結(jié)
這種語義分塊方法通過嵌入技術(shù)讓文本分割從 "按長度算" 進化到 "按意思分",但目前仍有優(yōu)化空間:
- 動態(tài)閾值調(diào)整:不同類型文本(如論文 / 小說)的語義密度不同,需自適應(yīng)調(diào)整閾值
- 遞歸分塊:對超大型文本塊進行二次分割,平衡語義完整性和長度限制
- 解決代詞歧義:結(jié)合指代消解技術(shù),避免 "他 / 它" 等代詞因分塊導(dǎo)致的指代混亂
最終,語義分塊的效果需要通過 RAG(檢索增強生成)評估來驗證 —— 能讓 AI 回答更準(zhǔn)確的分塊,才是好的分塊。





























