Training GPT-2 Small on 12GB of data sounded simple. It was not simple. Here are 10 ways the universe humbled me, and how I eventually fixed each one.
Challenge 1: The Config That Lied
The Error:
| |
What I Did: Tried to resume training with SEQ_LEN = 1024.
What I Should Have Done: Checked the checkpoint. It was trained with SEQ_LEN = 512.
The Fix: Change one number. Feel dumb for 10 minutes.
Lesson: Read your own configs before complaining about PyTorch bugs.
Challenge 2: The dtype That Shall Not Be Named
The Error:
| |
Translation: CrossEntropyLoss wants int64. I gave it int32. It threw a tantrum.
The Fix:
| |
Lesson: PyTorch loss functions are dtype snobs. Just .long() everything classification-related.
Challenge 3: The Phantom Checkpoint
The Error:
| |
What Happened: My code looked for checkpoint.pth. My training script saved pytorch_model.bin. Nobody told me.
The Fix:
| |
Lesson: Actually look at what files exist before writing code that loads them. Revolutionary, I know.
Challenge 4: The Glacial Training Run 🐢
The Symptom: 1.6 seconds per step on an A100. The A100. The $10,000 GPU. Running like a potato.
The Problem: Memory-mapped files + random access = death by I/O.
Every batch, mmap was doing random reads across an 11GB file. The disk was crying. The GPU was bored.
The Fix: Just… load it into RAM:
| |
Result: 0.8s/step. 2x faster. The GPU started doing GPU things.
Lesson: If it fits in RAM, stop being clever with mmap.
Challenge 5: Python Being Python
The Symptom: Still 0.8s/step. GPU utilization still sad.
The Culprit: This innocent-looking line:
| |
Ah yes, a Python for loop in the hot path. Peak performance.
The Fix: Vectorize like an adult:
| |
Result: 0.5s/step. 1.6x faster. Python loops are for preprocessing, not training.
Challenge 6: The GPU That Wasn’t Trying
The Symptom: 0.5s/step. GPU at 60% utilization. It’s just… vibing.
The Fixes:
Step 1 - torch.compile():
| |
PyTorch 2.0’s magic spell. Fuses operations, generates custom CUDA kernels, makes everything faster.
Step 2 - Mixed Precision (AMP):
| |
Bfloat16 goes brrr.
Result: 0.225s/step. 2.2x faster. GPU finally earning its keep.
Challenge 7: The OOM That Wouldn’t Quit
The Error:
| |
What I Did: Got greedy. batch_size=128, mode='max-autotune'. YOLO.
What Happened: The GPU politely informed me that 40GB isn’t infinite.
The Fix:
| |
Lesson: max-autotune is cool but uses extra VRAM for kernel search. Start conservative.
Challenge 8: The Zombie Memory 🧟
The Horror: After an OOM crash, GPU memory shows 79GB in use. Nothing is running. The memory just… won’t leave.
What I Tried:
| |
What Worked: None of that. Runtime > Restart Runtime. Start over. Accept defeat.
Lesson: After OOM, the crashed process ghosts you with its memory. Restart is the only exorcism.
Challenge 9: The Tokenizer That Changed Its Mind
The Problem: Trained with character tokenizer. New data uses BPE. Code hardcoded character tokenizer. Surprise!
The Fix:
| |
Lesson: Hardcoding is technical debt with a high interest rate.
Challenge 10: The dtype Strikes Back
The Problem: Character tokenization used int16 (vocab < 65K). BPE uses 32K vocab but I used int32 for safety. Code assumed int16.
The Fix:
| |
Lesson: Store dtype in metadata. Read dtype from metadata. Trust no assumptions.
The Optimization Journey (1.6s → 0.225s)
| What I Did | Speed | Speedup | Effort Level |
|---|---|---|---|
| Baseline (mmap) | 1.6s/step | 1x | N/A |
| RAM preload | 0.8s/step | 2x | 1 line of code |
| Vectorized batching | 0.5s/step | 3.2x | 20 minutes |
| + torch.compile | 0.35s/step | 4.6x | 1 line of code |
| + AMP | 0.225s/step | 7.1x | 10 lines of code |
7x speedup. ~100 hours saved. Not bad for an afternoon of debugging.
The Takeaways
On Optimization
- Profile first. Is it I/O? Compute? Python? Find out before changing random things.
- RAM beats mmap for random access. Memory-mapping is for sequential reads, not ML training.
- Vectorize or suffer. Python loops in hot paths are a crime.
- torch.compile + AMP = free performance. If you’re not using both in 2026, you’re volunteering for slow training.
On Colab Survival
- Checkpoints go to Drive. Colab will disconnect. It’s not a question of if.
- OOM = restart runtime.
empty_cache()is a suggestion, not a command. - Budget compute units. A100 burns 8 units/hour. Math accordingly.
On Debugging
- Check dtypes. Then check them again. Then add
.long()anyway. - Config mismatches are silent killers. SEQ_LEN, vocab_size, embed_dim — they all must match.
- Verify file names exist. Before writing the code that loads them. Wild concept.