Published on

LLM - Retrieval Augmented Generation

Authors
  • avatar
    Name
    Ryan Chung
    Twitter

目錄

本篇是 LLM 系列文的第三章,著重在 RAG 的原理和實作方法。
其他內容請參考以下連結:

  1. LLM 首部曲: Building Applications with APIs
  2. LLM 二部曲: Prompting and Alignment (SFT+RL)
  3. LLM 三部曲: Retrieval Augmented Generation (RAG)

檢索增強生成 RAG

在前文中,我們已經示範如何使用 supervised fine-tuning (SFT) 與 reinforcement learning (RL) 的方式來對 LLM 做 alignment。 然而面對逐漸龐大的語言模型,使用微調的方式不見得是最好的做法。 首先,訓練模型需要足夠的硬體資源。就算是透過 LoRA 與量化技術,也難以在一般 GPU 上訓練像 GPT3 這種上百 billion 參數的模型。 再來,LLM 流暢的對答與廣泛的知識能力都儲存在參數裡,任意微調並變動這些參數可能會傷害其泛用性。 最後,如果訓練資料必須經常更新(例如前文的股票預測),每此更新都需要重新微調模型,未免有些不切實際。

一種簡單且有效的方式是採用 in-context learning,將資料直接丟進去 prompt 裡與使用者輸入一同互動,就可以根據資料直接進行回答。 然而,目前所有的 LLM 都有輸入字串的上限。以 GPT-3.5 為例,只能同時輸入 4,096 個 tokens,大約 2,000 字。 這體現在兩個面向,第一是無法同時輸入大量內容(例如整篇論文),第二是無法保有長期記憶(超過 2,000 字的舊對話無法繼續作為輸入)。 會這樣是因為 self-attention 機制需要對上下文每個字之間做運算,一旦輸入變長,計算量 O(n2) 就會非常可觀。

過去對長文摘要的一個常見作法,是將每個章節或段落個別摘要,再將所有摘要做一個整體摘要。 然而這樣並無法正確解讀摘要以外的細節,是故我們必須發展出一個更有效的資料存取方式,將輸入問題與文章重點進行匹配。 檢索增強生成 Retrieval Augmented Generation (RAG) 就是為此而發展的技術。 我們首先將資料(如網頁、pdf、json)分割成數個區塊 (chunk),再將這些區塊以 embedding vector 的形式儲存至向量資料庫。 要搜尋時,我們會透過資料檢索系統將問題與資料庫內容做「相似度匹配」,選出最適合的資料後一起丟給 generative model (eg. LLM) 做答案的生成。 整個流程可以參考以下圖片。

Indexing, retrieval and generation. Source: LangChain, Build a RAG App

過去研究表明,RAG 在知識性問答(Jeopardy question)的情境中能顯著提升答案的事實性 (Factuality 42.7%) 和具體性 (Specificity 37.4%) 1,大幅降低幻覺 (hallucination) 的發生。 這項技術宛如幫 LLM 裝上一個外接硬碟,能夠隨時更新不同領域的專業知識,推動其在企業內部的運用。 例如:基於內部文件與 QA 的客服聊天機器人、技術文件的檢索與問答等。

如果我們進一步思考,可以發現這項技術成功的關鍵,必須仰賴優良的資料檢索系統 (retrieval system)。 前面說的「相似度匹配」,具體做法是將輸入問題取 embedding,接著與向量資料庫的 chunks 做內積, 就可以從最大值找出前幾接近的內容,這種方式稱作 maximum inner product search (MIPS)。 然而,對於像 Wikipedia 這樣龐大的資料庫,如果直接使用暴力搜尋的方式將問題與所有 chunks 做內積,將會非常花時間。 我們可以用 approximate nearest neighbor (ANN) search 的方式,預先將資料庫做分類或量化,就可以快速搜尋出相似內容,再去進一步排序。

RAG Workflow

接下來我們將實際操作一次 RAG 流程。 以下內容修改自 Nvidia Deep Learning Institute 的線上課程 Building RAG Agents with LLMs, 如果有時間也推薦大家自行上網學習。

在實作之前,我們可以先來了解什麼是 LangChain。 LangChain 是專門用來開發 LLM 應用程式的框架,整合各種 model、retrieval、agent 模組,對於開發 RAG agent 非常有幫助。 正如其名,它可以用 chain 的方式鏈結不同輸入與元件,將複雜的應用透過小零件逐步搭建起來。

例如我們要開發一個簡單的 LLM chatbot,讓所有對答必須要押韻。 可以使用 LangChain 以及第一篇文曾提過的 Google Gemini 服務。 範例如下:

pip install langchain langchain-google-genai
from langchain_google_genai import GoogleGenerativeAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

llm = GoogleGenerativeAI(model="gemini-1.5-flash", google_api_key=API_KEY)
prompt = ChatPromptTemplate.from_messages([
    ("system", "Only respond in rhymes"),
    ("user", "{input}")
])
rhyme_chain = prompt | llm | StrOutputParser()
print(rhyme_chain.invoke({"input" : "Tell me about yourself!"}))

我們將 prompt 模板與 LLM model 透過 | 符號串連 (pipe) 起來,再透過 LangChain invoke() 函式將輸入丟進去 rhyme_chain,得到押韻結果如下:

A language model, here I stand,
With words at hand, I'll do my best to understand.
To learn and grow, my purpose is to serve,
To answer questions, and to make you observe. 
I'm a tool for knowledge, a friend to you,
With information vast, I'll see you through. 
So tell me your desires, your thoughts, your fears,
And I'll respond with rhymes, through all the years. 

LangChain 整合了目前所有常見的 AI 供應商,除了 Google Generative AI,還有 OpenAI、Anthropic、Hugging Face、Nvidia AI Endpoints 等等,儼然已成為龐大的社群。 不論公司習慣使用什麼服務,都有對應的接口可以直接套用。

儲存與檢索

接下來我們要挑選適合的 encoder model 來將文件轉換成 embedding vector。 理想上好的 embedder 可以支援不同語系,並將原文壓縮且保留足夠重要的資訊。 我們可以參考 Hugging Face 上 Massive Text Embedding Benchmark (MTEB) 排行榜來挑選合適的模型, 或是參考 這篇 不錯的文章。

這裡我們先用 Google GenAI 預設的 embedding-001 做示範:

from langchain_google_genai import GoogleGenerativeAIEmbeddings

embedder = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=API_KEY)
vector = embedder.embed_query("Hello World!")
vector[:5]

如果我們將整串文字 Hello World! 丟入 embedder,會產生如下的 embedding vector:

[0.05169878527522087,
 -0.033477481454610825,
 -0.031893402338027954,
 -0.029319265857338905,
 0.019925475120544434]

接著我們要挑選合適的向量資料庫 (Vector DB)。 目前市面上常見的有 ChromaMilvus 等專業資料庫,支援 ANN search。 一些經典選擇如 PostgreSQLMongoDB 也開始支援向量儲存與搜尋功能。 這裡我們使用 Meta 開發的 FAISS (Facebook AI Similarity Search) 套件, 它不完全是一個資料庫,而是將資料暫存在記憶體裡供高效搜尋與擴充。

pip install langchain-community faiss-cpu # 如果使用 GPU 可以安裝 faiss-gpu
from langchain.vectorstores import FAISS

conversation = [
    "[User]  Hello! My name is Beras, and I'm a big blue bear! Can you please tell me about the rocky mountains?",
    "[Agent] The Rocky Mountains are a beautiful and majestic range of mountains that stretch across North America",
    "[Beras] Wow, that sounds amazing! Ive never been to the Rocky Mountains before, but Ive heard many great things about them.",
    "[Agent] I hope you get to visit them someday, Beras! It would be a great adventure for you!"
    "[Beras] Thank you for the suggestion! Ill definitely keep it in mind for the future.",
    "[Agent] In the meantime, you can learn more about the Rocky Mountains by doing some research online or watching documentaries about them."
    "[Beras] I live in the arctic, so I'm not used to the warm climate there. I was just curious, ya know!",
    "[Agent] Absolutely! Lets continue the conversation and explore more about the Rocky Mountains and their significance!"
]

vector_store = FAISS.from_texts(conversation, embedding=embedder)
retriever = vector_store.as_retriever()

print(retriever.invoke("What is your name?"))

在上面的範例中,我們提供一段 chatbot 的對話,將這段對話透過 FAISS 轉成向量資料,並建立檢索器 (retriever)。 接著我們只需要輸入任意字串,就可以從資料中匹配出最接近的內容:

[Document(metadata={}, page_content="[User]  Hello! My name is Beras, and I'm a big blue bear! Can you please tell me about the rocky mountains?"), ...]

了解流程後,我們就可以自行建立 RAG chain,將使用者的問題與檢索結果一起丟給 decoder model (LLM),生成合適的回答。

context_prompt = ChatPromptTemplate.from_template(
    "Answer the question using only the context"
    "\n\nRetrieved Context: {context}"
    "\n\nUser Question: {question}"
    "\nAnswer the user conversationally. User is not aware of context."
)

chain = (
    {'context': vector_store.as_retriever(),
     'question': (lambda x: x)}
    | context_prompt
    | llm
    | StrOutputParser()
)

print(chain.invoke("Where does Beras live?"))

'''
[Output]
Based on the conversation, Beras lives in the arctic!
'''

值得注意的是,如果我們的問題無法在原始資料裡找到答案,因為 system prompt 闡明「Answer the question using only the context」, 所以語言模型只會基於我們的資料做回答,不會自行腦補出幻覺。 這對於基於真實資料的 chatbot (如客服系統)來說非常重要。

print(chain.invoke("How far away is Beras from the Rocky Mountains?"))

'''
[Output]
We don't have enough information to know how far Beras is from the Rocky Mountains.
We know Beras lives in the arctic, but we don't know exactly where that is.
'''

Building RAG Agents

這一節,我們將從頭打造自己的 chatbot 服務,並利用 RAG 方式來檢索論文並回答相關問題。 LangChain 提供多種方式載入文件,包含 word、pdf、csv、Google Drive 等。 這裡我們使用 arxiv loader 來下載 arXiv 上的論文,再使用 PyMuPDF 來讀取 pdf 檔。 注意儘管 LangChain 已整合相關模組,依然要安裝原始套件才能使用這些功能。

pip install langchain-community arxiv pymupdf
from langchain.document_loaders import ArxivLoader

docs = [
    ArxivLoader(query="1706.03762").load(),  # Attention Is All You Need Paper
    ArxivLoader(query="1810.04805").load(),  # BERT Paper
    ArxivLoader(query="2005.11401").load(),  # RAG Paper
    ArxivLoader(query="2205.00445").load(),  # MRKL Paper
    ArxivLoader(query="2310.06825").load(),  # Mistral Paper
    ArxivLoader(query="2306.05685").load(),  # LLM-as-a-Judge Paper
]

接著我們移除 references,並將剩餘的文件分割成適當大小的 chunks。 可以使用 LangChain 的 RecursiveCharacterTextSplitter(),會參考 separators 將文章做更有意義的分割。

import json
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Cut the paper short if references is included.
for doc in docs:
    content = json.dumps(doc[0].page_content)
    if "References" in content:
        doc[0].page_content = content[:content.index("References")]

# Split the documents and also filter out stubs (overly short chunks)
print("Chunking Documents\n")
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, chunk_overlap=100,
    separators=["\n\n", "\n", ".", ";", ",", " "],
)
docs_chunks = [ text_splitter.split_documents(doc) for doc in docs ]

最後我們將不同論文的 chunks 整合並儲存進 FAISS 向量空間裡。

docstore = FAISS()
for chunk in docs_chunks:
    vecstore = FAISS.from_documents(chunk, embedder)
    docstore.merge_from(vecstore)

建立 retrieval chain 來檢索向量空間,同時建立 stream chain 來將 retrieval 內容丟到 LLM 生成輸出。

from langchain_core.runnables.passthrough import RunnableAssign

initial_msg = (
    "Hello! I am a document chat agent here to help the user!"
    f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)

chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked: {input}\n\n"
    " From this, we have retrieved the following potentially-useful info: "
    " Conversation History Retrieval:\n{history}\n\n"
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])

stream_chain = chat_prompt | llm | StrOutputParser()
retrieval_chain = (
    {'input' : (lambda x: x)}
    | RunnableAssign({'history' : (lambda x: x['input']) | convstore.as_retriever()})
    | RunnableAssign({'context' : (lambda x: x['input']) | docstore.as_retriever()})
)

def chat_gen(message, history=[], return_buffer=True):
    buffer = ""
    retrieval = retrieval_chain.invoke(message)
    line_buffer = ""

    for token in stream_chain.stream(retrieval):
        buffer += token
        yield buffer if return_buffer else token

    save_memory_and_get_output({'input':  message, 'output': buffer}, convstore)

最後當我們 invoke 問題,RAG agent 就會檢索內容並生成回覆。 例如我提問「Tell me about RAG」,系統會根據我提供的 RAG paper 作回答:

chat_gen("Tell me about RAG!", return_buffer=False)

'''
[Output]
RAG stands for Retrieval-Augmented Generation. It's a pretty cool technique used in natural language processing (NLP) that combines the strengths of two different types of models:  parametric and non-parametric. 
Think of it like this: The parametric model is like a really smart person who has learned a lot from reading lots of books, but might not know everything. The non-parametric model is like a massive library, full of information but needing someone to find the right books. 
'''

我們可以進一步使用 Gradio, 或是直接透過 LangServe(整合 FastAPI)來建立 REST API server。

pip install gradio
import gradio as gr

chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
demo.launch(debug=True, share=True, show_api=False)
demo.close()

這樣就可以快速打造我們的 RAG 應用。 當我們提問「Tell me about "Mistral 7B"」時,系統會檢索論文並還原出 Github 跟網頁連結。 上述完整程式碼可以參考我的 Github 專案

Footnotes

  1. Patrick Lewis, et al. "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks." Advances in Neural Information Processing Systems 33, 9459-9474, 2020.