Building My Own LLM in Rust: A Wild Ride
I decided to roll my own Large Language Model from scratch. Why? Because I wanted to see how close I could get to the big dogs like PyTorch, how hard it would actually be to build a state-of-the-art model (minus the training), and how resource-efficient I could make it.
Spoiler: I got humbled.
What’s Gemma3, Anyway?
Gemma3 is a model developed by Google. It’s built from several key components:
- FeedForward – A fully connected neural layer that processes each token’s embedding independently. In Gemma3, it applies non-linear activations (like GeLU) to hidden states, capturing complex patterns. Think of it as the muscle for feature extraction.
- TransformerBlock – The hear…
Building My Own LLM in Rust: A Wild Ride
I decided to roll my own Large Language Model from scratch. Why? Because I wanted to see how close I could get to the big dogs like PyTorch, how hard it would actually be to build a state-of-the-art model (minus the training), and how resource-efficient I could make it.
Spoiler: I got humbled.
What’s Gemma3, Anyway?
Gemma3 is a model developed by Google. It’s built from several key components:
- FeedForward – A fully connected neural layer that processes each token’s embedding independently. In Gemma3, it applies non-linear activations (like GeLU) to hidden states, capturing complex patterns. Think of it as the muscle for feature extraction.
- TransformerBlock – The heart of the model. It combines feedforward layers, attention, and normalization. Each block refines the input sequence representation by balancing context (attention) and transformation (feedforward).
- Grouped Query Attention (GQA) – A more efficient form of attention that splits queries into groups, cutting down on memory and compute while keeping performance high.
- RMSNorm – A normalization layer that uses the root mean square of activations instead of subtracting the mean (like LayerNorm does). This stabilizes training and makes gradients smoother.
- RoPE (Rotary Positional Encoding) – Encodes token positions through vector rotations instead of fixed embeddings. This helps the model handle long sequences without blowing up memory.
- Masking – Classic causal masking ensures the model only attends to past tokens during training and inference. No peeking into the future, no cheating.
- Affine Linear Transformations – The standard
y = Wx + boperation used throughout, for projecting data into new spaces or adjusting dimensions.
Building It, One Trait at a Time
I started small. One piece at a time. I studied Google’s open-source code and adapted it into Rust. I used an approach inspired by PyTorch, mainly using a Module trait pattern. The implementation of the trait looks like this:
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
}
impl From<InitParams> for Linear {
fn from(params: InitParams) -> Self {
Self::init(params)
}
}
pub struct InitParams {
bias: bool,
in_features: usize,
out_features: usize,
seed: Option<u64>,
}
impl InitParams {
pub fn new(bias: bool, in_features: usize, out_features: usize, seed: Option<u64>) -> Self {
Self {
bias,
in_features,
out_features,
seed,
}
}
}
pub struct LinearLoader {
weight: Tensor,
bias: Option<Tensor>,
}
impl Module for Linear {
type InitParams = InitParams;
type ForwardParams<'a> = &'a Tensor;
type Loader = LinearLoader;
fn init(params: Self::InitParams) -> Self {
Self {
weight: Tensor::rand(&[params.out_features, params.in_features], params.seed),
bias: params
.bias
.then(|| Tensor::rand(&[1, params.out_features], params.seed)),
}
}
fn forward<'a>(&mut self, params: Self::ForwardParams<'a>) -> Result<Tensor, TensorError> {
let result = params.matmul(&self.weight.transpose(0, 1));
match self.bias.as_ref() {
Some(b) => &result? + b,
None => result,
}
}
fn load(loader: Self::Loader) -> Self {
Self {
weight: loader.weight,
bias: loader.bias,
}
}
}
Math was mathing
Most of the math was straightforward, you can read the formulas, check how PyTorch does it, and fill in the gaps. I used test-driven development (TDD) with reference results from PyTorch: generate known outputs, test against them, then fix your math until it matches.
Broadcasting and matrix multiplication (for tensors of different shapes) were the tricky part. PyTorch’s automatic broadcasting led to a few “why isn’t this working” moments early on. But once I found out about this broadcasting business, I just needed to adapt my implementation, add new tests cases and voilà.
Unlike PyTorch’s reliance on global random states, I pass an explicit seed for deterministic random number generation. That made testing waaaay much simpler. With pytorch you have to set the random seed for numpy and pytorch itself, and even doing so the results weren’t always the same.
Masking Mayhem
The mask implementation was… rough. I didn’t have a great idea for how to do it elegantly. My version just uses a big vector of 0s and 1s, not efficient, since my tensors only support f32. So yeah, I use floating points to store booleans. Dumb and wasteful, but it works.
There’s definitely a better way (maybe something clever with bit manipulation), but for now, it’s good enough.
Performance Pain
The biggest bottleneck? matmul.
Even after adding multithreading, it barely helped. Multithreaded math isn’t hard, you don’t really need to worry about locks or things like this. But even still my implementation is slow.
Right now, my model runs about 100× slower than PyTorch and eats 1.5–2 GB of RAM, mostly from constant tensor allocations. No memory reuse. No caching. No mercy.
I could probably fix that by introducing a custom allocator for tensor reuse, PyTorch uses something similar (tensor recycling).
Little Wins
Despite all the pain, there were highlights. I managed to beat PyTorch’s erf implementation. Yep, really.

I also tried the legendary inverse square root trick (à la Quake), but Rust’s built-in SIMD matched it, no real speedup. Either I implemented it wrong, or the Quake devs were just wizards.
Loading Models (a.k.a. The Lazy Part)
I didn’t bother writing my own model file format. Instead, I used Safetensors (secure tensor format). A few traits, a couple of structs, and boom, Gemma3 loaded.
Sometimes laziness pays off.
What’s Next?
Training loop: I want to train something, even if it’s just on a tiny dataset.
Memory optimization: Too many temporary tensors, I need in-place ops and/or better allocation. I could write a custom allocator, this could be fun.
BLAS / SIMD: Currently using dumb CPU ops. Adding BLAS or SIMD could speed things up dramatically. GPU offload would be even better.
Benchmarks & evals: More systematic benchmarks against PyTorch, and maybe a fun eval (who knows, maybe Rust makes it smarter).
Final Thoughts
PyTorch absolutely kicked my ass. But building an LLM in Rust was easier and way more fun than I expected. I learned a ton about how these models actually work under the hood.
The code’s ugly, slow, and probably inefficient… but it works.
Got ideas to make it faster? Fork it, or drop me a message.