提升RAG檢索質(zhì)量的三個(gè)高級(jí)技巧(查詢(xún)擴(kuò)展、交叉編碼器重排序和嵌入適配器)
在現(xiàn)成的 RAG 實(shí)施過(guò)程中,經(jīng)常會(huì)出現(xiàn)檢索的文檔缺少完整的答案或是包含冗余信息和無(wú)關(guān)的信息,以及文檔排序不同,導(dǎo)致生成的答案與用戶(hù)查詢(xún)的意圖不一致。
現(xiàn)介紹三種能夠有效提高檢索能力的技術(shù),即查詢(xún)擴(kuò)展(Query expansion),跨編碼器重排序(Cross-encoder re-ranking),嵌入適配器(Embedding adaptors),可以支持檢索到更多與用戶(hù)查詢(xún)密切匹配的相關(guān)文檔,從而提高生成答案的影響力。
1.查詢(xún)擴(kuò)展
查詢(xún)擴(kuò)展是指對(duì)原始查詢(xún)進(jìn)行改寫(xiě)的一系列技術(shù)。有兩種常見(jiàn)的方法:
1) 使用生成的答案進(jìn)行查詢(xún)擴(kuò)展
給定輸入查詢(xún)后,這種方法首先會(huì)指示 LLM 提供一個(gè)假設(shè)答案,無(wú)論其正確性如何。然后,將查詢(xún)和生成的答案合并在一個(gè)提示中,并發(fā)送給檢索系統(tǒng)。
圖片
這個(gè)方法的效果很好?;灸康氖窍M麢z索到更像答案的文檔。假設(shè)答案的正確性并不重要,因?yàn)楦信d趣的是它的結(jié)構(gòu)和表述。可以將假設(shè)答案視為一個(gè)模板,它有助于識(shí)別嵌入空間中的相關(guān)鄰域。具體可參考論文《Precise Zero-Shot Dense Retrieval without Relevance Labels【1】》
下面是用來(lái)增強(qiáng)發(fā)送給 RAG 的查詢(xún)的提示示例,該 RAG 負(fù)責(zé)回答有關(guān)財(cái)務(wù)報(bào)告的問(wèn)題。
You are a helpful expert financial research assistant.
Provide an example answer to the given question, that might
be found in a document like an annual report.2)用多個(gè)相關(guān)問(wèn)題擴(kuò)展查詢(xún)
利用 LLM 生成 N 個(gè)與原始查詢(xún)相關(guān)的問(wèn)題,然后將所有問(wèn)題(加上原始查詢(xún))發(fā)送給檢索系統(tǒng)。通過(guò)這種方法,可以從向量庫(kù)中檢索到更多文檔。不過(guò),其中有些會(huì)是重復(fù)的,因此需要進(jìn)行后處理來(lái)刪除它們。
圖片
這種方法背后的理念是,可以擴(kuò)展可能不完整或模糊的初始查詢(xún),并納入最終可能相關(guān)和互補(bǔ)的相關(guān)方面。
下面是用來(lái)生成相關(guān)問(wèn)題的提示:
You are a helpful expert financial research assistant.
Your users are asking questions about an annual report.
Suggest up to five additional related questions to help them
find the information they need, for the provided question.
Suggest only short questions without compound sentences.
Suggest a variety of questions that cover different aspects of the topic.
Make sure they are complete questions, and that they are related to
the original question.
Output one question per line. Do not number the questions.具體可參考論文《Query Expansion by Prompting Large Language Models【2】》。
上述方法有一個(gè)缺點(diǎn)就是會(huì)得到很多的文檔,這些文檔可能會(huì)分散 LLM 的注意力,使其無(wú)法生成有用的答案。這時(shí)候需要對(duì)文檔進(jìn)行重排序,去除相關(guān)性不高的文檔。
2.交叉編碼器重排序
這種方法會(huì)根據(jù)輸入查詢(xún)與檢索到的文檔的相關(guān)性的分?jǐn)?shù)對(duì)文檔進(jìn)行重排序。為了計(jì)算這個(gè)分?jǐn)?shù),將會(huì)使用到交叉編碼器。

交叉編碼器是一種深度神經(jīng)網(wǎng)絡(luò),它將兩個(gè)輸入序列作為一個(gè)輸入進(jìn)行處理。這樣,模型就能直接比較和對(duì)比輸入,以更綜合、更細(xì)致的方式理解它們之間的關(guān)系。
圖片
交叉編碼器可用于信息檢索:給定一個(gè)查詢(xún),用所有檢索到的文檔對(duì)其進(jìn)行編碼。然后,將它們按遞減順序排列。得分高的文檔就是最相關(guān)的文檔。
詳情請(qǐng)參見(jiàn) SBERT.net Retrieve & Re-rank【3】。
圖片
下面介紹如何使用交叉編碼器快速開(kāi)始重新排序:
pip install -U sentence-transformers
#導(dǎo)入交叉編碼器并加載
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
#對(duì)每一對(duì)(查詢(xún)、文檔)進(jìn)行評(píng)分
pairs = [[query, doc] for doc in retrieved_documents]
scores = cross_encoder.predict(pairs)
print("Scores:") for score in scores:
print(score)
# Scores:
# 0.98693466
# 2.644579
# -0.26802942
# -10.73159
# -7.7066045
# -5.6469955
# -4.297035
# -10.933233
# -7.0384283
# -7.3246956
#重新排列文件順序:
print("New Ordering:")
for o in np.argsort(scores)[::-1]:
print(o+1)交叉編碼器重新排序可與查詢(xún)擴(kuò)展一起使用:在生成多個(gè)相關(guān)問(wèn)題并檢索相應(yīng)的文檔(比如最終有 M 個(gè)文檔)后,對(duì)它們重新排序并選出前 K 個(gè)(K < M)。這樣,就可以減少上下文的大小,同時(shí)選出最重要的部分。
3.嵌入適配器
這是一種功能強(qiáng)大但使用簡(jiǎn)單的技術(shù),可以擴(kuò)展嵌入式內(nèi)容,使其更好地與用戶(hù)的任務(wù)保持一致,利用用戶(hù)對(duì)檢索文檔相關(guān)性的反饋來(lái)訓(xùn)練適配器。
適配器是全面微調(diào)預(yù)訓(xùn)練模型的一種輕量級(jí)替代方法。目前,適配器是以小型前饋神經(jīng)網(wǎng)絡(luò)的形式實(shí)現(xiàn)的,插入到預(yù)訓(xùn)練模型的層之間。訓(xùn)練適配器的根本目的是改變嵌入查詢(xún),從而為特定任務(wù)產(chǎn)生更好的檢索結(jié)果。嵌入適配器是在嵌入階段之后、檢索之前插入的一個(gè)階段。可以把它想象成一個(gè)矩陣(帶有經(jīng)過(guò)訓(xùn)練的權(quán)重),它采用原始嵌入并對(duì)其進(jìn)行縮放。
圖片
以下是訓(xùn)練步驟:
1)準(zhǔn)備訓(xùn)練數(shù)據(jù)
要訓(xùn)練嵌入適配器,需要一些關(guān)于文檔相關(guān)性的訓(xùn)練數(shù)據(jù)。這些數(shù)據(jù)可以是人工標(biāo)注的,也可以由 LLM 生成。這些數(shù)據(jù)必須包括(查詢(xún)、文檔)的元組及其相應(yīng)的標(biāo)簽(如果文檔與查詢(xún)相關(guān),則為 1,否則為-1)。為簡(jiǎn)單起見(jiàn),將創(chuàng)建一個(gè)合成數(shù)據(jù)集,但在現(xiàn)實(shí)世界中,需要設(shè)計(jì)一種收集用戶(hù)反饋的方法(比如,讓用戶(hù)對(duì)界面上的文檔相關(guān)性進(jìn)行評(píng)分)。
為了創(chuàng)建一些訓(xùn)練數(shù)據(jù),首先可利用LLM生成財(cái)務(wù)分析師在分析財(cái)務(wù)報(bào)告時(shí)可能會(huì)提出的問(wèn)題樣本。
import os
import openai
from openai import OpenAI
from dotenv import load_dotenv, find_dotenv
_ = load_dotenv(find_dotenv())
openai.api_key = os.environ['OPENAI_API_KEY']
PROMPT_DATASET = """
You are a helpful expert financial research assistant.
You help users analyze financial statements to better understand companies.
Suggest 10 to 15 short questions that are important to ask when analyzing
an annual report.
Do not output any compound questions (questions with multiple sentences
or conjunctions).
Output each question on a separate line divided by a newline.
"""
def generate_queries(model="gpt-3.5-turbo"):
messages = [
{
"role": "system",
"content": PROMPT_DATASET,
},
]
response = openai_client.chat.completions.create(
model=model,
messages=messages,
)
content = response.choices[0].message.content
content = content.split("\n")
return content
generated_queries = generate_queries()
for query in generated_queries:
print(query)
# 1. What is the company's revenue growth rate over the past three years?
# 2. What are the company's total assets and total liabilities?
# 3. How much debt does the company have? Is it increasing or decreasing?
# 4. What is the company's profit margin? Is it improving or declining?
# 5. What are the company's cash flow from operations, investing, and financing activities?
# 6. What are the company's major sources of revenue?
# 7. Does the company have any pending litigation or legal issues?
# 8. What is the company's market share compared to its competitors?
# 9. How much cash does the company have on hand?
# 10. Are there any major changes in the company's executive team or board of directors?
# 11. What is the company's dividend history and policy?
# 12. Are there any related party transactions?
# 13. What are the company's major risks and uncertainties?
# 14. What is the company's current ratio and quick ratio?
# 15. How has the company's stock price performed over the past year?然后,為每個(gè)生成的問(wèn)題檢索文檔。為此,將查詢(xún)一個(gè) Chroma 集合,在該集合中,以前索引過(guò)一份財(cái)務(wù)報(bào)告。
results = chroma_collection.query(query_texts=generated_queries, n_results=10, include=['documents', 'embeddings'])
retrieved_documents = results['documents']再次使用 LLM 評(píng)估每個(gè)問(wèn)題與相應(yīng)文檔的相關(guān)性:
PROMPT_EVALUATION = """
You are a helpful expert financial research assistant.
You help users analyze financial statements to better understand companies.
For the given query, evaluate whether the following satement is relevant.
Output only 'yes' or 'no'.
"""
def evaluate_results(query, statement, model="gpt-3.5-turbo"):
messages = [
{
"role": "system",
"content": PROMPT_EVALUATION,
},
{
"role": "user",
"content": f"Query: {query}, Statement: {statement}"
}
]
response = openai_client.chat.completions.create(
model=model,
messages=messages,
max_tokens=1
)
content = response.choices[0].message.content
if content == "yes":
return 1
return -1然后,將訓(xùn)練數(shù)據(jù)結(jié)構(gòu)化為問(wèn)答元組。每個(gè)元組將包含查詢(xún)的嵌入、文檔的嵌入和評(píng)估標(biāo)簽(1,-1)。
retrieved_embeddings = results['embeddings']
query_embeddings = embedding_function(generated_queries)
adapter_query_embeddings = []
adapter_doc_embeddings = []
adapter_labels = []
for q, query in enumerate(tqdm(generated_queries)):
for d, document in enumerate(retrieved_documents[q]):
adapter_query_embeddings.append(query_embeddings[q])
adapter_doc_embeddings.append(retrieved_embeddings[q][d])
adapter_labels.append(evaluate_results(query, document))最后,生成完訓(xùn)練元組后,將其放入torch數(shù)據(jù)集,為訓(xùn)練做準(zhǔn)備。
2)定義模型
定義了一個(gè)以查詢(xún)嵌入、文檔嵌入和適配器矩陣為輸入的函數(shù)。該函數(shù)首先將查詢(xún)嵌入與適配器矩陣相乘,然后計(jì)算該結(jié)果與文檔嵌入之間的余弦相似度。
def model(query_embedding, document_embedding, adaptor_matrix):
updated_query_embedding = torch.matmul(adaptor_matrix, query_embedding)
return torch.cosine_similarity(updated_query_embedding, document_embedding, dim=0)3)定義損失(loss)
目標(biāo)是最小化前一個(gè)函數(shù)計(jì)算出的余弦相似度。為此,將使用均方誤差(MSE)損失來(lái)優(yōu)化適配器矩陣的權(quán)重。
def mse_loss(query_embedding, document_embedding, adaptor_matrix, label):
return torch.nn.MSELoss()(model(query_embedding, document_embedding, adaptor_matrix), label)4)訓(xùn)練
初始化適配器矩陣,并完成訓(xùn)練 100 次epochs。
# Initialize the adaptor matrix
mat_size = len(adapter_query_embeddings[0])
adapter_matrix = torch.randn(mat_size, mat_size, requires_grad=True)
min_loss = float('inf')
best_matrix = None
for epoch in tqdm(range(100)):
for query_embedding, document_embedding, label in dataset:
loss = mse_loss(query_embedding, document_embedding, adapter_matrix, label)
if loss < min_loss:
min_loss = loss
best_matrix = adapter_matrix.clone().detach().numpy()
loss.backward()
with torch.no_grad():
adapter_matrix -= 0.01 * adapter_matrix.grad
adapter_matrix.grad.zero_()訓(xùn)練完成后,適配器可用于擴(kuò)展原始嵌入,并適配用戶(hù)任務(wù)。
test_vector = torch.ones((mat_size,1))
scaled_vector = np.matmul(best_matrix, test_vector).numpy()
test_vector.shape
# torch.Size([384, 1])
scaled_vector.shape
# (384, 1)
best_matrix.shape
# (384, 384)在檢索階段,只需將原始嵌入輸出與適配器矩陣相乘,然后輸入檢索系統(tǒng)即可。
以上三種方法操作性較強(qiáng),感興趣的讀者可以將其應(yīng)用到現(xiàn)有的RAG應(yīng)用中,來(lái)評(píng)估這些手段對(duì)于自己的場(chǎng)景有效性。
相關(guān)鏈接:
【1】https://arxiv.org/pdf/2212.10496.pdf
【2】https://arxiv.org/pdf/2305.03653.pdf
【3】https://www.sbert.net/examples/applications/retrieve_rerank/README.html
原文來(lái)自:Ahmed Besbes:3 Advanced Document Retrieval Techniques To Improve RAG Systems



























