Archives
Categories
Blogroll
I’m carrying on with my "extra credit" projects after finishing the main body of Sebastian Raschka’s book "Build a Large Language Model (from Scratch)". Having proven that I could train a GPT-2 small scale base model from scratch on my RTX 3090 in 48 hours, I wanted to try training it on a multi-GPU machine on Lambda Labs. There are two benefits I see in doing that:
- I can learn what you need to change in a simple single-GPU training loop to make it multi-GPU.
- I…
Archives
Categories
Blogroll
I’m carrying on with my "extra credit" projects after finishing the main body of Sebastian Raschka’s book "Build a Large Language Model (from Scratch)". Having proven that I could train a GPT-2 small scale base model from scratch on my RTX 3090 in 48 hours, I wanted to try training it on a multi-GPU machine on Lambda Labs. There are two benefits I see in doing that:
- I can learn what you need to change in a simple single-GPU training loop to make it multi-GPU.
- If I can get the training time for a full base model down from 48 hours to something more manageable (and hopefully not too expensive) – then I can try a few experiments to see how I can improve the quality of the trained model. I have a bunch of ideas about why my own base model wasn’t as good as the original OpenAI one, and it would be good to know which (if any) of them are right.
In addition, I wanted to see if anything unexpected dropped out of it; after all, there were four different sizes of machines that I wanted to try, so I’d be doing four from-scratch trains on the same dataset. Does the machine size affect the quality of the model in some way?
Here’s what happened. As with the last post, this is a set of tidied-up lab notes, so you can see the full journey. There’s a lot to it! I was considering splitting it into multiple posts – "writing the code", "building the datasets", "running the trains" – but they’re interleaved. Each train taught me something about how to structure the code to make it easier to use, so the code kept changing.
So I think it’s worth documenting the process as it really was. If at some point I want to write a how-to document on porting single-GPU code to multi-GPU, I’ll be able to mine this for resources, and in the meantime, hopefully this will be of use to readers – even if it’s just at the level of "I got this error message, how do I fix it?"
Anyway, once again I don’t want to bury the lede, so: after spending US$215.16 on various trains on various servers, I was able to find that a reasonably cheap instance on Lambda Labs, with 8x A100 GPUs, each of which has 40 GiB of VRAM, is the sweet spot for this particular 163M-parameter, ~Chinchilla-optimal single-epoch run. They can train the model in less than four hours, they happen to be the right size for batches that minimise loss (more on that later), and can do that train for about US$35, excluding validation.
If you’d like to read the gory details of what I did, then read on – but if you prefer, you can jump straight to the results.
Which multi-GPU technique?
Back when I was messing around with fine-tuning LLMs using the Hugging Face ecosystem – their "Transformers" library and so on – one of the experiments I did was to fine-tune a 0.5B Qwen model on an 8x GPU machine. As part of that, I came across this excellent HF page summarising different kinds of multi-GPU training techniques. The three that are relevant are:
- DataParallel (DP). With this:
- The default GPU (normally
gpu0) is in charge of the process. It gets a batch of data, divides it up into per-GPU "micro-batches", and sends each of those to a thread for each of the other GPUs. - It then sends an up-to-date version of the model to each GPU.
- Next, all of the per-GPU threads do a forward pass on their replica using their specific micro-batch, and send their outputs to the thread for the default GPU.
- The default GPU thread aggregates all of those outputs (similarly to how the losses across all of our batches and the prefix sequences are aggregated in the normal single-GPU case) to work out an overall loss.
- It then does a backward pass. This will start on the default GPU, as the aggregation step is the first thing that it will come to when going backwards through the steps that came up with that overall loss. However, it will then come to operations that happened on the other GPUs and those are (somehow) parallelised.
- Once that is done, each GPU has gradients that represent how their copies of the model contributed to the overall loss.
- Finally, they send those gradients back to the default GPU, which combines them (I think of this as just being an average, though I gather it’s more complex) and applies them, producing an updated model.
- Then the process repeats; the updated model on the default GPU will be sent to the other GPUs in the second step of the next iteration.
- DistributedDataParallel (DDP). This does less work on the default GPU and does less copying around. Each GPU has its own process (rather than thread), and is essentially responsible for its own training loop. Right at the very start, the default GPU’s process sends the model to all of the others. Then all processes go into their training loop:
- Firstly, each one works out its own micro-batch (which means you need to have code to make sure that the datasets are properly split across the GPUs)
- Each model does its own forward pass, then its own backward pass, working out its own independent gradients.
- As it comes up with those gradients, it broadcasts them to a "reducer", which handles the aggregation. This is done in a distributed way – there’s not just one reducer handling everything.
- When all models have completed the backward pass, the reducer has a set of combined gradients, which is visible from the per-GPU processes.
- Each GPU process does its own optimizer step using those combined gradients.
- That means that there’s no model copy required – each GPU has applied the same gradient update, so they already have in-sync models, assuming everything went well.
- ZeRO. This is a much more complex system, and I went into how it works in this blog post.
Now, from what I understand, due to all of the copying around of models, plus the issues inherent with the GIL in Python, DDP is actually better than DP despite being more complicated – and more flexible! Per Hugging Face:
DDP is recommended because it reduces communication overhead between GPUs, efficiently utilizes each GPU, and scales to more than one machine.
It might be a while before I want to try multi-machine training, but it would be awesome to have code that’s ready to do that without needing any extra work.
Now, how to implement it?
Implementing DDP for our model.
Hugging Face have a library called Accelerate, which does everything for you:
Accelerate is a library that enables the same PyTorch code to be run across any distributed configuration by adding just four lines of code!
That does sound very useful, but I worry that by using it I won’t learn as much. It also rather ties you in to the HF ecosystem. That’s not necessarily a bad thing – I enjoyed using their stuff in my fine-tuning project – but I’m trying for a somewhat lower-level view in this series.
So, let’s use the PyTorch-native stuff. There’s a "getting started" tutorial, so we can follow that.
It has two options for running using DDP, one with a bit of extra setup code – the first example, under "Basic Use Case" – and one that uses torchrun to make things easier. The second sounds best.
The code changes actually look really simple; given a normal single-GPU training script, you need to do some setup at the start:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# ...
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
dist.init_process_group(backend)
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.accelerator.device_count()
...then wrap the model itself in a DDP object, which is what you actually do the train on:
model = ToyModel().to(device_id)
ddp_model = DDP(model, device_ids=[device_id])
...and a bit of teardown at the end:
dist.destroy_process_group()
The way to look at this is that torchrun will spin off one process per GPU, each running exactly the same code. They have a "rank", which is an integer saying which of the per-GPU processes they are – 0 for GPU 0, 1 for GPU 1, and so on. There’s a bit of a gotcha here, though – you can see that we’re looking at an environment variable called LOCAL_RANK at the start, but we then get a (non-"local") rank variable from torch.distributed a bit later on. This is due to the multi-machine possibilities with DDP – if you have multiple machines, then the local rank will be "which GPU on the machine does this process relate to", but there will also be a "global" rank, which is unique across all machines. This distinction won’t matter that much during this one-machine test, but it’s worth keeping in mind if we want to keep the code in a shape where it could potentially scale to multiple machines.
Anyway, after the processes are spun up, they will do their training, and the synchronisation and passing around of gradients during the backward pass will all happen invisibly in the background, so when we do our optimizer.step(), it will have the full set of gradients.
Now that means that we’ll presumably also need to use the rank – that is, which of the n per-GPU processes the current code is running in – when selecting which dataset items to train on. More about that later.
Let’s start writing some code! I’ll use a new repo, into which I can put just the code needed for this train. I’ll also structure it a little better than last time, with separate "runs", each of which has a model config and training parameters, and will later on have its own checkpoints. You can think of these as being one per machine size that I’m trying out – I’ll create a run directory for each one.
Here’s a first cut, simply loading up a model config from a run’s directory, using it to create the model, and then doing the wrapping above – no training at all. Running it with torchrun (and uv, as I’m using that for all new projects):
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uv run torchrun ddp_train.py original
On rank 0.
Promising. Now, unfortunately we only have one GPU locally, and the code assumes that it’s one process per GPU (I believe that’s a hard limitation for PyTorch’s DDP), so running with --nproc_per_node=2 blows up. So we can’t do an in-depth test locally.
But at least we know that the basic infra is there and working.
Now let’s move the other training code from the single-GPU script into that file, pretty much blindly. This is the result – it’s doing almost nothing beyond what the last train did, apart from wrapping the model in a DDP object – the only other changes are to use this "runs" directory that we’ve introduced.
As a quick hack, we should try running it. It does a validation and checkpoint before it starts, and we can make that happen quickly by hacking the validation loop to only do a couple of iterations:
for val_inputs, val_targets in tqdm(val_ds[:2]):
(Foreshadowing: that hack will come back to haunt us later!)
Running that, then hitting control-C after the validation completes, and it looks OK:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uv run torchrun ddp_train.py original
On rank 0.
Starting training at dataset offset 0
0%| | 0/530630 [00:00<?, ?it/s]Validation/checkpoint
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10.95it/s]
Continuing training█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10.96it/s]
0%| | 18/530630 [00:06<45:20:54, 3.25it/s]^CW1203 18:34:11.363000 471545 torch/distributed/elastic/agent/server/api.py:725] Received 2 death signal, shutting down workers
W1203 18:34:11.364000 471545 torch/distributed/elastic/multiprocessing/api.py:908] Sending process 471607 closing signal SIGINT
0%| | 18/530630 [00:07<57:44:53, 2.55it/s]
Aborted!
...and we have what look like solid checkpoints:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ ls -lrt runs/original/checkpoints/
total 4
lrwxrwxrwx 1 giles giles 27 Dec 3 18:34 latest -> 20251203Z183404-iteration-0
lrwxrwxrwx 1 giles giles 27 Dec 3 18:34 best -> 20251203Z183404-iteration-0
drwxr-xr-x 2 giles giles 4096 Dec 3 18:34 20251203Z183404-iteration-0
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ ls -lrth runs/original/checkpoints/20251203Z183404-iteration-0/
total 1.9G
-rw-r--r-- 1 giles giles 670M Dec 3 18:34 model.safetensors
-rw-r--r-- 1 giles giles 1.4K Dec 3 18:34 scaler.pt
-rw-r--r-- 1 giles giles 1.3G Dec 3 18:34 optimizer.pt
-rw-r--r-- 1 giles giles 105 Dec 3 18:34 meta.json
However, loading one of those checkpoints fails:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uv run torchrun ddp_train.py original best
On rank 0.
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/ddp_train.py", line 229, in <module>
[rank0]: main()
[rank0]: ~~~~^^
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/.venv/lib/python3.13/site-packages/click/core.py", line 1485, in __call__
[rank0]: return self.main(*args, **kwargs)
[rank0]: ~~~~~~~~~^^^^^^^^^^^^^^^^^
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/.venv/lib/python3.13/site-packages/click/core.py", line 1406, in main
[rank0]: rv = self.invoke(ctx)
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/.venv/lib/python3.13/site-packages/click/core.py", line 1269, in invoke
[rank0]: return ctx.invoke(self.callback, **ctx.params)
[rank0]: ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/.venv/lib/python3.13/site-packages/click/core.py", line 824, in invoke
[rank0]: return callback(*args, **kwargs)
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/ddp_train.py", line 211, in main
[rank0]: train_ds_offset, best_loss = load_checkpoint(
[rank0]: ~~~~~~~~~~~~~~~^
[rank0]: run_dir, checkpoint, model, optimizer, scaler
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: )
[rank0]: ^
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/checkpointing.py", line 16, in load_checkpoint
[rank0]: model.load_state_dict(load_file(checkpoint_dir / "model.safetensors"))
[rank0]: ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/giles/Dev/ddp-base-model-from-scratch/.venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 2629, in load_state_dict
[rank0]: raise RuntimeError(
[rank0]: ...<3 lines>...
[rank0]: )
[rank0]: RuntimeError: Error(s) in loading state_dict for GPTModel:
[rank0]: Missing key(s) in state_dict: "tok_emb.weight", "pos_emb.weight", "trf_blocks.0.att.mask", "trf_blocks.0.att.W_query.weight",
...
[rank0]: Unexpected key(s) in state_dict: "module.final_norm.scale", "module.final_norm.shift", "module.out_head.weight", "module.pos_emb.weight", "module.tok_emb.weight"
...
It turns out that the problem is this code when we save it:
save_checkpoint(
run_dir,
f"iteration-{ix}",
model, optimizer, scaler,
avg_train_loss, val_loss,
ix,
is_best
)
The model that we’re saving is the DDP wrapper around our model; my guess is that it does actually include all of the weights for the model, hence the correct-looking size for the checkpoint file, but they’re renamed – the DDP wrapper sees the underlying model as something called module, so (for example) tok_emb.weight would be called module.tok_emb.weight.
Fixing that, with this diff:
diff --git a/ddp_train.py b/ddp_train.py
index 7418851..963fbf7 100644
--- a/ddp_train.py
+++ b/ddp_train.py
@@ -137,12 +137,13 @@ def train(
if (ix % VAL_AND_CHECKPOINT_INTERVAL == 0) or (ix == len(train_ds) - 1):
print("Validation/checkpoint")
model.eval()
+ base_model = model.module
with torch.inference_mode(), torch.amp.autocast(device_type=device.type, dtype=torch.float16):
val_losses = []
for val_inputs, val_targets in tqdm(val_ds):
val_inputs = val_inputs.to(device).to(torch.long)
val_targets = val_targets.to(device).to(torch.long)
- val_logits = model(val_inputs)
+ val_logits = base_model(val_inputs)
val_losses.append(
calculate_loss(val_logits, val_targets).item()
)
@@ -160,7 +161,7 @@ def train(
save_checkpoint(
run_dir,
f"iteration-{ix}",
- model, optimizer, scaler,
+ base_model, optimizer, scaler,
avg_train_loss, val_loss,
ix,
is_best
...sorts it out – we can load our checkpoints again. Here’s the updated file.
I think we’re going to have to revisit checkpointing and validation again; we don’t want to do it in all of our processes, probably only on global rank 0, and we’ll need to somehow synchronise everything so that the other processes don’t carry on training while we’re doing it.
But before we get on to that, there are a couple of other things to change. At the top of the file we’re defining some constants that look wrong:
BATCH_SIZE = 6
SEQ_LENGTH = 1024
VAL_AND_CHECKPOINT_INTERVAL = 2000
Sequence length
We’ll handle the dumbest of these first; it was actually silly that in the old code we had a constant for sequence length. We’re using the context length of the model for that, so it’s duplicated information. Let’s get it from the model_conf:
diff --git a/ddp_train.py b/ddp_train.py
index 963fbf7..77a62ae 100644
--- a/ddp_train.py
+++ b/ddp_train.py
@@ -20,15 +20,14 @@ from gpt import GPTModel
BATCH_SIZE = 6
-SEQ_LENGTH = 1024
VAL_AND_CHECKPOINT_INTERVAL = 2000
class BigTrainDataset(Dataset):
- def __init__(self, all_tokens):
- self.xs = all_tokens[:-1].reshape(-1, BATCH_SIZE, SEQ_LENGTH)
- self.ys = all_tokens[1:].reshape(-1, BATCH_SIZE, SEQ_LENGTH)
+ def __init__(self, all_tokens, seq_length):
+ self.xs = all_tokens[:-1].reshape(-1, BATCH_SIZE, seq_length)
+ self.ys = all_tokens[1:].reshape(-1, BATCH_SIZE, seq_length)
def __getitem__(self, ix):
return (self.xs[ix], self.ys[ix])
@@ -37,9 +36,10 @@ class BigTrainDataset(Dataset):
return self.xs.shape[0]
-def load_dataset(run_dir, split):
+def load_dataset(run_dir, split, seq_length):
return BigTrainDataset(
- load_file(run_dir / "datasets" / f"{split}.safetensors")["tokens"]
+ load_file(run_dir / "datasets" / f"{split}.safetensors")["tokens"],
+ seq_length,
)
@@ -205,8 +205,8 @@ def main(run, checkpoint):
scaler = torch.amp.GradScaler()
- train_ds = load_dataset(run_dir, "train")
- val_ds = load_dataset(run_dir, "validation")
+ train_ds = load_dataset(run_dir, "train", model_conf["context_length"])
+ val_ds = load_dataset(run_dir, "validation", model_conf["context_length"])
if checkpoint:
train_ds_offset, best_loss = load_checkpoint(
...and here’s the updated file. That was nice and simple.
Batch size
The code that we have specifies the batch size for each GPU – that is, with 6, we’ll have six sequences in each batch on each one. Like I mentioned earlier, that’s called a "micro-batch" in distributed training like this 1 – a per-GPU batch, as opposed to the overall global size across all GPUs – so we could just rename it, and then we’d have 6×ngpus as a global batch size.
However, it feels to me like this is a useful metaparameter to be able to tweak from outside the code. I can see machines with per-GPU VRAM varying from 40 GiB to 160 GiB on Lambda Labs, and pretty clearly that will mean there will be a varying largest micro-batch size on each type. So this is something we’ll want to configure on a per-run basis, so let’s add a new train.json file to our run config, load that up, and pass it through.
That’s a simple enough fix; no need to note the diff, but here’s the code.
Validation/checkpoint interval
This one we’ll need to think about. The size of our validation set is based on what one process running on my local RTX 3090 can validate in five minutes, and the interval (for which I fairly arbitrarily put 2000 in the code when copying it across) was calibrated for roughly every half-hour. Those numbers in turn were aimed at the 44 hours of training time I expected locally.
For this train, we’ll (hopefully!) be taking significantly less time. We’ll have eight GPUs, so naively that’s 5.5 hours of train time, and each will have more VRAM, so we should be able to bump up the batch size and potentially get even faster than that. Depending on which kind of cards we’re using, they may be faster, too – I found that an A100 is slower (with the same batch size) than the RTX 3090 in my fine-tuning experiments, but the H100 and B200 are likely faster.
I think this is another thing for the train config; we should have the validation interval (in terms of iterations) and the number of batches to do for validation.
Datasets
Now, let’s move on to the dataset. With the code as it is right now, all of our per-GPU processes are using this code to iterate over the same dataset:
for ix in tqdm(range(train_ds_offset, len(train_ds))):
That means that they’ll all be training on the same data; the synchronisation that is happening "magically" in the background means that they’ll all train on the first item, work out gradients, and step their optimiser – so they’ll essentially (modulo randomness) have the same updates. Pretty pointless! What we want is for each of the n per-GPU processes to train on 1/n of the data.
We have two useful helpers in torch.distributed:
get_rank, which gets the global rank of this process. In our one-machine case, it returns 0 for the process on gpu0, 1 for the one on gpu1, and so on. We’re already using it in that setup code we looked at earlier:
rank = dist.get_rank()
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
device_id = rank % torch.accelerator.device_count()
get_world_size, which tells us how many GPU processes there are (globally – it would be across all machines if we had more than one)
So, the simplest thing to do is to use the world size as a step, and the rank as an offset:
rank = dist.get_rank()
world_size = dist.get_world_size()
for ix in tqdm(range(train_ds_offset + rank, len(train_ds), world_size)):
Validation and checkpointing only on rank 0
Now, remember that the same code is running for every one of our per-GPU processes. That means that all of them will do the training with forward and backward passes, and their own optimiser steps, all synchronised by PyTorch DDP magic. But they will also do their own validations – which is kind of pointless – and they’ll also try to save their own checkpoints, which would be messy because they could quite easily interfere with each other; after all, all of the processes are running on the same machine and would be writing to the same filesystem.
So, as a first cut, let’s just wrap an if rank == 0 around the eval and checkpointing stuff – we change this:
if (ix % validation_interval == 0) or (ix == len(train_ds) - 1):
...to this:
if rank == 0 and ((ix % validation_interval == 0) or (ix == len(train_ds) - 1)):
That line is getting bit long, so let’s break it apart a bit:
is_eval_iter = (
(ix % validation_interval == 0)
or (ix == len(train_ds) - 1)
)
if rank == 0 and is_eval_iter:
That looks OK, but there’s an extra wrinkle: all of the processes are running the same code, so while the rank zero one will do the eval, the others will continue through the script, so they will go right back around our loop and start training on the next batches – which is bad. We want our processes to be proceeding in lockstep, iteration-by-iteration.
Luckily, the solution is simple: the barrier function in torch.distributed basically says "stop here until all of our processes have reached this point".
So we can use two of those – one before the eval loop, to make sure that all of the processes have finished their training part of the iteration before we do the eval on rank zero, and one after the eval, so that the non-rank-zero processes will wait.
One bit of complexity – we want to do those barriers only if it’s a eval iteration, but we want to do them for all processes. So we have to break up the if statement, and we wind up with this:
is_eval_iter = (
(ix % validation_interval == 0)
or (ix == len(train_ds) - 1)
)
if is_eval_iter:
dist.barrier()
if rank == 0:
print("Validation/checkpoint")
model.eval()
base_model = model.module
with torch.inference_mode(), torch.amp.autocast(device_type=device.type, dtype=torch.float16):
val_losses = []
for val_inputs, val_targets in tqdm(val_ds[:validation_batches]):
val_inputs = val_inputs.to(device).to(torch.long)
val_targets = val_targets.to(device).to(torch.long)
val_logits = base_model(val_inputs)
val_losses.append(
calculate_loss(val_logits, val_targets).item()
)
val_loss = sum(val_losses) / len(val_losses)
if best_loss is None or val_loss < best_loss:
is_best = True
best_loss = val_loss
else:
is_best = False
avg_train_loss = sum(train_losses) / len(train_losses)
train_losses = []
save_checkpoint(
run_dir,
f"iteration-{ix}",
base_model, optimizer, scaler,
avg_train_loss, val_loss,
ix,
is_best
)
generate_training_chart(run_dir)
model.train()
print("Continuing training")
dist.barrier()
That seems to work OK (code here), but it does give a warning:
UserWarning: barrier(): using the device under current context. You can specify ``device_id`` in ``init_process_group`` to mute this warning.
So, we want to pass the device ID in when we call init_process_group. Let’s dig into that a bit.
Revisiting the init code
Here’s the copypasta that I took from the PyTorch tutorial earlier in this post:
torch.accelerator.set_device_index(int(os.environ["LOCAL_RANK"]))
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
dist.init_process_group(backend)
rank = dist.get_rank()
print(f"On rank {rank}.")
device_id = rank % torch.accelerator.device_count()
Let’s dig into what that is doing.
The LOCAL_RANK environment variable is being set by torchrun to 0, 1, 2, etc as appropriate to tell us which process we are on this machine. So the first line is telling PyTorch to use the device with that index for this process.
The next line is getting the current accelerator – that is, an object that represents which acceleration hardware we’re using in this process.
I think that the best way to see the combination of these two lines is that the first says "use gpu0" (or 1, or 2, or...), and then the second says "get the object describing the GPU you’re using right now". So it’s a slightly indirect way of getting the object containing the details of the GPU in question.
Next, we call torch.distributed.get_default_backend_for_device. A backend in this context is an abstraction of whatever system the device in question is programmed using – in the case of an Nvidia GPU, it would be some kind of thing that encapsulates CUDA.
Once that’s done, we call torch.distributed.init_process_group, passing in the backend that we’re using. We’re saying "initialise the internal data structures for torch.distributed so that they’re all set up properly to work with the backend we specified".
After that, we can do stuff like getting the global rank with dist.get_rank and so on, because torch.distributed has been properly initialized. Presumably at this point we’re talking to any other machines in a multi-machine cluster, so we can find out what our world size is and that kind of thing.
That extra line at the end, to get the device_id:
device_id = rank % torch.accelerator.device_count()
...actually looks erroneous to me. All of our code is assuming one process per GPU. So I think we can just use the LOCAL_RANK there as well.
Let’s rewrite it like this (with some useful comments):
# Which of the one-per-GPU processes are we?
rank = int(os.environ["LOCAL_RANK"])
# Set ourselves up to use the GPU with ID ``rank``
torch.accelerator.set_device_index(rank)
# Get the accelerator object associated with that GPU,
# and the associated backend object (eg. ``nccl`` for CUDA):
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
# Initialize torch.distributed; set the device ID explicitly
# to avoid warnings in ``dist.barrier``
dist.init_process_group(backend, device_id=rank)
print(f"On rank {rank}.")
model = GPTModel(model_conf).to(rank)
That seems to work well! Here’s the code. However, I ran it past ChatGPT (largely to validate my understanding of what was going on), and it highlighted something slightly misleading about it.
Right now, we’re training on a single node, with one process per GPU. But again, one of the neat-o things about this DDP stuff is that it should be able to scale to multiple nodes.
Now, remember that LOCAL_RANK is just the rank of the current process on the specific node that it’s running on – hence the name. If we had two machines, each with 8 GPUs, then there would be a process with rank zero on each of them.
The "real" rank – that is, across all machines – is the one that you can get from dist.get_rank once it has been initialised. One of the things it does during that initialisation is to talk to all of the other nodes and work that kind of thing out – which of the local rank zero processes across all of the machines is the global rank zero process.
So we need to use the local rank when working out which GPU we should be running on and so on, but we should not treat it as a global rank.
That’s actually quite fine in this case, as we’re calling dist.get_rank inside the training loop when we actually need to use the global one (when indexing into the dataset, or when deciding if we’re the process that should be doing evals and checkpoints). The only place where we might be confusing matters is in that print, which is not important anyway, as the training loop also prints out its rank.
So, let’s tweak it a little more for clarity:
# Which of the one-per-GPU processes are we on this machine?
local_rank = int(os.environ["LOCAL_RANK"])
# Set ourselves up to use the GPU with the ID that matches our local rank
torch.accelerator.set_device_index(local_rank)
# Get the accelerator object associated with that GPU,
# and the associated backend object (eg. ``nccl`` for CUDA):
acc = torch.accelerator.current_accelerator()
backend = torch.distributed.get_default_backend_for_device(acc)
# Initialize torch.distributed; set the device ID explicitly
# to avoid warnings in ``dist.barrier``
dist.init_process_group(backend, device_id=local_rank)
model = GPTModel(model_conf).to(local_rank)
That seems to work well! Here’s the code.
Time to run it past ChatGPT to see if I’ve made any dumb errors. Turns out that (unsurprisingly) I have...
Checkpointing, revisited
Let’s go back to our code that decides whether or not it’s an iteration where we need to do a validation run and a checkpoint:
is_eval_iter = (
(ix % validation_interval == 0)
or (ix == len(train_ds) - 1)
)
The problem is that our index ix is different in the different processes! Remember, we have this in order to pick out the correct training items:
for ix in tqdm(range(train_ds_offset + rank, len(train_ds), world_size)):
So let’s think about it; in the first run through the loop, with 8 GPUs, we would have
ix= 0 for the process with rank 0ix= 1 for the process with rank 1- ...
ix= 7 for the process with rank 7
In the next run through the loop, we’d have:
ix= 8 for the process with rank 0ix= 9 for the process with rank 1- ...
ix= 15 for the process with rank 7
So is_eval_iter will give different results for each process. That might not sound like the end of the world – ix % validation_interval will only be zero for one of them, so long as validation_interval is larger than the number of GPUs – but remember that our validation code looks like this:
if is_eval_iter:
dist.barrier()
if rank == 0:
# do the validation and checkpointing
dist.barrier()
Now, if different processes have different values for is_eval_iter, then dist.barrier() will only be called in the one(s) for which it is True. But dist.barrier() means "wait until all processes have reached this barrier". So the ones that call it will lock up completely until other processes get there, and everything will at best get out-of-sync, and at worst will lock up completely.
I think that the problem here is that I’m conflating two things: the index of the global step – that is, one iteration across all GPUs – and the dataset element that we want to use. In the original one-GPU case that made, sense; iteration 0 was on dataset element 0, iteration 1 was on element 1, and so on. But now the offset into the dataset, and the global step, are quite different things.
This is quite deeply embedded in the code, but we can fix it!
Let’s start off by changing our checkpoint code, just to rename things. It keeps track of a variable called train_ds_offset, our offset into the training dataset, and uses that both to index into the dataset, and to work out how far through the train we are. The latter is a much better thing to store in a checkpoint, so instead of saving train_ds_offset, we’ll store (and restore) global_step. Basically, just a rename so that the variables and stored JSON match the new reality. Here’s the updated code.
Now we need to make a number of minor changes to the training loop just to match that rename of the value that we’re checkpointing (eg. for the code to generate the training chart) but the most important change is to our loop. Instead of iterating over our dataset with a step and and offset so that we can index into it, we firstly work out how many global steps there will be:
total_global_steps = len(train_ds) // world_size
...then we iterate from our initial global step – zero if we’re starting a fresh train, or whatever global step we were on in a loaded checkpoint plus one if we’re doing a continued train from a checkpoint – up to the total_global_steps:
for global_step in tqdm(range(start_global_step, total_global_steps)):
That means that we need to use the global step, the world size, and our current rank to work out which dataset item we should be training on for this process at this global step. Let’s say that we have eight processes; on the 0th global step, we should have rank 0 training on dataset item 0, rank 1 on item 1, and so on. On the next global step, rank 0 should train on item 8, rank 1 on 9, and so on. So:
inputs, targets = train_ds[global_step * world_size + rank]
That’s actually much more elegant than the earlier code, and seems to work fine. Here it is.
Phew, glad to have caught that before I started spending money on machines – it would have been confusing if everything locked up. Thanks, ChatGPT!
Slicing the validation dataset
Another thing that raised by ChatGPT is about the validation. We don’t want to validate across all of the validation dataset – we’re using a number from the train.json. I have this code:
for val_inputs, val_targets in tqdm(val_ds[:validation_batches]):
This looked like a nice, quick way to get the first validation_batches elements of the validation dataset. But ChatGPT told me it would raise. It didn’t, though – why?
The problem is that I had validation_batches set to 2 in my training config for testing. Stepping through what that slice does, when we run val_ds[:validation_batches]:
Python calls the __getitem__ on the dataset, passing in a slice object as ix, so this code is called with it:
def __getitem__(self, ix):
return (self.xs[ix], self.ys[ix])
Now, because that code doesn’t do anything clever with slices, they’re passed straight down to the tensors that make up self.xs and self.ys. So it’s actually equivalent to this:
return self.xs[:validation_batches], self.ys[:validation_batches]
Or, to rewrite the whole loop (omitting the tqdm for clarity):
for val_inputs, val_targets in (self.xs[:validation_batches], self.ys[:validation_batches]):
...
So, the first time through the loop, we try to bind our loop variables like this:
val_inputs, val_targets = self.xs[:validation_batches]
That is clearly wrong! It’s equivalent to this:
val_inputs = self.xs[:validation_batches][0]
val_targets = self.xs[:validation_batches][1]
...with code to blow up if self.xs[:validation_batches] has more than two elements – the normal Python "ValueError: too many values to unpack"
- But if
validation_batchesis set to 2, which it happened to be in my case, then it will silently fail – our first eval loop will get the first X from the validation set asval_inputs, and the second X asval_targets.
Nasty! AI code review certainly helped me dodge a bullet on that one.
Let’s fix it, it’s not a big change: we can just do this:
for val_ix in tqdm(range(validation_batches)):
val_inputs, val_targets = val_ds[val_ix]
...and that works! So here’s the code now.
Back to the datasets
So, I think we have one final issue, which is the training and validation datasets. In our single-GPU train, we worked out ahead of time how much of FineWeb (or FineWeb-Edu) to train on – the Chinchilla-optimal number – and generated a dataset that contained a round number of 6-sequence, 1024-token batches that was the smallest such round number that was larger than our target. We also worked out exactly how large (in terms of batches) our validation dataset needed to be so that each validation run would take five minutes.
There was one big issue with that system; when I decided to do an "extended" train on more of the FineWeb-Edu dataset, in order to see whether I could get the loss down further, I had to do some nasty hackery in order to generate a new one. So it would be nice to not have that problem this time around.
Additionally, we’re likely to be tweaking the batch size quite a lot in this experiment while we find what the appropriate level is to fit onto the cloud GPUs, and also varying how much validation we do – and additionally, we have the world size to worry about.
I think that the best way to give us the flexibility we need will be to pre-convert the complete FineWeb and FineWeb-Edu datasets into the format we need – each sequence in the dataset converted to GPT-2 tokens, and then those sequences concatenated together, with the <|endoftext|> token 50257 separating them.
It would be good to properly nail down the validation dataset at the same time. So we can have a script that loads up the original dataset as downloaded from Hugging Face, splits it into 99% train, 1% validation, does the conversion, and then saves them as safetensors files.
If we use uint16 for those (which is just large enough for our 50,257-token vocab), we can fit the ~10B tokens in each dataset’s train split into 20 GiB of disk. Not too bad.
But there will still be the issue of getting them onto our cloud machines. Let’s generate the data, and then work out how to handle that.
I tried initially with the code I used last time, adapted to run through the entire dataset. It does the 99%/1% train/validation split, and then for each of those generates a single massive tensor of tokens like this:
-
Zoom through the records in the dataset in batches of 1,000.
-
For each batch:
-
Tokenising each batch, so we get a list of lists of tokens.
-
Convert that list of lists into a single list
<|endoftext|>tokens separating each item. -
Convert that list into a PyTorch
uint16tensor. -
Add the tensor to a
resultslist. -
After that’s all done, use
torch.catto convert theresultslist into a single tensor, and then save that withsafetensors.
It almost worked! To my surprise, it got all the way to the end, and only blew up with an out-of-memory error when it was trying to save the result – and it did that completely silently, so I thought it had worked right up until I tried to check the file on disk to see how large it was, and it wasn’t there.
The obvious tweak: set the results list to None just after the torch.cat, to free up the memory it’s using. Given that it was the save that triggered the OOM, you’d think that that would be enough – but it turned out not to be so.
Rather than mess around with this for much longer, I just decided to add on 128 GiB of swap to my machine temporarily:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ sudo dd if=/dev/zero of=./swap bs=1G count=128
[sudo] password for giles:
128+0 records in
128+0 records out
137438953472 bytes (137 GB, 128 GiB) copied, 63.1124 s, 2.2 GB/s
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ sudo chmod 0600 ./swap
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ sudo mkswap ./swap
Setting up swapspace version 1, size = 128 GiB (137438949376 bytes)
no label, UUID=693d72a1-871d-4ab8-b0c8-b383b435ca8f
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ sudo swapon ./swap
...and that was enough to make it run. So I’ve now generated pre-tokenised, pre-concatenated train and validation sets for both FineWeb and FineWeb-Edu:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ ls -lrth fineweb-prepared/
total 20G
-rw-r--r-- 1 giles giles 196M Dec 4 21:02 validation.safetensors
-rw-r--r-- 1 giles giles 20G Dec 4 21:20 train.safetensors
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ ls -lrth fineweb-edu-prepared/
total 19G
-rw-r--r-- 1 giles giles 192M Dec 4 22:43 validation.safetensors
-rw-r--r-- 1 giles giles 19G Dec 4 22:59 train.safetensors
Now, thinking about how to get it up to the Lambda Labs machines. I have normal 1 Gb residential broadband, so conceivably I could upload 20 GiB in about 200 seconds. But that’s assuming that there’s no network congestion, so I would expect it to take longer. The LL machines are quite expensive, and I don’t want to waste money keeping them up while I’m just uploading data.
There are possibilities here:
- I can upload the datasets to Hugging Face; their network connection will be better than mine, so I can just pay the price in time of uploading everything from home once, and then I can download them faster from HF to LL. That also has the benefit of meaning that after this experiment I can safely delete the local files, but then download them again if I need them. And if anyone else wants to repro this experiment, the data will be easily available to them.
- Lambda Labs have persistent filesystems that you can use. They cost $0.20/GB/month, so that would be about $5/month for all of my datasets. So I could upload the data to a cheap instance with a persistent filesystem mounted, shut down that instance but keep the filesystem, and then mount it on each machine I use to run tests. .
I think the best option is to use option (1), but with the option of also doing (2). The HF dataset will still take time to download to LL, even over the faster network connection. That might not be a problem – but if it is, I download it once on a cheap instance and use a persistent disk too. Essentially I’d be using the persistent disk as a "cache", and still get the benefits of the easily-shareable datasets on Hugging Face.
So, that decided, let’s find out how we can upload a whacking great 20 GiB safetensors file as a dataset on Hugging Face.
Putting the datasets on Hugging Face.
It turns out that resources like datasets on HF are just Git repositories using the LFS (Large File System) plugin to be able to handle, well, large files. Conveniently, given that I’m using uv to manage my project, there’s a plugin that allows me to use their CLI tools with minimal effort, so:
giles@perry:~/Dev/ddp-base-model-from-scratch (main)$ uvx hf auth login
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
To log in, ``huggingface_hub`` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible):
Add token as git credential? [y/N]: n
Token is valid (permission: write).
The token ``[REDACTED]`` has been saved to /home/giles/.cache/huggingface/stored_tokens
Your token has been saved to /home/giles/.cache/huggingface/token
Login successful.
The current active token is: ``[REDACTED]``
giles@perry:~/Dev/ddp-base-model-from-scratch (main