In our last post, we explored the basics of Retrieval-Augmented Generation (RAG), a method that combines the retrieval of relevant external data with the generative power of large language models (LLMs). This combination allows AI systems to produce responses that are not only accurate but also contextually rich and up-to-date. Today, we’ll dive into two practical implementations of RAG, showing how to integrate external data sources to enhance your AI capabilities.
RAG Architecture
At its core, RAG enriches language models by incorporating relevant, real-time information from external sources, augmenting the model’s responses with the data it needs to generate more informed and precise answers. Here’s a breakdown of the RAG architecture:
1. User Query Submission
The process begins when a user submits a query. This could be anything from a straightforward question to a complex problem that requires detailed context. The query is sent to an orchestrator, which acts as the system’s command center, managing the flow of information throughout the RAG process.
2. Retrieval Phase
Orchestrator to Retriever
The orchestrator sends the query to the retriever, whose job is to gather relevant information that will enhance the language model’s response.
Gathering Relevant Information
The retriever accesses a knowledge source — whether it’s a database, an API, a document repository, or another data source rich in factual information. It processes the query and returns the most relevant documents, snippets, or data points. The quality of this retrieved information is crucial, as it directly influences the relevance and accuracy of the final output.
3. Augmentation Phase
Orchestrator Augments the Query
With the relevant information in hand, the orchestrator enriches the original query with this new context. This step amplifies the input given to the language model, preparing it to generate a more detailed and contextually appropriate response.
Forming the Augmented Query
The augmented query is essentially an enhanced version of the original query. It’s the user’s question, but now supplemented with additional, retrieved information, providing a rich prompt for the language model to work from.
4. Generation Phase
Passing to the Language Model
The orchestrator forwards this augmented query to the LLM. With the added context, the model is primed to process the query more effectively.
Generating the Response
Using its advanced natural language understanding and generation capabilities, the language model generates a response. Thanks to the retrieval phase, the answer is not only accurate but also finely tuned to the user’s needs.
5. Response Delivery
Finally, the orchestrator delivers the generated response back to the user, completing a more informed and context-aware interaction.
Practical Implementations of RAG
Let’s put the RAG architecture into perspective with concrete examples, let’s explore two practical implementations.
📈 Combining Stock Data and News Articles
Below implementation shows a chatbot leveraging stock data and relevant news articles to provide fresh information to user queries. It retrieves and processes external info, augments the user’s query with this data, and generates a response using a pre-trained LLM. Imagine a user asks, "What is the current market outlook for Tesla?"
We’ll walk through each component and explain how they fit together to create a powerful RAG system.
Setting Up the Retriever
The retriever’s job is to gather information that will enrich the LLM’s responses. In this case, we’ll pull stock data from Yahoo Finance and scrape related news articles.
1. Fetching Stock Data
To retrieve comprehensive stock information, including financials and news articles, we use the yfinance library. Here's how it works:
class Retriever:
@staticmethod
def get_stock_data(symbol):
stock = yf.Ticker(symbol)
stock_info = stock.info
news = stock.news if hasattr(stock, "news") else []
financials = stock.quarterly_financials
return stock_info, news, financials
This code snippet fetches the stock information for a given symbol, along with news articles and financial data.
2. Scraping News Articles
Next, we scrape the content of news articles using BeautifulSoup
. This provides additional context that can be used to augment user queries:
@staticmethod
def fetch_story_content(link):
try:
response = requests.get(link)
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
paragraphs = soup.find_all("p")
content = " ".join([para.get_text() for para in paragraphs])
return content
except Exception as e:
return ""
This function retrieves the text from the paragraphs of an article given its URL.
Preprocessing and Document Retrieval
To ensure we’re retrieving the most relevant documents, we preprocess them using TF-IDF vectorization and cosine similarity.
1. Preprocessing Documents
Convert documents into TF-IDF vectors for similarity computation:
@staticmethod
def preprocess_documents(documents):
texts = [Retriever.fetch_story_content(doc.get("link", "")) for doc in documents]
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(texts)
return vectorizer, tfidf_matrix, documents
This function creates TF-IDF vectors from the scraped article content, preparing them for similarity analysis. TF-IDF helps us identify the most relevant documents by giving more weight to words that are unique to each document, thereby reducing the influence of common words.
2. Retrieving Relevant Documents
Retrieve the most relevant documents using cosine similarity, ensuring the augmented query is highly contextual:
@staticmethod
def retrieve_documents(query, vectorizer, tfidf_matrix, documents, k=NUM_DOCS):
query_vec = vectorizer.transform([query])
similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
related_docs_indices = similarities.argsort()[-k:][::-1]
return [documents[i] for i in related_docs_indices]
This function finds the top k
documents most similar to the query, providing the best contextual fit for augmentation. Cosine similarity measures the angle between vectors, helping us identify documents that are most aligned with the query.
Augmenting the Query
Next, we’ll augment the user’s query with the retrieved information to provide the LLM with a richer context for generating responses.
1. Augmenting the User's Query
@staticmethod
def augment_query_with_documents(query, documents, stock_info, financials, max_length=MAX_LENGTH):
context_docs = Augmenter.generate_context_docs(documents)
stock_info_text = Augmenter.stock_info_to_text(stock_info)
augmented_query = (
f"{context_docs}\n{stock_info_text}\nAnswer the following question based on the above context:\n{query}\nAnswer:"
)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
inputs = tokenizer(augmented_query, return_tensors="pt")
if inputs["input_ids"].size(1) > max_length:
truncated_inputs = tokenizer.encode_plus(
augmented_query, max_length=max_length, truncation=True, return_tensors="pt"
)
augmented_query = tokenizer.decode(
truncated_inputs["input_ids"][0],
skip_special_tokens=True
)
return augmented_query
In this code, the AutoTokenizer
is used to prepare the augmented query for the LLM, ensuring it fits within the model's input constraints.
Generating Responses
Finally, we generate responses using a pre-trained LLM. This step involves using the augmented query to produce informative answers.
1. Initializing the Generator
Load and set up the pre-trained model:
class Generator:
def __init__(self, model_name=GENERATOR_MODEL):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
if torch.backends.mps.is_available():
print("Using MPS")
self.device = torch.device("mps")
self.model.to(self.device)
else:
print("Using CPU")
self.device = torch.device("cpu")
We’re utilizing MPS (Metal Performance Shaders) to accelerate computation on Apple hardware, making the model run faster compared to using just the CPU. It’s a framework by Apple that allows for efficient processing on Apple devices, particularly those with M1 chips (like my Mac).
2. Generating the Response
Use the augmented query to produce a response:
def generate_response(self, augmented_query, max_new_tokens=MAX_NEW_TOKENS):
inputs = self.tokenizer(augmented_query, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
num_return_sequences=NUM_RETURN_SEQUENCES
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response.split("Answer: ")[-1]
In this function, AutoModelForCausalLM
is used to generate text. This class provides an interface for causal language models, which are particularly suited for tasks like text completion and response generation.
3. Integrating Components
Let’s combine all components into a chat engine. The chat engine can also be called an orchestrator because it coordinates the various components (retriever, augmenter, generator) to produce the final response.
class ChatEngine:
def __init__(self, retriever, augmenter, generator, symbol="TSLA"):
self.retriever = retriever
self.augmenter = augmenter
self.generator = generator
self.symbol = symbol
self.initialize_data()
def initialize_data(self):
stock_info, news, financials = self.retriever.get_stock_data(self.symbol)
vectorizer, tfidf_matrix, documents = self.retriever.preprocess_documents(news)
self.vectorizer = vectorizer
self.tfidf_matrix = tfidf_matrix
self.documents = documents
self.financials = financials
self.stock_info = stock_info
def chat(self, query):
retrieved_docs = self.retriever.retrieve_documents(
query,
self.vectorizer,
self.tfidf_matrix,
self.documents,
k=NUM_DOCS
)
augmented_query = self.augmenter.augment_query_with_documents(
query,
retrieved_docs,
self.stock_info,
self.financials
)
response = self.generator.generate_response(augmented_query)
return response
def run_chat(self):
print("Chatbot is ready! Type your questions below (type 'exit' to quit):")
while True:
user_input = input("You: ")
if user_input.lower() in ["exit", "quit"]:
break
bot_response = self.chat(user_input)
print(f"Bot: {bot_response}")
if __name__ == "__main__":
retriever = Retriever()
augmenter = Augmenter()
generator = Generator()
chat_engine = ChatEngine(retriever, augmenter, generator)
chat_engine.run_chat()
Now we can run it and talk with our chat bot:
🐦 Tweet-Based Retrieval and Generation
In the second example, we focus on retrieving relevant tweets from a collection and generating responses using an LLM. It uses vector stores and embeddings to enhance the retrieval process.
In this example, we'll use a more sophisticated approach comprated to the previous one and leverage third-party libraries to optimize the entire process and offload many tasks to them.
Setting Up the Embeddings and Vector Store
We use sentence-transformers to create tweet embeddings and FAISS, Meta's in-memory vector store, for efficient vector storage and retrieval.
1. Creating Embeddings
Convert tweets into embeddings for similarity search:
class ChatEngine:
def __init__(self):
if torch.backends.mps.is_available():
print("Using MPS")
self.device = torch.device("mps")
else:
print("Using CPU")
self.device = torch.device("cpu")
self.embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL_NAME,
model_kwargs={"device": str(self.device)}
)
self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
self.model = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(self.device)
self.pipe = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=MAX_NEW_TOKENS,
return_full_text=True,
num_return_sequences=NUM_RETURN_SEQUENCES,
device=self.device,
pad_token_id=self.tokenizer.eos_token_id,
)
self.llm = HuggingFacePipeline(pipeline=self.pipe)
if not os.path.exists(VECTOR_STORE_PATH):
self.create_vector_store_from_text()
Here, HuggingFaceEmbeddings
is used to generate embeddings from tweets, utilizing the sentence-transformers
library for efficient encoding. This model balances performance and speed, making it suitable for real-time applications. Additionally, the language model for text generation is initialized using AutoTokenizer
and AutoModelForCausalLM
.
2. Building the Vector Store
Store tweet embeddings using FAISS:
def create_vector_store_from_text(self):
loader = DirectoryLoader(
DOCUMENT_PATH,
glob=DOCUMENT_REGEXP,
loader_cls=TextLoader
)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter()
texts = splitter.split_documents(documents)
vector_store = FAISS.from_documents(
documents=texts,
embedding=self.embeddings
)
vector_store.save_local(VECTOR_STORE_PATH)
Tweet Loader and Chat Engine
The chat engine retrieves relevant tweet snippets and uses an LLM to generate responses.
1. Loading Tweets
Load and split tweets for embedding:
def __init__(self):
if not os.path.exists(VECTOR_STORE_PATH):
self.create_vector_store_from_text()
This initialization step ensures that the vector store is created if it doesn't already exist.
2. Generating Responses
For this example we will use Meta’s latest LLM, 🦙 Llama 3.1. It is a beast, particularly the 405B model, which stands out not just for its scale but for how it closes the gap between open-source and closed-source models.
Use the retrieval QA chain to produce contextually relevant answers:
def chat(self, query):
vector_store = FAISS.load_local(
VECTOR_STORE_PATH,
self.embeddings,
allow_dangerous_deserialization=True
)
retriever = vector_store.as_retriever()
prompt = hub.pull("rlm/rag-prompt")
chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={"prompt": prompt},
verbose=False
)
result = chain.invoke({"query": query})
return result["result"].split("Answer: ")[-1]
def run_chat(self):
print("Chatbot is ready! Type your questions below (type 'exit' to quit):")
while True:
query = input("You: ")
if query.lower() in ["exit", "quit"]:
print("Goodbye!")
break
answer = self.chat(query)
print(f"Bot: {answer}")
In this setup, the RetrievalQA
chain is used to process queries, retrieve relevant tweets, and generate responses using the LLM. We chose this approach to leverage the strengths of pre-existing models and simplify the development process.
Now, run this code and interact with your chatbot:
You can find the full source code on my GitHub repository. Data source
Additional materials
- A Simple Guide to Retrieval Augmented Generation by Abhinav Kimothi
- Meta's Llama models
- "Searching for Best Practices in Retrieval-Augmented Generation" Paper