Archives
Categories
Blogroll
Having worked through the main body of Sebastian Raschka’s book “Build a Large Language Model (from Scratch)”, I wanted to try an experiment: is it possible to train a base model of my own, on my own hardware?
The book shows you how to train your LLM, does a basic training run on a small dataset, and then we switch to downloading the “pre-cooked” weights from OpenAI. That makes sense given that not every reader will have access to enough hardware to really train from scratch. And right back at the start of this series, I did some naive scaling of numbers I’d got when fine-tuning LLMs and ca…
Archives
Categories
Blogroll
Having worked through the main body of Sebastian Raschka’s book “Build a Large Language Model (from Scratch)”, I wanted to try an experiment: is it possible to train a base model of my own, on my own hardware?
The book shows you how to train your LLM, does a basic training run on a small dataset, and then we switch to downloading the “pre-cooked” weights from OpenAI. That makes sense given that not every reader will have access to enough hardware to really train from scratch. And right back at the start of this series, I did some naive scaling of numbers I’d got when fine-tuning LLMs and came to the conclusion that it would be impossible in a reasonable time.
But the speed I got with my RTX 3090 on the book’s small training run made me think that perhaps – just perhaps! – it might actually be possible to train a model of this size – about 163M parameters – on my own hardware. Not, perhaps, on a small laptop, but at least on a reasonably high-end “gaming” PC.
Additionally, Andrej Karpathy recently announced nanochat, “the best ChatGPT that $100 can buy”. He mentions on the main page that he’s trained a model called d32, with 32 Transformer layers, which has 1.9B parameters, for about $800. His smaller 20-layer d20 model, with 561M parameters, he says should be trainable in about four hours on an 8x H100 GPU node, which costs about $24/hour – hence the $100 total price.
What’s even more interesting about nanochat is that it’s built with PyTorch; initially I’d got the impression that it was based on his pure C/CUDA llm.c, which I would imagine would give a huge speedup. But no – he’s using the same stack as I have been in this series!
Karpathy’s models are both larger than 163M parameters, so it definitely sounded like this might be doable. Obviously, I’m nowhere near as experienced an AI developer, and he’s using a larger machine (8 GPUs and each of them has > 3x more VRAM than mine), but he’s also including the time to train a tokeniser and instruction fine-tune into that four hours – and his smaller model is more than three times larger than mine. So that should all help.
This post is a little less structured than the others in my LLM from scratch series, as it’s essentially a tidied version of the notes I kept as I worked through the project.
But so as not to bury the lede: using the Hugging Face FineWeb-series datasets, I was able to train a GPT-2 small sized base model to a level where it was almost as good as the original in just over 48 hours on my own hardware! Base models: not just for the big AI labs.
Here’s the full story.
The model
For this project, I want to use the exact same model code as Raschka presented in the LLM from scratch book – my copy here. There have been a number of architectural improvements to LLMs since GPT-2, but for now it’s best to keep things simple.
But there are still some settings to decide on. The config dictionary for the models we’ve been using has these parameters:
vocab_size. This is determined by the tokenizer, and I want to use the GPT-2 one, so it will need to be50257.context_length. GPT-2 has a 1,024-token context length, so I’ll stick with that.emb_dim,n_heads,n_layers— these define which of the different GPT-2 model classes we’re training, and I want to stick to the smallestgpt2-smallone, so they will be768,12and12respectivelydrop_rate. One of the most surprising things to me in the “architectural improvements” post linked above was that dropout is no longer used so much. However, this appears to be tied in to the one-epoch training that has taken off since GPT-2, so I think it would be best to stick to0.1here.qkv_bias. From what Raschka says in the book, this doesn’t add on much value, even though the original GPT-2 used it, so let’s set it toFalse.
There’s also the aspect of weight-tying – the original GPT-2 reused its embedding matrix as the weights for the linear layer that projects the context vectors from the last Transformers layer into vocab space to get the logits.
There’s nothing in the code we’ve been working with to enforce that, though – when we do our small train in the book, we’re using independent weights for each of those steps. The only time it is “enforced” is when we download the pretrained weights from OpenAI, where we put the same values into both the embedding matrix and the final output head.
Given that Raschka says that it’s in general better to avoid weight-tying, and actually doing it would be harder than not doing it, then it seems a no-brainer to not do it.
So, what does that mean about our model?
In [1]: big_train_params = {
...: "vocab_size": 50257,
...: "context_length": 1024,
...: "emb_dim": 768,
...: "n_heads": 12,
...: "n_layers": 12,
...: "drop_rate": 0.1,
...: "qkv_bias": False
...: }
In [2]: from gpt import GPTModel
In [3]: model = GPTModel(big_train_params)
In [4]: sum(p.numel() for p in model.parameters())
Out[4]: 163009536
That matches what we got when working through the book; 163M parameters. Can we train it?
The data
It seems like every AI project starts with the question “what data can we use?”
The original report on GPT-2, “Language Models are Unsupervised Multitask Learners”, is frustratingly lacking in details. However, it does say that they trained it on “8 million documents for a total of 40 GB of text”. Now, according to OpenAI, it’s reasonable to assume roughly four characters per token for typical English text. So 40 GB of text is ~10 billion tokens. That data was essentially gathered by scraping pages linked from Reddit that had more than three upvotes there, so was reasonably high quality. Can we get something similar?
Conveniently, Hugging Face host a big dataset called FineWeb, and that has a 10 billion token “sample” dataset, randomly selected from the full 18.5 trillion tokens. So the sample feels like it’s order-of-magnitude right. And while reading more about Karpathy’s nanochat, I spotted that it uses FineWeb-Edu, which is a version of FineWeb that contains “only the most educational web pages”.
I wrote a script to download both of those, and kicked it off. It took about 20 minutes for each one (slow wifi in my study, I was getting < 5MB/s); FineWeb’s 10B sample took up about 29 GiB, and FineWeb-Edu’s about 27 GiB.
Time to take a look at them. The Hugging Face datasets load_dataset function loads up all of the files you provide, and you can tell it how to split them up into train/validation/test sets. This command just loads up the whole FineWeb one and says “treat it all as the train split”, which is good enough for now:
In [1]: from datasets import load_dataset
In [2]: fw = load_dataset(
...: "parquet",
...: data_files="./fineweb/sample/10BT/*.parquet",
...: split="train"
...: )
Generating train split: 14868862 examples [01:53, 130852.34 examples/s]
Loading dataset shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [00:03<00:00, 31.90it/s]
Yikes. It took 1 minute, 53 seconds to generate the train split. However, that appears to be a one-off cost – when I accessed it again later using the same code in a different Python session, it just did the second “Loading dataset shards” portion, taking three seconds, not the generation of the split. Presumably it caches it.
Anyway, let’s see what’s in it:
In [3]: print(fw)
Dataset({
features: ['text', 'id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'],
num_rows: 14868862
})
Great, so we have 14,868,862 rows, each of which has various bits of information. Checking the first one’s text:
In [7]: print(fw[0]["text"][:500])
|Viewing Single Post From: Spoilers for the Week of February 11th|
|Lil||Feb 1 2013, 09:58 AM|
Don't care about Chloe/Taniel/Jen-Jen. Don't care about Sami, really, but hoping
that we get some good "SAMANTHA GENE!!" Marlena Death-Stares out of it. And
"newfound" feelings. Please. If only.
STEFANO!! STEFANO, STEFANO, STEFANO!!!! :cheer:
|Spoilers for the Week of February 11th · DAYS: News, Spoilers & Discussion|
Well, for FineWeb, that doesn’t look particularly “fine”, but I guess it’s better than the stuff that Karpathy talked about in his recent interview with Dwarkesh Patel:
When you’re looking at a pre-training dataset in the frontier lab and you look at a random internet document, it’s total garbage. I don’t even know how this works at all. It’s [stuff] like stock tickers, symbols, it’s a huge amount of slop and garbage from like all the corners of the internet
Let’s take a look at FineWeb-Edu.
In [8]: fw_edu = load_dataset(
...: "parquet",
...: data_files="./fineweb-edu/sample/10BT/*.parquet",
...: split="train"
...: )
Generating train split: 9672101 examples [01:32, 104057.34 examples/s]
Loading dataset shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [00:02<00:00, 48.62it/s]
In [9]: print(fw_edu[0]["text"][:500])
The Independent Jane
For all the love, romance and scandal in Jane Austen’s books, what they are
really about is freedom and independence. Independence of thought and the
freedom to choose.
Elizabeth’s refusal of Mr. Collins offer of marriage showed an independence
seldom seen in heroines of the day. Her refusal of Mr. Darcy while triggered by
anger showed a level of independence that left him shocked and stunned.
The freedom she exhibited in finally accepting him in direct defiance of Lady Cath
That looks a lot better!
Now let’s take a look at the document lengths in terms of tokens. There’s a token_count column, but I don’t know which tokeniser that’s for, so to be safe we’ll calculate it ourselves.
How long would it take to tokenise every row in FineWeb 10B to check? Let’s tokenise the first 10,000 of the 14,868,862 that we have, and see how long that would take – then we can work out the estimated time for the whole thing.
In [25]: import tiktoken
In [26]: import time
In [27]: tokenizer = tiktoken.get_encoding("gpt2")
In [28]: start = time.time()
...: for entry in fw.select(range(10_000)):
...: tokenizer.encode(entry["text"])
...: end = time.time()
In [29]: end - start
Out[29]: 1.4528205394744873
In [30]: fw
Out[30]:
Dataset({
features: ['text', 'id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count'],
num_rows: 14868862
})
In [31]: (14868862 / 10_000) * 1.4528205394744873
Out[31]: 2160.1788112211702
2,160 seconds or about 36 minutes. Yikes!
After a bit of digging, though, I found that tiktoken tokenisers can handle batches (poorly documented, but it’s there in the source):
In [45]: text_batch = ["a", "b", "c"]
In [46]: tokenizer.encode_batch(text_batch)
Out[46]: [[64], [65], [66]]
Also, we can map a function over an entire HF dataset, and that can be made to run with multiple processes. So, we can combine the two:
In [47]: import os
In [53]: def add_len(examples):
...: texts = [t or "" for t in examples["text"]]
...: tokens = tokenizer.encode_batch(texts, disallowed_special=())
...: return {"tok_len": [len(t) for t in tokens]}
...:
In [54]: start = time.time()
...: fw_with_len = fw.map(
...: add_len,
...: batched=True,
...: batch_size=1024,
...: num_proc=os.cpu_count(),
...: )
...: end = time.time()
Map (num_proc=24): 100%|████████████████████████████████████████████████████████████████████████████████████████████| 14868862/14868862 [03:15<00:00, 75869.33 examples/s]
Just over three minutes, not too bad! (The reason the command count above jumps from 47 to 53 was that in the first run I didn’t have the disallowed_special=() in there – one of the rows in the dataset had <|endoftext|> in it, and the tokenizer rejected it. I’m going to play fast and loose and ignore that for now.)
Now let’s see how it added it:
In [56]: fw_with_len[0].keys()
Out[56]: dict_keys(['text', 'id', 'dump', 'url', 'date', 'file_path', 'language', 'language_score', 'token_count', 'tok_len'])
In [57]: fw_with_len[0]["tok_len"]
Out[57]: 142
In [58]: len(fw_with_len["tok_len"])
Out[58]: 14868862
In [59]: fw_with_len["tok_len"][0]
Out[59]: 142
Cool! We’ve added a tok_len column with the number of GPT-2 tokens for each row, and we can extract what amounts to a list of those values. Let’s plot them as a histogram.
Trying to do it directly – that is, just doing
ax.hist(fw_with_len["tok_len"], bins=bins)
...seems to make MatPlotLib very unhappy, and my interpreter crashed with an OOM – I think it might be trying to load all of the dataset – text, IDs, etc – into RAM in one go.
So I started a fresh one and did the stuff to load it and annotate it with token lengths again – weirdly, this time the mapping only took 10 seconds or so! That was strange, I’ll need to look into that. Perhaps the earlier command added the tok_len column to the files on disk?
To work around the memory issue, I converted the tok_len column from the dataset to an actual list:
In [11]: lengths = [n for n in fw_with_len["tok_len"]]
That took ten or twenty seconds. Let’s then try the plot again (full code this time):
In [19]: import numpy as np
...: import matplotlib.pyplot as plt
...:
...: bins = np.arange(0, 2048 + 16, 16)
...:
...: plt.xkcd()
...: plt.rcParams['font.family'] = "xkcd"
...: fig = plt.figure(figsize=(10, 6))
...: ax = plt.gca()
...:
...: ax.hist(lengths, bins=bins)
...: ax.set_xlabel("TOKENIZED LENGTH (GPT-2 TOKENS)")
...: ax.set_ylabel("COUNT")
...: ax.set_title("FINEWEB DISTRIBUTION OF TOKENIZED LENGTHS")
...:
...: mean_len = float(np.mean(lengths))
...: median_len = float(np.median(lengths))
...: h_mean = ax.axvline(mean_len, linestyle="--", label=f"MEAN = {mean_len:.1f}")
...: h_med = ax.axvline(median_len, linestyle=":", label=f"MEDIAN = {median_len:.1f}")
...: ax.legend(handles=[h_mean, h_med])
...:
...: ax.grid(True, axis="y", alpha=0.3)
...: plt.tight_layout()
...: plt.savefig("fineweb-token-length-distribution.png")
That took about 11s to run, and the result is this:

That’s really promising! The bulk of them are less than our 1,024 token sequence length. 1 If we present each row in the dataset as a stand-alone training sample, cropping them when necessary, perhaps we won’t lose too much data? Let’s see.
First step, how many tokens are there in total?
In [20]: sum(lengths)
Out[20]: 10336315397
Nice, about 10B, as expected. How many tokens would we have if we cropped them to the default GPT-2 context length of 1,024?
In [21]: sum(l if l < 1024 else 1024 for l in lengths)
Out[21]: 7354541756
Ouch, 7.3B. That’s quite a reduction:
In [22]: 7354541756 / 10336315397
Out[22]: 0.7115245107685639
So we’re losing 29% of our tokens by that cropping. That’s from curtailing just 16% of the sequences:
In [26]: len([l for l in lengths if l > 1024])
Out[26]: 2438899
In [27]: len(lengths)
Out[27]: 14868862
In [28]: 2438899 / 14868862
Out[28]: 0.1640272806351959
That’s not great.
I feel that we have two options here:
- Crop all of the input sequences – that is, each row in the dataset – so that each one is no more than our 1,024 sequence length. Then we can pad them out with end-of-sequence tokens (as is the standard) so that they’re all 1,024. This will lose us quite a lot of tokens, but has the big benefit of being easy.
- Treat the corpus as, essentially, one long document, with end-of-sequence delimiters between each row, then split that up into 1,024-token sequences. Doing it this way would mean we’d use all of our training data. But it would be more complicated, especially if we hit memory constraints.
At this point in the experiment, I’m going to keep both options open. I’m inclined towards the latter (I believe it’s closer to what the real GPT-2 train did), but I’m not sure.
Anyway, we’re scoping things out here, so let’s move on.
Epochs
After looking at the data, I’ve thought a bit more about this. I’d previously been thinking in terms of training across all of the tokens in the dataset; we’d work our way through the 10B tokens, and then we’d be done.
But when training a model, you do multiple epochs, normally – you run through the dataset once, updating your gradients as you go, then run through it again likewise, and eventually you stop when your validation loss starts rising.
I think that because I’d read that LLMs are normally trained on just one epoch these days, I’d kind of internalised that we only need to do one. But it wasn’t the case in 2019 when GPT-2 came out. They had less data – just 10B tokens or so, compared to insanely huge datasets like the full FineWeb (not the 10B one we’ve been looking at – the 18.5T full one), so they would have trained it for some number of epochs.
How many? That’s another case where the GPT-2 paper is annoyingly light. This report says in the “Replicating GPT-2” section that OpenAI trained it for 800k iterations with a batch size of 512. Plugging in a sequence length of 1024, that gives us this many tokens:
800,000×512×1,024=419,430,400,000
Over 419B tokens!
Now, if we believe that their dataset was 10B tokens, then we can work out how many epochs that came to:
419,430,400,000/10,000,000,000=41.94
The same report says that they – as in, the report authors – make that “around a total of 60 epochs through the training set” – I believe that the training set they’re talking about could well be slightly shorter than the original GPT-2 one – the GPT-2 authors didn’t release their own, which is called “WebText”, so the report’s author is using a different one that tries to replicate it, OpenWebText.
That sounds expensive; even without knowing how many tokens per second we can train for, 40-odd epochs of 10B tokens each sounds like it would take a long time. Are there any other comparison points that might tell us how long to train for?
Well, there’s a “Chinchilla heuristic” that I’ve heard of, which says that you should train on about 20 tokens per model parameter. I spent some time reading into where that comes from; originally it’s in “Training Compute-Optimal Large Language Models” from Google DeepMind, and it’s an interesting paper, and is surprisingly easy to read, with a few bits of maths that get a bit hairy (but aren’t required to get a good-enough feel for what they’re saying). I recommend you take a look.
It was written in 2022, and the authors felt that people were scaling up models a lot, but weren’t increasing the number of tokens that they used for training enough. So, they trained a huge number of models, trying to answer the question: “given a particular budget in training FLOPs, what is the optimal balance of training tokens versus parameters to make sure you’re using those FLOPs most efficiently?”. They were arguing against the method taken in a particular paper, where another team had trained a model (called Gopher) on significantly fewer tokens than they thought optimal.
The number of FLOPs used to train a model is linear with both the number of parameters and the number of tokens you train it on, so if you get 2x the number of FLOPs that you had before, you can either train the same model on twice as many tokens, or you can double its size. Which is better? Their conclusion was that you should actually scale both parameters and tokens up by the same amount – that is, in the 2x case you’d want to have 2 times both the parameters and tokens, which would double your FLOPs and get you better performance.
As you can probably see, it’s pretty clear that by doing this, doing this indirectly worked out an optimal number of tokens to train a particular size of model for; they don’t state the “20x” heuristic themselves, but it’s pretty clear in table 3 in the paper, where they give a number of model sizes and the optimal number of tokens for each.
Now, this number is not the number of tokens you need to train for to get the best model you can for a particular number of parameters; a model of a given size can always be trained more and will (hopefully) get better. But it tells you when you’ve trained on enough tokens that you could get better results by training a larger model than you have right now.
They’re implicitly assuming that models can get as large as you want, which of course is not the case – in reality, you’re going to be targeting a particular model size, the size that can fit on your training hardware (or more likely with production models, the size that can fit on your planned inference hardware).
But interestingly, looking at the README.md for Karpathy’s nanochat project, he trained his 1.9B “d32” model on 38B tokens – exactly 20x. And if you look at the speedrun.sh script in the same repo, he explicitly says that he’s training for 20x parameters for the smaller d20 model:
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
If Andrej Karpathy thinks that training for Chinchilla-optimality is the right way to go, then who am I to disagree? ;-)
More seriously, perhaps the better quality of the dataset makes this a reasonable thing to do. From the GPT-2 paper, their description of how they got the data:
...we created a new web scrape which emphasizes document quality. To do this we only scraped web pages which have been curated/filtered by humans. Manually filtering a full web scrape would be exceptionally expensive so as a starting point, we scraped all outbound links from Reddit, a social media platform, which received at least 3 karma. This can be thought of as a heuristic indicator for whether other users found the link interesting, educational, or just funny.
That’s a clever trick, but I believe that FineWeb is much more carefully filtered and improved than the WebText dataset they got from that. Back in 2019, they had to do everything from scratch – find appropriate ways to get data, filter it, and so on. Now we can just download stuff from Hugging Face. So maybe Chinchilla-optimal is enough.
Anyway, we have 163,009,536 parameters, so on that basis, let’s train for:
163,009,536×20=3,260,190,720
...tokens. (I’ll just use 3.2B from now on, but that’s the actual number I mean.)
That’s pretty cool! We have more than that number of tokens already in our FineWeb 10B sample, so we can do a single-epoch training run.
So the question is – is that even doable on my hardware?
Tokens per second
It all hinges on how many tokens per second we can train at. A good way to check this is to write a throwaway “trainer”. We can use that to work out what our maximum batch size on the RTX 3090’s 24 GiB of VRAM, then run a bunch of batches through – a forward and backward pass for each – and see how many we get.
This won’t estimate how much time we’ll spend validating the model, of course. But my gut is telling me that we should spend no more than 5% of our training time running validations, so we can later on do a similar test, eval mode, forward pass only with no gradient tracking, and use that to work out how many tokens should be in the validation set.
So, let’s estimate training speed. This code gets an estimate of tokens/second at different batch sizes. Hopefully it’s clear enough to not need an in-depth explanation. An outline:
- We load enough GPT-2 tokens from FineWeb for
NUM_BATCHESbatches ofMAX_BATCH_SIZEsequences each, every one of those sequences beingSEQ_LENGTHlong (plus one extra token for the targets we’re comparing them to). Note that we’re not bothering to separate them with anything for this test. - We then loop over batch sizes from
1toMAX_BATCH_SIZE. - Then we create our model and put it on the CUDA device. We do this for each batch size rather than creating one and then using it for all of them so that they’re all starting from the same point – the
torch.manual_seedshould make sure that they’re identical. - For each batch size, we create input and output batches as tensors – note that we’re not putting these on CUDA yet, I wanted to do that in the training loop to mirror what a real training loop will have to do. When we’re training with 3.2B tokens then having them all on CUDA will be a waste of VRAM, so we’ll be pushing a batch there for each iteration.
- We do a stripped-down training loop – for each batch, put the inputs and outputs onto CUDA, then a forward pass, work out the loss, backward pass, and optimiser step. We do the same
NUM_BATCHESiterations per batch size. - Finally, we print out the number of tokens we trained on for this batch size, how long it took, and the number of tokens per second.
Here’s what it prints out:
Loading dataset shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [00:00<00:00, 362.71it/s]
Testing with batch size 1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00, 9.77it/s]
Done, trained on 102,400 tokens in 10.2348s.
Tokens per second: 10,005
Testing with batch size 2
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00, 5.60it/s]
Done, trained on 204,800 tokens in 17.8631s.
Tokens per second: 11,464
Testing with batch size 3
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00, 3.93it/s]
Done, trained on 307,200 tokens in 25.4152s.
Tokens per second: 12,087
Testing with batch size 4
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00, 3.02it/s]
Done, trained on 409,600 tokens in 33.1185s.
Tokens per second: 12,367
Testing with batch size 5
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:40<00:00, 2.46it/s]
Done, trained on 512,000 tokens in 40.6351s.
Tokens per second: 12,599
Testing with batch size 6
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/giles/Dev/llm-from-scratch/measure-tokens-per-second.py", line 89, in <module>
main()
~~~~^^
...
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.15 GiB. GPU 0 has a total capacity of 23.56 GiB of which 269.19 MiB is free. Including non-PyTorch memory, this process has 20.99 GiB memory in use. Of the allocated memory 18.67 GiB is allocated by PyTorch, and 2.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
So we can see that it gets faster as we increase the batch size, which makes sense because we’re handling sequences in parallel, but it does flatten off a bit, which makes sense because there’s a limit to how much parallelism we can do, even on a GPU.
Let’s see how that fits in with the different training sizes we looked at above:
- Chinchilla heuristic, 20x parameters – 3.2B tokens: 247,850 seconds, which is just less than three days
- Estimated GPT-2 train, 419B tokens: 32,452,947 seconds, which is just over a year.
OK. We’re definitely not going to be able to train this thing the GPT-2 way! I expected that to be the case, but now we have a solid proof of that.
But the three-day Chinchilla-optimal train actually sounds doable! I’m heading to London to visit family soon, so won’t be using my home PC. With a bit of help from Tailscale I’ll be able to log into it from my laptop, though, so I can potentially nurse a run through.
Can we make it any faster?
Now, when doing the fine-tuning work, I found that you could generally speed things up by doing everything in 16-bit rather than 32-bit. Intuitively that makes sense – lower-precision numbers, fewer bits, means less work for the GPU doing the various multiplications and additions that are involved in our train.
Working with ChatGPT, I found a couple of ways to take advantage of that. Firstly, using TF32.
The normal float32 format uses 8 bits for the exponent, and 23 for the mantissa. If you haven’t looked into how floats are represented in memory (or if you’ve forgotten), that means that, using m to mean the mantissa and x the exponent, the numbers are represented in memory as
m×2x
TF32 is messier; it has the same exponent size – and thus the same range – as float32, but it essentially ignores the lower 13 bits of the mantissa. So it takes up the same amount of memory, but is lower-precision, which means that calculations can be faster. Most importantly, cards like the RTX 3090 have dedicated “tensor cores” – as opposed to the normal CUDA cores that do normal matrix multiplications – and they operate in TF32. Unsurprisingly, “TF32” is “tensor float 32-bit”.
The PyTorch set_float32_matmul_precision allows you to tell it what precision to use for matrix multiplications; the default is "highest", which means “use float32 all of the time”, so you’re stuck using just the CUDA cores. If, instead, you set it to "high", then it will use TF32 if the hardware supports it and it has the appropriate kernels available. So that will let us use the tensor cores.
I added this to the code above just above the loop over the different batch sizes:
torch.set_float32_matmul_precision("high")
Let it run, and:
Testing with batch size 1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.66it/s]
Done, trained on 102,400 tokens in 8.5799s.
Tokens per second: 11,934
Testing with batch size 2
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:15<00:00, 6.65it/s]
Done, trained on 204,800 tokens in 15.0287s.
Tokens per second: 13,627
Testing with batch size 3
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:20<00:00, 4.85it/s]
Done, trained on 307,200 tokens in 20.6374s.
Tokens per second: 14,885
Testing with batch size 4
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:27<00:00, 3.61it/s]
Done, trained on 409,600 tokens in 27.7148s.
Tokens per second: 14,779
Testing with batch size 5
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00, 3.01it/s]
Done, trained on 512,000 tokens in 33.2420s.
Tokens per second: 15,402
That’s a 22% speedup! Of course, the precision of the training isn’t as good. But given that many modern models are trained at 16-bit (I’ve seen suggestions that some are even trained as low as 4-bit) then that shouldn’t matter.
Let’s see whether we can train in 16-bit instead. PyTorch has a smart mode where you can tell it “use 16-bit where it makes sense, otherwise use 32-bit” – AMP, which stands for “Automatic Mixed Precision”. There’s a great recipe for how to use it in the docs, so let’s use that. We need to create a Scaler object to handle scaling parameters from 16-bit to 32-bit as needed – we can re-use that across all batch sizes so we can create it just before the loop:
scaler = torch.amp.GradScaler()
...then we need to replace this core part of our training loop:
logits = model(inputs)
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), outputs.flatten()
)
loss.backward()
optimizer.step()
...with some code to use AMP and that scaler – basically we use a context manager to switch it on when we’re doing the forward pass and work out the loss, and then use the scaler to manage the backward pass and the optimiser’s step:
with torch.amp.autocast(device_type=device.type, dtype=torch.float16):
logits = model(inputs)
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1), outputs.flatten()
)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Running that gives us these results:
(llm-from-scratch) giles@perry:~/Dev/llm-from-scratch (main)$ python measure-tokens-per-second.py
Loading dataset shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [00:00<00:00, 340.25it/s]
Testing with batch size 1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.38it/s]
Done, trained on 102,400 tokens in 7.4764s.
Tokens per second: 13,696
Testing with batch size 2
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00, 8.11it/s]
Done, trained on 204,800 tokens in 12.3286s.
Tokens per second: 16,611
Testing with batch size 3
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:16<00:00, 6.02it/s]
Done, trained on 307,200 tokens in 16.6238s.
Tokens per second: 18,479
Testing with batch size 4
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:21<00:00, 4.67it/s]
Done, trained on 409,600 tokens in 21.3936s.
Tokens per second: 19,145
Testing with batch size 5
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00, 3.87it/s]
Done, trained on 512,000 tokens in 25.8624s.
Tokens per second: 19,797
Testing with batch size 6
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00, 3.25it/s]
Done, trained on 614,400 tokens in 30.7239s.
Tokens per second: 19,997
Testing with batch size 7
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/giles/Dev/llm-from-scratch/measure-tokens-per-second.py", line 94, in <module>
main()
Wow! With that we can train on 3.2B tokens in about 160,000 seconds, which is 44 hours. That’s definitely doable.
Now, what happens if we remove the
torch.set_float32_matmul_precision("high")
...so that we’re using AMP, but not the tensor cores?
(llm-from-scratch) giles@perry:~/Dev/llm-from-scratch (main)$ python measure-tokens-per-second.py
Loading dataset shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 102/102 [00:00<00:00, 365.94it/s]
Testing with batch size 1
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.03it/s]
Done, trained on 102,400 tokens in 7.6736s.
Tokens per second: 13,344
Testing with batch size 2
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:12<00:00, 8.04it/s]
Done, trained on 204,800 tokens in 12.4383s.
Tokens per second: 16,465
Testing with batch size 3
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:16<00:00, 5.96it/s]
Done, trained on 307,200 tokens in 16.7851s.
Tokens per second: 18,301
Testing with batch size 4
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:21<00:00, 4.64it/s]
Done, trained on 409,600 tokens in 21.5571s.
Tokens per second: 19,000
Testing with batch size 5
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:25<00:00, 3.85it/s]
Done, trained on 512,000 tokens in 25.9610s.
Tokens per second: 19,721
Testing with batch size 6
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00, 3.24it/s]
Done, trained on 614,400 tokens in 30.8405s.
Tokens per second: 19,921
Testing with batch size 7
0%| | 0/100 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/giles/Dev/llm-from-scratch/measure-tokens-per-second.py", line 93, in <module>
main()
~~~~^^
File "/home/giles/Dev/llm-from-scratch/measure-tokens-per-second.py", line 81, in main
It’s basically the same. 300tps slower at the start, down to 70 at the end. Still, it looks better to keep the “high” precision in place, rather than the “highest”.
Right. We have the beginnings of a training loop that should be able to let us run a Chinchilla-optimal train on a GPT-2 small sized model in 44 hours, and I have the time to do it. And it looks like a batch size of six is what we can fit into the RTX 3090’s 24 GiB of VRAM.
What else are we going to need to build something to do this?
Checkpointing
If I want to do a long training run, then stuff might go wrong – it might crash for some reason. So we’re going to need to save checkpoints as we go and be able to restart training from those checkpoints.
In those, we’re going to need to save the model and the optimiser’s state, plus some kind of info about how far through the dataset we are. We should keep training and validation losses too, so that we can easily chart and recover our progress, and according to this forum post we’re going to need to save the scaler (which makes me think that it actually has state in it, so we probably should have used a fresh scaler for each batch size in the above – let’s hope that doesn’t prove to be a problem [note from later: it wasn’t]).
I wrote a script to create a model, train it for a bit, and then dump out all of that apart from the metadata (which I reckon is going to be less than 1kB). I wanted to use the safetensors format for all of it, but unfortunately I couldn’t get it to work for the optimiser or the scaler, so had to use torch.save for those (which I don’t like because it uses pickle, which introduces serious problems if you ever want to move files from machine to machine, as the Python and library versions need to match perfectly). Ah well. Here’s what the test checkpoint looks like:
(llm-from-scratch) giles@perry:~/Dev/llm-from-scratch (main)$ du -sh test-checkpoint
1.9G test-checkpoint
(llm-from-scratch) giles@perry:~/Dev/llm-from-scratch (main)$ ls -lh test-checkpoint
total 1.9G
-rw-r--r-- 1 giles giles 670M Nov 11 15:21 model.safetensors
-rw-r--r-- 1 giles giles 1.3G Nov 11 15:21 optimizer.pt
-rw-r--r-- 1 giles giles 1.4K Nov 11 15:21 scaler.pt
That’s huge! And it’s almost all the optimiser. From what I read, that stores two numbers per parameter, so it makes sense that it’s double the size of the model weights. And at 32-bit, 4 bytes per param, then 670MiB for the model is sane.
Timing-wise, it takes about a second to save, the same to load, so that’s fine.
So that sounds reasonable in terms of timing, and disk space is pretty high, but not so huge that it can’t be managed with careful planning – don’t checkpoint so much that we run out of disk during the train (I have a 2TiB disk, but it’s far from empty).
It’s p