Training GPT-2 Small on 12GB of data sounded simple. It was not simple. Here are 10 ways the universe humbled me, and how I eventually fixed each one.


Challenge 1: The Config That Lied

The Error:

1
2
3
RuntimeError: size mismatch for causal_mask: copying a param with shape
torch.Size([1024, 1024]) from checkpoint, the shape in current model is
torch.Size([512, 512])

What I Did: Tried to resume training with SEQ_LEN = 1024.

What I Should Have Done: Checked the checkpoint. It was trained with SEQ_LEN = 512.

The Fix: Change one number. Feel dumb for 10 minutes.

Lesson: Read your own configs before complaining about PyTorch bugs.


Challenge 2: The dtype That Shall Not Be Named

The Error:

1
2
3
NotImplementedError: Could not run 'aten::nll_loss_forward_reduce_cuda_kernel_2d_index'
with arguments from the 'CUDA' backend.
'Int' (dtype)

Translation: CrossEntropyLoss wants int64. I gave it int32. It threw a tantrum.

The Fix:

1
2
X = batch[:, :-1].long().to(device)  # .long() = int64
Y = batch[:, 1:].long().to(device)   # Don't forget the targets

Lesson: PyTorch loss functions are dtype snobs. Just .long() everything classification-related.


Challenge 3: The Phantom Checkpoint

The Error:

1
FileNotFoundError: [Errno 2] No such file or directory: '.../checkpoint.pth'

What Happened: My code looked for checkpoint.pth. My training script saved pytorch_model.bin. Nobody told me.

The Fix:

1
2
model.load_state_dict(torch.load(f'{checkpoint_path}/pytorch_model.bin'))
optimizer.load_state_dict(torch.load(f'{checkpoint_path}/optimizer_state.bin'))

Lesson: Actually look at what files exist before writing code that loads them. Revolutionary, I know.


Challenge 4: The Glacial Training Run 🐢

The Symptom: 1.6 seconds per step on an A100. The A100. The $10,000 GPU. Running like a potato.

The Problem: Memory-mapped files + random access = death by I/O.

Every batch, mmap was doing random reads across an 11GB file. The disk was crying. The GPU was bored.

The Fix: Just… load it into RAM:

1
data = torch.from_numpy(np.fromfile('/content/tokens_bpe.bin', dtype=np.int32))

Result: 0.8s/step. 2x faster. The GPU started doing GPU things.

Lesson: If it fits in RAM, stop being clever with mmap.


Challenge 5: Python Being Python

The Symptom: Still 0.8s/step. GPU utilization still sad.

The Culprit: This innocent-looking line:

1
batch = torch.stack([data[i:i+SEQ_LEN+1] for i in idx])

Ah yes, a Python for loop in the hot path. Peak performance.

The Fix: Vectorize like an adult:

1
2
3
4
offsets = torch.arange(SEQ_LEN + 1)
idx = torch.randint(0, train_size - SEQ_LEN - 1, (BATCH_SIZE,))
indices = idx.unsqueeze(1) + offsets
batch = data[indices]  # Zero Python loops

Result: 0.5s/step. 1.6x faster. Python loops are for preprocessing, not training.


Challenge 6: The GPU That Wasn’t Trying

The Symptom: 0.5s/step. GPU at 60% utilization. It’s just… vibing.

The Fixes:

Step 1 - torch.compile():

1
model = torch.compile(model)

PyTorch 2.0’s magic spell. Fuses operations, generates custom CUDA kernels, makes everything faster.

Step 2 - Mixed Precision (AMP):

1
2
3
4
5
6
7
8
9
scaler = torch.amp.GradScaler('cuda')

with torch.amp.autocast('cuda'):
    logits = model(X)
    loss = criterion(logits.view(-1, vocab_size), Y.view(-1))

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

Bfloat16 goes brrr.

Result: 0.225s/step. 2.2x faster. GPU finally earning its keep.


Challenge 7: The OOM That Wouldn’t Quit

The Error:

1
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate X GiB

What I Did: Got greedy. batch_size=128, mode='max-autotune'. YOLO.

What Happened: The GPU politely informed me that 40GB isn’t infinite.

The Fix:

1
2
BATCH_SIZE = 64  # Humility
model = torch.compile(model)  # Default mode, not "try everything and explode"

Lesson: max-autotune is cool but uses extra VRAM for kernel search. Start conservative.


Challenge 8: The Zombie Memory 🧟

The Horror: After an OOM crash, GPU memory shows 79GB in use. Nothing is running. The memory just… won’t leave.

What I Tried:

1
2
3
gc.collect()  # Please?
torch.cuda.empty_cache()  # Pretty please?
torch.cuda.reset_peak_memory_stats()  # I'm begging here

What Worked: None of that. Runtime > Restart Runtime. Start over. Accept defeat.

Lesson: After OOM, the crashed process ghosts you with its memory. Restart is the only exorcism.


Challenge 9: The Tokenizer That Changed Its Mind

The Problem: Trained with character tokenizer. New data uses BPE. Code hardcoded character tokenizer. Surprise!

The Fix:

1
2
3
4
5
6
7
tokenizer_type = metadata.get('tokenizer_type', 'character')
if tokenizer_type == 'bpe':
    from tokenizer_bpe import BPETokenizer
    tokenizer = BPETokenizer()
else:
    from tokenizer import CharacterTokenizer
    tokenizer = CharacterTokenizer()

Lesson: Hardcoding is technical debt with a high interest rate.


Challenge 10: The dtype Strikes Back

The Problem: Character tokenization used int16 (vocab < 65K). BPE uses 32K vocab but I used int32 for safety. Code assumed int16.

The Fix:

1
2
dtype_str = metadata.get('dtype', 'int16')
mmap_data = np.memmap(args.bin_file, dtype=dtype_str, mode='r')

Lesson: Store dtype in metadata. Read dtype from metadata. Trust no assumptions.


The Optimization Journey (1.6s → 0.225s)

What I DidSpeedSpeedupEffort Level
Baseline (mmap)1.6s/step1xN/A
RAM preload0.8s/step2x1 line of code
Vectorized batching0.5s/step3.2x20 minutes
+ torch.compile0.35s/step4.6x1 line of code
+ AMP0.225s/step7.1x10 lines of code

7x speedup. ~100 hours saved. Not bad for an afternoon of debugging.


The Takeaways

On Optimization

  • Profile first. Is it I/O? Compute? Python? Find out before changing random things.
  • RAM beats mmap for random access. Memory-mapping is for sequential reads, not ML training.
  • Vectorize or suffer. Python loops in hot paths are a crime.
  • torch.compile + AMP = free performance. If you’re not using both in 2026, you’re volunteering for slow training.

On Colab Survival

  • Checkpoints go to Drive. Colab will disconnect. It’s not a question of if.
  • OOM = restart runtime. empty_cache() is a suggestion, not a command.
  • Budget compute units. A100 burns 8 units/hour. Math accordingly.

On Debugging

  • Check dtypes. Then check them again. Then add .long() anyway.
  • Config mismatches are silent killers. SEQ_LEN, vocab_size, embed_dim — they all must match.
  • Verify file names exist. Before writing the code that loads them. Wild concept.