Home
Tags Projects About
From RAGs to Riches: Hands-on Retrieval-Augmented Generation

From RAGs to Riches: Hands-on Retrieval-Augmented Generation

In our last post, we explored the basics of Retrieval-Augmented Generation (RAG), a method that combines the retrieval of relevant external data with the generative power of large language models (LLMs). This combination allows AI systems to produce responses that are not only accurate but also contextually rich and up-to-date. Today, we’ll dive into two practical implementations of RAG, showing how to integrate external data sources to enhance your AI capabilities.

RAG Architecture

At its core, RAG enriches language models by incorporating relevant, real-time information from external sources, augmenting the model’s responses with the data it needs to generate more informed and precise answers. Here’s a breakdown of the RAG architecture:

1. User Query Submission

The process begins when a user submits a query. This could be anything from a straightforward question to a complex problem that requires detailed context. The query is sent to an orchestrator, which acts as the system’s command center, managing the flow of information throughout the RAG process.

2. Retrieval Phase

Orchestrator to Retriever

The orchestrator sends the query to the retriever, whose job is to gather relevant information that will enhance the language model’s response.

Gathering Relevant Information

The retriever accesses a knowledge source — whether it’s a database, an API, a document repository, or another data source rich in factual information. It processes the query and returns the most relevant documents, snippets, or data points. The quality of this retrieved information is crucial, as it directly influences the relevance and accuracy of the final output.

3. Augmentation Phase

Orchestrator Augments the Query

With the relevant information in hand, the orchestrator enriches the original query with this new context. This step amplifies the input given to the language model, preparing it to generate a more detailed and contextually appropriate response.

Forming the Augmented Query

The augmented query is essentially an enhanced version of the original query. It’s the user’s question, but now supplemented with additional, retrieved information, providing a rich prompt for the language model to work from.

4. Generation Phase

Passing to the Language Model

The orchestrator forwards this augmented query to the LLM. With the added context, the model is primed to process the query more effectively.

Generating the Response

Using its advanced natural language understanding and generation capabilities, the language model generates a response. Thanks to the retrieval phase, the answer is not only accurate but also finely tuned to the user’s needs.

5. Response Delivery

Finally, the orchestrator delivers the generated response back to the user, completing a more informed and context-aware interaction.

Practical Implementations of RAG

Let’s put the RAG architecture into perspective with concrete examples, let’s explore two practical implementations.

📈 Combining Stock Data and News Articles

Below implementation shows a chatbot leveraging stock data and relevant news articles to provide fresh information to user queries. It retrieves and processes external info, augments the user’s query with this data, and generates a response using a pre-trained LLM. Imagine a user asks, "What is the current market outlook for Tesla?"

We’ll walk through each component and explain how they fit together to create a powerful RAG system.

Setting Up the Retriever

The retriever’s job is to gather information that will enrich the LLM’s responses. In this case, we’ll pull stock data from Yahoo Finance and scrape related news articles.

1. Fetching Stock Data

To retrieve comprehensive stock information, including financials and news articles, we use the yfinance library. Here's how it works:

class Retriever:
    @staticmethod
    def get_stock_data(symbol):
        stock = yf.Ticker(symbol)
        stock_info = stock.info
        news = stock.news if hasattr(stock, "news") else []
        financials = stock.quarterly_financials
        return stock_info, news, financials

This code snippet fetches the stock information for a given symbol, along with news articles and financial data.

2. Scraping News Articles

Next, we scrape the content of news articles using BeautifulSoup. This provides additional context that can be used to augment user queries:

@staticmethod
def fetch_story_content(link):
    try:
        response = requests.get(link)
        response.raise_for_status()
        soup = BeautifulSoup(response.content, "html.parser")
        paragraphs = soup.find_all("p")
        content = " ".join([para.get_text() for para in paragraphs])
        return content
    except Exception as e:
        return ""

This function retrieves the text from the paragraphs of an article given its URL.

Preprocessing and Document Retrieval

To ensure we’re retrieving the most relevant documents, we preprocess them using TF-IDF vectorization and cosine similarity.

1. Preprocessing Documents

Convert documents into TF-IDF vectors for similarity computation:

@staticmethod
def preprocess_documents(documents):
    texts = [Retriever.fetch_story_content(doc.get("link", "")) for doc in documents]
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform(texts)
    return vectorizer, tfidf_matrix, documents

This function creates TF-IDF vectors from the scraped article content, preparing them for similarity analysis. TF-IDF helps us identify the most relevant documents by giving more weight to words that are unique to each document, thereby reducing the influence of common words.

2. Retrieving Relevant Documents

Retrieve the most relevant documents using cosine similarity, ensuring the augmented query is highly contextual:

@staticmethod
def retrieve_documents(query, vectorizer, tfidf_matrix, documents, k=NUM_DOCS):
    query_vec = vectorizer.transform([query])
    similarities = cosine_similarity(query_vec, tfidf_matrix).flatten()
    related_docs_indices = similarities.argsort()[-k:][::-1]
    return [documents[i] for i in related_docs_indices]

This function finds the top k documents most similar to the query, providing the best contextual fit for augmentation. Cosine similarity measures the angle between vectors, helping us identify documents that are most aligned with the query.

Augmenting the Query

Next, we’ll augment the user’s query with the retrieved information to provide the LLM with a richer context for generating responses.

1. Augmenting the User's Query

@staticmethod
def augment_query_with_documents(query, documents, stock_info, financials, max_length=MAX_LENGTH):
    context_docs = Augmenter.generate_context_docs(documents)
    stock_info_text = Augmenter.stock_info_to_text(stock_info)
    augmented_query = (
        f"{context_docs}\n{stock_info_text}\nAnswer the following question based on the above context:\n{query}\nAnswer:"
    )
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL)
    inputs = tokenizer(augmented_query, return_tensors="pt")
    if inputs["input_ids"].size(1) > max_length:
        truncated_inputs = tokenizer.encode_plus(
            augmented_query, max_length=max_length, truncation=True, return_tensors="pt"
        )
        augmented_query = tokenizer.decode(
            truncated_inputs["input_ids"][0], 
            skip_special_tokens=True
        )
    
    return augmented_query

In this code, the AutoTokenizer is used to prepare the augmented query for the LLM, ensuring it fits within the model's input constraints.

Generating Responses

Finally, we generate responses using a pre-trained LLM. This step involves using the augmented query to produce informative answers.

1. Initializing the Generator

Load and set up the pre-trained model:

class Generator:
    def __init__(self, model_name=GENERATOR_MODEL):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        if torch.backends.mps.is_available():
            print("Using MPS")
            self.device = torch.device("mps")
            self.model.to(self.device)
        else:
            print("Using CPU")
            self.device = torch.device("cpu")

We’re utilizing MPS (Metal Performance Shaders) to accelerate computation on Apple hardware, making the model run faster compared to using just the CPU. It’s a framework by Apple that allows for efficient processing on Apple devices, particularly those with M1 chips (like my Mac).

2. Generating the Response

Use the augmented query to produce a response:

def generate_response(self, augmented_query, max_new_tokens=MAX_NEW_TOKENS):
    inputs = self.tokenizer(augmented_query, return_tensors="pt").to(self.device)
    outputs = self.model.generate(
        **inputs,
        pad_token_id=self.tokenizer.eos_token_id,
        max_new_tokens=max_new_tokens,
        num_return_sequences=NUM_RETURN_SEQUENCES
    )
    response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response.split("Answer: ")[-1]

In this function, AutoModelForCausalLM is used to generate text. This class provides an interface for causal language models, which are particularly suited for tasks like text completion and response generation.

3. Integrating Components

Let’s combine all components into a chat engine. The chat engine can also be called an orchestrator because it coordinates the various components (retriever, augmenter, generator) to produce the final response.

class ChatEngine:
    def __init__(self, retriever, augmenter, generator, symbol="TSLA"):
        self.retriever = retriever
        self.augmenter = augmenter
        self.generator = generator
        self.symbol = symbol
        self.initialize_data()

    def initialize_data(self):
        stock_info, news, financials = self.retriever.get_stock_data(self.symbol)
        vectorizer, tfidf_matrix, documents = self.retriever.preprocess_documents(news)
        self.vectorizer = vectorizer
        self.tfidf_matrix = tfidf_matrix
        self.documents = documents
        self.financials = financials
        self.stock_info = stock_info

    def chat(self, query):
        retrieved_docs = self.retriever.retrieve_documents(
            query, 
            self.vectorizer, 
            self.tfidf_matrix, 
            self.documents, 
            k=NUM_DOCS
        )
        augmented_query = self.augmenter.augment_query_with_documents(
            query, 
            retrieved_docs, 
            self.stock_info, 
            self.financials
        )
        response = self.generator.generate_response(augmented_query)
        return response

    def run_chat(self):
        print("Chatbot is ready! Type your questions below (type 'exit' to quit):")
        while True:
            user_input = input("You: ")
            if user_input.lower() in ["exit", "quit"]:
                break
            bot_response = self.chat(user_input)
            print(f"Bot: {bot_response}")

if __name__ == "__main__":
    retriever = Retriever()
    augmenter = Augmenter()
    generator = Generator()

    chat_engine = ChatEngine(retriever, augmenter, generator)
    chat_engine.run_chat()

Now we can run it and talk with our chat bot:

Screenshot

Elon

🐦 Tweet-Based Retrieval and Generation

In the second example, we focus on retrieving relevant tweets from a collection and generating responses using an LLM. It uses vector stores and embeddings to enhance the retrieval process.

In this example, we'll use a more sophisticated approach comprated to the previous one and leverage third-party libraries to optimize the entire process and offload many tasks to them.

Setting Up the Embeddings and Vector Store

We use sentence-transformers to create tweet embeddings and FAISS, Meta's in-memory vector store, for efficient vector storage and retrieval.

1. Creating Embeddings

Convert tweets into embeddings for similarity search:

class ChatEngine:
    def __init__(self):
        if torch.backends.mps.is_available():
            print("Using MPS")
            self.device = torch.device("mps")
        else:
            print("Using CPU")
            self.device = torch.device("cpu")

        self.embeddings = HuggingFaceEmbeddings(
            model_name=EMBEDDING_MODEL_NAME,
            model_kwargs={"device": str(self.device)}
        )

        self.tokenizer = AutoTokenizer.from_pretrained(GENERATOR_MODEL)
        self.model = AutoModelForCausalLM.from_pretrained(GENERATOR_MODEL).to(self.device)
        self.pipe = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_new_tokens=MAX_NEW_TOKENS,
            return_full_text=True,
            num_return_sequences=NUM_RETURN_SEQUENCES,
            device=self.device,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        self.llm = HuggingFacePipeline(pipeline=self.pipe)
        if not os.path.exists(VECTOR_STORE_PATH):
            self.create_vector_store_from_text()

Here, HuggingFaceEmbeddings is used to generate embeddings from tweets, utilizing the sentence-transformers library for efficient encoding. This model balances performance and speed, making it suitable for real-time applications. Additionally, the language model for text generation is initialized using AutoTokenizer and AutoModelForCausalLM.

2. Building the Vector Store

Store tweet embeddings using FAISS:

def create_vector_store_from_text(self):
    loader = DirectoryLoader(
        DOCUMENT_PATH, 
        glob=DOCUMENT_REGEXP, 
        loader_cls=TextLoader
    )
    documents = loader.load()
    
    splitter = RecursiveCharacterTextSplitter()
    texts = splitter.split_documents(documents)
    
    vector_store = FAISS.from_documents(
        documents=texts,
        embedding=self.embeddings
    )
    vector_store.save_local(VECTOR_STORE_PATH)

Tweet Loader and Chat Engine

The chat engine retrieves relevant tweet snippets and uses an LLM to generate responses.

1. Loading Tweets

Load and split tweets for embedding:

def __init__(self):
    if not os.path.exists(VECTOR_STORE_PATH):
        self.create_vector_store_from_text()

This initialization step ensures that the vector store is created if it doesn't already exist.

2. Generating Responses

For this example we will use Meta’s latest LLM, 🦙 Llama 3.1. It is a beast, particularly the 405B model, which stands out not just for its scale but for how it closes the gap between open-source and closed-source models.

Use the retrieval QA chain to produce contextually relevant answers:

def chat(self, query):
    vector_store = FAISS.load_local(
        VECTOR_STORE_PATH, 
        self.embeddings, 
        allow_dangerous_deserialization=True
    )
    retriever = vector_store.as_retriever()
    prompt = hub.pull("rlm/rag-prompt")
    chain = RetrievalQA.from_chain_type(
        llm=self.llm,
        chain_type="stuff",
        retriever=retriever,
        return_source_documents=False,
        chain_type_kwargs={"prompt": prompt},
        verbose=False
    )
    result = chain.invoke({"query": query})
    return result["result"].split("Answer: ")[-1]

def run_chat(self):
    print("Chatbot is ready! Type your questions below (type 'exit' to quit):")
    while True:
        query = input("You: ")
        if query.lower() in ["exit", "quit"]:
            print("Goodbye!")
            break
        answer = self.chat(query)
        print(f"Bot: {answer}")

In this setup, the RetrievalQA chain is used to process queries, retrieve relevant tweets, and generate responses using the LLM. We chose this approach to leverage the strengths of pre-existing models and simplify the development process.

Now, run this code and interact with your chatbot:

Screenshot

Elon

You can find the full source code on my GitHub repository. Data source

Additional materials



Previous post
Buy me a coffee

More? Well, there you go:

From RAGs to Riches: An In-Depth Look at Retrieval-Augmented Generation

From ETL and ELT to Reverse ETL

Hidden Pitfalls of LLM in Education