Building a lean and robust model serving layer for vector embedding and search with Hugging Face’s new Candle Framework
Photo by Clay Banks on Unsplash
1. Intro
The incredible progress in AI research and tooling over the past decade has given rise to more accurate and more reliable machine learning models, as well as to libraries and frameworks that make it increasingly simpler to integrate AI capabilities into existing applications.
However, one area that still remains a considerable challenge in demanding production environments is inference at scale. Suppose, for example, that we have a simple search service that receives a few keywords, and then uses a language model in order to embed them in a vector and search for similar texts or documents in some vector database. This is quite a popular use case, and also a central part of RAG architecture which is commonly used to apply generative AI on domain specific knowledge and data.
In itself, this seems like a relatively straightforward use case to implement. There are many open-source language models and model hubs that we can use as embedding models in a few lines of code. If we further assume that the number of vectors we need to store and query is relatively moderate (e.g., less than 1M) then there are plenty of simple options for vector storage and search: from plain in-memory storage to databases such as Postgres, Redis, or Elastic.
But, what if our service is required to serve thousands or hundreds of thousands of requests per second? what if we need to maintain a relatively low — sub-second — latency on every request? What if we need to scale out fast to serve bursts of requests?
Although our use case is indeed quite simple, the scale and load requirements certainly make it a challenge. Scalable high throughput systems are usually based on multiple instances of small and performant binaries that can quickly bootstrap, scale, and serve requests. This creates a challenge in the context of machine learning systems, and specifically in deep learning, because the common libraries are usually heavy and clunky, partly because most are implemented in Python which is not easy to scale in demanding environments.
Therefore, when facing such challenges, we will either choose to use some paid serving platform that will take care of the scale, or we will have to create our own dedicated serving layer using a number of technologies.
In response to these challenges, Hugging Face has introduced the Candle framework, described as “a minimalist ML framework for Rust with a focus on performance…and ease of use”. Candle empowers us to build robust and lightweight model inference services in Rust, using a torch-like API. Inference services based on Candle will easily scale, rapidly bootstrap, and process requests blazing fast in a manner that makes it more suitable for cloud native serverless environments that are aimed to deal with challenges of scale and resilience.
The purpose of this post is to take the popular use case described earlier and show how it can be implemented, end to end, using the Candle framework. We will delve into a relatively straightforward but robust implementation of a vector embedding and search REST service based on Candle and Axum (as our web framework). We will use a specific news headlines dataset but the code can be very easily extended to any textual dataset.
This will be a very hands-on and practical post. Section 2 lays out the main design or flow of our service, as well as the involved components we will develop and work with. Section 3 focuses on the Candle framework and shows how to implement a vector embedding and search functionality using a Bert model. Section 4 shows how to wrap the model inference functionality in a REST web service using Axum. Section 5 explains how to create the actual embedding and artifacts that our service will need. Section 6 concludes.
2. High Level Service Design
I will start with the main building blocks of the inference service that we are going to implement. The main requirement is to create an HTTP REST endpoint that will receive a textual query consisting of a few key words, and will respond with the top 5 news headlines that are the most similar to the search query.
For this task, we will use Bert as a language model because it usually performs well on document embedding tasks. In an offline batch process (that will be explained in Section 5), we will use Bert to embed about 20K news headlines or documents and create a vector embedding for each headline that the service will use to search for matches. Both the embedding and the textual headlines will be serialized to a binary format that each service instance will load in order to serve query requests.
Image by author
As illustrated above, upon receiving a request with a search query, the service will first use a language model in order to embed the search query in a vector. Next, it will search the pre-loaded embedding for N most similar vectors (each vector representing a news headline). Finally, it will use the index of the most similar vectors in order to fetch the actual textual headline it represents using a mapping file.
We will implement this by creating a module or a rust library with one struct named BertInferenceModel, that will provide the main functions: model loading, sentence embedding, and vector search. The module will be used by a REST service implemented using Axum web framework. The next section will focus on implementing the module while the Section that follows will be dedicated to the web service itself.
Please note that the next sections include many code examples, though for the sake of clarity they only show the main functionality implemented. For a complete implementation of the solution please refer to the companion git repo linked below.
3. Model Serving and Embedding using Candle
This section will focus on the implementation of the module that will serve as an abstraction layer on top of the Candle library API. We will implement this as a struct named BertInferenceModel, that will consist of 3 main functions: model loading, inference (or sentence embedding), and just a simple vector search using Cosine similarity.
pub struct BertInferenceModel {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
embeddings: Tensor,
}
BertInferenceModel will encapsulate the Bert model and tokenizer that we download from Hugging Face repository and will essentially wrap their functionality.
Loading Models from Hugging Face Hub
BertInferenceModel will be instantiated with a load() function that will return a new instance of the struct, loaded with the relevant model and tokenizer, and ready for inference tasks. The load function’s parameters include the name and revision of the model we wish to load (we will use Bert sentence transformers) and the embedding file path.
pub fn load(
model_name: &str,
revision: &str,
embeddings_filename: &str,
embeddings_key: &str,
) -> anyhow::Result<Self> {}
As shown in the load function code below, loading a model involves creating a Repo struct which contains the properties of the relevant repo in Hugging Face (e.g. name and revision), and then creating an API struct in order to actually connect to the repo and download the relevant files (the model weights that are required to create the model are represented using HuggingFace’s safetensors format).
https://medium.com/media/920e56e68b31ecb4021140bd3e7fa436/href
The api.get function returns the local name of the relevant file (whether it was downloaded or just read from cache). Files will be downloaded only once while subsequent calls to api.get will simply use the cached version. We instantiate a tokenizer struct using the tokenizer config file, and build our model using the weights file (in safetensors format) and a config file.
After we have a model and tokenizer loaded, we can finally load the actual embedding file, which we will use to search for matches. I will later show how we generate it using the same model and then serialize it to a file. Loading embeddings as a Tensor using HuggingFace’s safetensors module is fairly simple, all we need is the file name and a key that was given to the tensor when it was saved.
https://medium.com/media/8247cb8f7283d0c9326c7d05c31d9483/href
Now that we have a model and tokenizer loaded, as well as an embedding vectors in memory, we finished the initialization of the BertInferenceModel that is returned to the calling function, and can move on to implement the inference method.
Sentence Inference and Embedding
The inference function is also fairly straightforward. We start by using the loaded tokenizer in order to encode the sentence (line 5). The encode() function returns an Encoding struct with a get_ids() function that returns an array or a numerical representation of the words in the sentence we want to embed. Next, we wrap the array of token ids in a tensor that we can can feed to our embedding model and use the model’s forward function in order to get the embedding vectors that represent the sentence (line 10).
https://medium.com/media/267b21d8d64f52496c6d53bd74cbb87c/href
The dimensions of the vectors we get in line 12 from the embedding model are [128, 384]. This is so because Bert represents each token or word with a vector of size 384 and the maximum input length of a sentence vector is 128 (because our input has just a few words then most of it is padding). In other words, we essentially get a vector of size 384 per token or word in our sentence, besides padding and other instruction tokens.
Next, we need to compress the sentence vectors from a tensor of size [128, 384] to a single vector of size [1, 384] that will represent or capture “the essence” of the sentence so that we can match it to other sentences in our embedding and find sentences that are similar to it. To do so, and partly because the inputs we are dealing with are short keywords rather than long sentences, we will use max pooling which essentially creates a new vector by taking the max values of each dimension of a given tensor in order to capture the most salient features in each dimension. As you can see below, this is fairly simple to implement using Candle’s API. Finally, we use L2 normalization in order to avoid skew and improve the cosine similarity measure by ensuring all vectors have the same magnitude. You can see the actual implementation of both pooling and normalization functions below.
https://medium.com/media/5217b1a43f966bd313e497d941bdc29f/href
Measuring Vector Similarity
Although this is not directly related to the Candle library, our module will also offer a vector search utility method that will receive a vector and use its internal embedding in order to return the indices of the most similar vectors.
This is implemented pretty naively: we first create a collection of tuples (line 7), where the first member of the tuple will represent the index of the relevant text and the second member will represent the cosine similarity score. We then iterate over all indices, and measure the cosine similarity between each to the given vector we need to match. Finally, we add the tuple with the (index, similarity score) to the collection, sort it, and return the top N requested.
https://medium.com/media/95f51bc3d1dd4ac8ae7d3ff389ccff9b/href
4. Embed and Search Web Service
Now that we have a struct that wraps our main model functionality, we need to encapsulate it in a REST service. We will create a REST endpoint with one POST route that will receive a few keywords in a JSON payload and return the index of the most similar vectors in the preloaded embedding. Upon request, the service will embed the keywords in a vector, search for similarities in its in-memory embedding, and return the index of the most similar vectors. The service will use the index to find the corresponding headlines in the text mapping file.
We will implement the service using the excellent Axum web framework. Most of the relevant code is the typical Axum boilerplate so I won’t get too much into the details of how to create a REST endpoint using Axum. As in many web frameworks, building a REST endpoint typically involves creating a Router and registering a handler function on some route to process the request. However, an ML model serving layer has the additional complexity of managing the state and persistence of the model itself. Model loading can be expensive in terms of performance as it involves the IO of loading model files (whether its from Hugging Face’s repo or local). Likewise, we need to find a way to cache and reuse the model across multiple requests.
To address such requirements, Axum provides the application State feature that we can use to initialize and persist any asset we want to be injected into the context of every request. Lets look at the entire initialization code of the service first, line by line, and see how this works.
https://medium.com/media/2b0695608be4490b326f16db3d9078e6/href
Each service instance will create and load a model wrapper and then cache it for reuse by each request it receives. In line 3 we are creating the model wrapper by calling the load() function to bootstrap and load the model. Besides the name and version of the Bert model we load from HF, we also need to specify the location of the embedding file that we load to memory in order to search for similar vectors, as well as the key that we used when creating the embedding.
In addition to the actual model, we also need to cache the mapping file for reuse by each request. After the service uses the model to embed the keywords, it searches for the most similar vectors in its embedding file and then returns their indices. The service then uses the mapping file in order to extract the actual texts corresponding to the indices of the most similar vectors. In a more robust production system, upon receiving the indices of the most similar vectors from the model wrapper, the service would fetch the actual texts from some fast access database, though in our case it will suffice to read it from a preloaded list of strings stored in a file. In line 10, we load the list that was pre-saved as a binary file.
Now we have 2 assets that we need to cache and reuse — the model (wrapper) and the mapping file. Axum enables us to do this using an Arc or a thread-safe reference-counting pointer that each request will share. As you can see in line 15, we create a new Arc around a tuple that consists of the model wrapper and the mapping file. In line 17–19, we create a new HTTP route to the function that will handle each request.
let shared_state =
Arc::new((bert_model, text_map));
let app = Router::new()
.route(“/similar”, post(find_similar))
.with_state(shared_state);
In order to cache the tuple and make it available to each request we use the with_state(state) function in order to inject it to the relevant request context. Let’s see exactly how.
Handling Requests
Our service will serve HTTP POST requests that carry the following payload that will consist of the keywords and number of similar vectors or headlines we want to receive.
{
“text”: “europe climate change storm”,
“num_results”:5
}
We will implement the corresponding request and response structs of the handler function and Axum will take care of serialization and de-serialization when required.
#[derive(Deserialize)]
struct ReqPayload {
keywords: String,
num_results: u32,
}
#[derive(Serialize)]
struct ResPayload {
text: Vec<String>,
}
Next, we can move to the handler function itself. The handler will take 2 parameters: the application state that we initialized before (Axum will take care of injecting it to each function call), and the request struct we defined earlier.
https://medium.com/media/e0142628f80b21e300815d3f2f9ada42/href
Processing each request will consist of 4 main stages that should be straightforward by now. In line 5, we first extract a reference to the state tuple that holds a reference to the model and mapping file. In line 6, we embed the keywords in a vector using the model. Next, in line 9, we search for N most similar vectors. The score_vector_similarity() function returns a vector of tuples, each consists of an index and cosine similarity score. Finally, we iterate over the indices tuples, extract the string corresponding to the index from the mapping file, and wrap it in the response payload struct.
And.. we are good to go! Although it doesn’t say much necessarily, I have tested this on my Mac with an embedding of about 20K vectors and got a nice avg response time of 100ms. Not bad for Bert based vector embedding + vector search.
curl -s -w “\nTotal time: %{time_total}s\n”
-X POST http://localhost:3000/similar
-H “Content-Type: application/json”
-d ‘{“text”: “self driving cars navigation”, “num_results”: 3}’ | jq
{
“text”: [
“Item:Stereo Acoustic Perception … (index: 8441 score:0.8516491)”,
“Item:Vision-based Navigation of … (index: 7253 score:0.85097575)”,
“Item:Learning On-Road Visual ….. (index: 30670 score:0.8500275)”
]
}
Total time: 0.091665s
(this example was created using an embedding generated over Arxiv papers abstracts dataset. The actual dataset is available here under public domain license.)
5. Generating the Embedding
Just before we conclude, there is one last component in this flow that we need to cover. Until now we have assumed the existence of an embedding file in which we search for similar vectors. However, I haven’t explained how to create the embedding file itself.
Recall that the struct created in the last section — BertInferenceModel, already contains a function that embeds a set of keywords into a vector. When we create a function that is required to embed multiple sets of keywords, all we need to do is process them as a batch.
https://medium.com/media/17c845a3ef61fff1c9fb43ce8006c5a0/href
The main difference in the way we use BertInferenceModel involve using the tokenizer encode_batch function rather than encode, which takes a vector of strings rather than a string. Then we simply stack all the vectors into a single tensor and feed it to the model forward() function just as we did with a single vector embedding (You can see the full source code of the function in the companion repo linked below).
Once we have such function that can embed multiple strings, then the embedding generator itself is pretty straightforward. It uses the rayon crate in order to parallel the embedding of the text file, and afterwards it stack the results together to create a single tensor. Finally, it writes the embedding to disk using the safetensors format. The embedding is an important asset in this pipeline as it needs to be copied to every service instance.
https://medium.com/media/b324ea77d74434e80be3960b9a62451a/href
Now we can conclude 🙂
6. Conclusion
One of the greatest challenges in ML engineering is running inference in scale. AI is anything but lightweight, and therefore scaling inference workloads is a great pain that often ends up being very expensive or over engineered. This is exactly the challenge that Hugging Face’s Candle library tries to address. Using a Torch like API in Rust, it enables us to create a lean and fast model serving layer that can easily scale and run in a serverless environment.
In this post, I showed how we can use Candle to create an end to end model inference layer that can serve requests for vector embedding and search. I explained how we can wrap Bert / sentence transformers model in a library with a small memory footprint and use it in a REST service based on Axum.
The true value of Hugging Face’s Candle library lies in its ability to bridge the gap between powerful ML capabilities and efficient resource utilization. By leveraging Rust’s performance and safety features, Candle paves the way for more sustainable and cost-effective ML solutions. This is particularly beneficial for organizations seeking to deploy AI at scale without the overheads. I am hopeful that with Candle, we will see a new wave of ML applications that are not only high-performing but also lighter and more adaptable to various environments.
Some resources on Candle
https://github.com/huggingface/candlehttps://medium.com/@Aaron0928/hugging-face-has-written-a-new-ml-framework-in-rust-now-open-sourced-1afea2113410https://pub.towardsai.net/candle-and-falcon-a-guide-to-large-language-models-in-rust-3f0a4369df03
All source code for this post can be found in my github repo here
Streamlining Serverless ML Inference: Unleashing Candle Framework’s Power in Rust was originally published in Towards Data Science on Medium, where people are continuing the conversation by highlighting and responding to this story.