KellerJordan / modded-nanogpt
- среда, 15 октября 2025 г. в 00:00:06
NanoGPT (124M) in 3 minutes
This repository hosts the NanoGPT speedrun, in which we (collaboratively|competitively) search for the fastest algorithm to use 8 NVIDIA H100 GPUs to train a language model that attains 3.28 cross-entropy loss on the FineWeb validation set.
The target (3.28 validation loss on FineWeb) follows Andrej Karpathy's GPT-2 replication in llm.c, which attains that loss after running for 45 minutes. The speedrun code also descends from llm.c's PyTorch trainer, which itself descends from NanoGPT, hence the name of the repo. Thanks to the efforts of many contributors, this repo now contains a training algorithm which attains the target performance in:
This improvement in training speed has been brought about by the following techniques:
As well as many systems optimizations.
Contributors list (growing with each new record): @bozavlado; @brendanh0gan; @fernbear.bsky.social; @Grad62304977; @jxbz; @kellerjordan0; @KoszarskyB; @leloykun; @YouJiacheng; @jadenj3o; @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad; @ryanyang0
To run the current record, run the following commands.
git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
pip install -r requirements.txt
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --upgrade
# downloads only the first 800M training tokens to save time
python data/cached_fineweb10B.py 8
./run.sh
Note: torch.compile will add around 5 minutes of latency the first time you run the code.
For cases where CUDA or NCCL versions aren't compatible with your current system setup, Docker can be a helpful alternative. This approach standardizes versions for CUDA, NCCL, CUDNN, and Python, reducing dependency issues and simplifying setup. Note: an NVIDIA driver must already be installed on the system (useful if only the NVIDIA driver and Docker are available).
git clone https://github.com/KellerJordan/modded-nanogpt.git && cd modded-nanogpt
sudo docker build -t modded-nanogpt .
sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt python data/cached_fineweb10B.py 8
sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt sh run.sh
To get an interactive docker, you can use
sudo docker run -it --rm --gpus all -v $(pwd):/modded-nanogpt modded-nanogpt bash
The following is the historical progression of world speed records for the following competitive task:
Train a neural network to ≤3.28 validation loss on FineWeb using 8x NVIDIA H100s.
Note: The 3.28 target was selected to match Andrej Karpathy's GPT-2 (small) reproduction.
# | Record time | Description | Date | Log | Contributors |
---|---|---|---|---|---|
1 | 45 minutes | llm.c baseline | 05/28/24 | log | @karpathy, llm.c contributors |
2 | 31.4 minutes | Tuned learning rate & rotary embeddings | 06/06/24 | log | @kellerjordan0 |
3 | 24.9 minutes | Introduced the Muon optimizer | 10/04/24 | none | @kellerjordan0, @jxbz |
4 | 22.3 minutes | Muon improvements | 10/11/24 | log | @kellerjordan0, @bozavlado |
5 | 15.2 minutes | Pad embeddings, ReLU², zero-init projections, QK-norm | 10/14/24 | log | @Grad62304977, @kellerjordan0 |
6 | 13.1 minutes | Distributed the overhead of Muon | 10/18/24 | log | @kellerjordan0 |
7 | 12.0 minutes | Upgraded PyTorch 2.5.0 | 10/18/24 | log | @kellerjordan0 |
8 | 10.8 minutes | Untied embedding and head | 11/03/24 | log | @Grad62304977, @kellerjordan0 |
9 | 8.2 minutes | Value and embedding skip connections, momentum warmup, logit softcap | 11/06/24 | log | @Grad62304977, @kellerjordan0 |
10 | 7.8 minutes | Bfloat16 activations | 11/08/24 | log | @kellerjordan0 |
11 | 7.2 minutes | U-net pattern skip connections & double lr | 11/10/24 | log | @brendanh0gan |
12 | 5.03 minutes | 1024-ctx dense causal attention → 64K-ctx FlexAttention | 11/19/24 | log | @KoszarskyB |
13 | 4.66 minutes | Attention window warmup | 11/24/24 | log | @fernbear.bsky.social |
14 | 4.41 minutes | Value Embeddings | 12/04/24 | log | @KoszarskyB |
15 | 3.95 minutes | U-net pattern value embeddings, assorted code optimizations | 12/08/24 | log | @leloykun, @YouJiacheng |
16 | 3.80 minutes | Split value embeddings, block sliding window, separate block mask | 12/10/24 | log | @YouJiacheng |
17 | 3.57 minutes | Sparsify value embeddings, improve rotary embeddings, drop an attn layer | 12/17/24 | log | @YouJiacheng |
18 | 3.4 minutes | Lower logit softcap from 30 to 15 | 01/04/25 | log | @KoszarskyB |
19 | 3.142 minutes | FP8 head, offset logits, lr decay to 0.1 instead of 0.0 | 01/13/25 | log | @YouJiacheng |
20 | 2.992 minutes | Merged QKV weights, long-short attention, attention scale, lower Adam epsilon, batched Muon | 01/16/25 | log | @leloykun, @fernbear.bsky.social, @YouJiacheng, @brendanh0gan, @scottjmaddox, @Grad62304977 |
21 | 2.933 minutes | Reduced batch size | 01/26/25 | log | @leloykun |
21 | 2.997 minutes | 21st record with new timing | 02/01/25 | log | not a new record, just re-timing #21 with the updated rules |
21 | 3.014 minutes | 21st record with latest torch | 05/24/25 | log | not a new record, just re-timing #21 with latest torch |
22 | 2.990 minutes | Faster gradient all-reduce | 05/24/25 | log | @KonstantinWilleke, @alexrgilbert, @adricarda, @tuttyfrutyee, @vdlad; The Enigma project |
23 | 2.979 minutes | Overlap computation and gradient communication | 05/25/25 | log | @ryanyang0 |
24 | 2.966 minutes | Replace gradient all_reduce with reduce_scatter | 05/30/25 | log | @vagrawal |
25 | 2.896 minutes | Upgrade PyTorch to 2.9.0.dev20250713+cu126 | 07/13/25 | log | @kellerjordan0 |
26 | 2.863 minutes | Align training batch starts with EoS, increase cooldown frac to .45 | 07/13/25 | log | @ClassicLarry |
The only rules are that new records must:
torch._inductor.config
or torch.compile
flags. (These can save a few seconds, but they can also make compilation take >30min. This rule was introduced after the 21st record.)Note:
torch._inductor.config.coordinate_descent_tuning
is allowed for GPT-2 Medium track (a.k.a. 2.92 track).
Other than that, anything and everything is fair game!
The target metric is cross-entropy loss on the FineWeb val set. To speak mathematically, the goal of the speedrun is *to obtain a probability model of language which assigns a probability of at least math.exp(-3.28 * 10485760)
to the first 10,485,760 tokens of the FineWeb valset. Hence, e.g., we allow evaluation at any sequence length, so long as we still have a valid probability model of language.
After the 21st record, we made two changes to the timing. First, there used to be an initial "grace period" of 10 untimed steps to allow kernel warmup. We replaced this with an explicit kernel-warmup section which is untimed and uses dummy data. This results in an extra runtime of 850ms from the 10 extra timed steps.
Second, we banned the use of torch._inductor.config.coordinate_descent_tuning
. This saves ~25min of untimed pre-run compilation, but results in an extra runtime of ~3s.
Thanks to the statistical testing of @agrawal (holder of the 24th record), we have learned that records 23, 24, and in all likelihood 22 and 25, actually attain a mean loss of 3.281, which is slightly above the 3.28 target. Therefore if we were to completely adhere to the speedrun rules, we would have to deny that these are valid records. However, we have decided to leave them in place as valid, because of the following two reasons: (a) the extra loss is most likely my (@kellerjordan0) own fault rather than that of the records, and (b) it is most likely easily addressable.
Here's what happened: Records #22 to #25 each change only the systems/implementation of the speedrun. Therefore, the requirement to do statistical testing to confirm they hit the target was waived, since in theory they should have hit it automatically, by virtue of the fact that they didn't touch the ML (i.e., they didn't change the architecture, learning rate, etc.).
So if these records shouldn't have changed the ML, what explains the regression in val loss? We think that most likely, the answer is that this regression was indeed not introduced by any of these records. Instead, it was probably caused by my own non-record in which I retimed record #21 with newest torch, because in this non-record I also changed the constants used to cast the lm_head to fp8. I thought that this change should cause only a (small) strict improvement, but apparently that was not the case.
Therefore, it is probable that each of records #22-25 could be easily made fully valid by simply reverting the change I made to those constants. Therefore they shall be upheld as valid records.
For the future, fortunately record #26 brought the speedrun back into the green in terms of <3.28 loss, so (with high p-value) it should be in a good state now.
Notable runs:
Notable forks:
The target loss for this track is lowered from 3.28 to 2.92, as per Andrej Karpathy's 350M-parameter llm.c baseline. This baseline generates a model with performance similar to the original GPT-2 Medium, whereas the first track's baseline generates a model on par with GPT-2 Small. All other rules remain the same.
Note:
torch._inductor.config.coordinate_descent_tuning
is turned on after the record 6 (*).
# | Record time | Description | Date | Log | Contributors |
---|---|---|---|---|---|
1 | 5.8 hours | llm.c baseline (350M parameters) | 05/28/24 | log | @karpathy, llm.c contributors |
2 | 29.3 minutes | Initial record based on scaling up the GPT-2 small track speedrun | 01/18/25 | log | @kellerjordan0 |
3 | 28.1 minutes | Added standard weight decay | 02/08/25 | log | @kellerjordan0 |
4 | 27.7 minutes | Tuned Muon Newton-Schulz coefficients | 02/14/25 | log | @leloykun |
5 | 27.2 minutes | Increased learning rate cooldown phase duration | 03/06/25 | log | @YouJiacheng |
6 | 25.95 minutes* | 2x MLP wd, qkv norm, all_reduce/opt.step() overlap, optimized skip pattern | 03/25/25 | log | @YouJiacheng |
7 | 25.29 minutes | Remove FP8 head; ISRU logits softcap; New sharded mixed precision Muon; merge weights | 04/16/25 | log | @YouJiacheng |
8 | 24.50 minutes | Cubic sliding window size schedule, 2× max window size (24.84 minutes) 24.5min repro | 04/22/25 | log | @jadenj3o |
A: The officially stated goal of NanoGPT speedrunning is as follows: gotta go fast
. But for something a little more verbose involving an argument for good benchmarking, here's some kind of manifesto, adorned with a blessing from the master. https://x.com/karpathy/status/1846790537262571739
A: Because it is a competitive benchmark. In particular, if you attain a new speed record (using whatever method you want), there is an open invitation for you to post that record (on arXiv or X) and thereby vacuum up all the clout for yourself. I will even help you do it by reposting you as much as I can.
A: This is hard to refute, since "at scale" is an infinite category (what if the methods stop working only for >100T models?), making it impossible to fully prove. Also, I would agree that some of the methods used in the speedrun are unlikely to scale, particularly those which impose additional structure on the network, such as logit softcapping. But if the reader cares about 1.5B models, they might be convinced by this result:
Straightforwardly scaling up the speedrun (10/18/24 version) to 1.5B parameters yields a model with GPT-2 (1.5B)-level HellaSwag performance 2.5x more cheaply than @karpathy's baseline ($233 instead of $576):
Muon is defined as follows:
Where NewtonSchulz5 is the following Newton-Schulz iteration [2, 3], which approximately replaces G
with U @ V.T
where U, S, V = G.svd()
.
@torch.compile
def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7):
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16() / (G.norm() + eps)
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
return X.to(G.dtype)
For this training scenario, Muon has the following favorable properties:
Many of the choices made to generate this optimizer were obtained experimentally by our pursuit of CIFAR-10 speedrunning. In particular, we experimentally obtained the following practices:
Our use of a Newton-Schulz iteration for orthogonalization traces to Bernstein & Newhouse (2024), who suggested it as a way to compute Shampoo [5, 6] preconditioners, and theoretically explored Shampoo without preconditioner accumulation. In particular, Jeremy Bernstein @jxbz sent us the draft, which caused us to experiment with various Newton-Schulz iterations as the orthogonalization method for this optimizer. If we had used SVD instead of a Newton-Schulz iteration, this optimizer would have been too slow to be useful. Bernstein & Newhouse also pointed out that Shampoo without preconditioner accumulation is equivalent to steepest descent in the spectral norm, and therefore Shampoo can be thought of as a way to smooth out spectral steepest descent. The proposed optimizer can be thought of as a second way of smoothing spectral steepest descent, with a different set of memory and runtime tradeoffs compared to Shampoo.
run.sh
to have a different --nproc_per_node
. This should not change the behavior of the training.@misc{modded_nanogpt_2024,
author = {Keller Jordan and Jeremy Bernstein and Brendan Rappazzo and
@fernbear.bsky.social and Boza Vlado and You Jiacheng and
Franz Cesista and Braden Koszarsky and @Grad62304977},
title = {modded-nanogpt: Speedrunning the NanoGPT baseline},
year = {2024},
url = {https://github.com/KellerJordan/modded-nanogpt}
}