Skip to main content

    Lesson 22 • Advanced

    Training Stability Techniques

    By the end of this lesson you'll be able to spot an unstable training run, clip gradients to a safe norm, warm up and schedule the learning rate, and make every run reproducible — so deep learning actually converges instead of blowing up.

    What You'll Learn in This Lesson

    • You'll be able to explain exploding vs vanishing gradients in one sentence
    • You'll clip a gradient vector to a maximum norm — by hand and with PyTorch
    • You'll add learning-rate warmup and a decay schedule to a training loop
    • You'll choose weight init and normalisation that keep signals stable
    • You'll detect NaN/Inf loss and loss spikes, and find the first bad step
    • You'll seed every RNG source to make a run fully reproducible

    🏗️ Real-World Analogy: keeping a tall tower from toppling

    Think of a deep network as a very tall tower of building blocks, one block per layer. Training nudges the blocks to make the tower taller and straighter. The taller it gets, the easier it is to topple — and the same techniques a builder uses map exactly onto stable training:

    • A level foundation = good weight initialisation. Start the blocks balanced (He/Xavier) and the tower rises straight; start them crooked and it leans from block one.
    • Re-levelling each floor = normalisation. Batch/Layer Norm re-centre each layer's signal so errors don't accumulate up the stack.
    • A gentle starting push = learning-rate warmup. Shove a fresh tower hard and it tips; push softly at first, then harder once it's settled.
    • A safety cap on each shove = gradient clipping. Limit how hard any single nudge can be so one bad gradient can't knock the whole thing over.
    • Noticing the lean early = NaN/spike detection. Catch the wobble at floor 5, not after the tower is already rubble (every later block is garbage).

    1Exploding and Vanishing Gradients

    During backpropagation the gradient is multiplied through every layer. Exploding gradients happen when those factors are larger than 1 on average — the product grows huge, the weights jump, and the loss blows up to inf then NaN. Vanishing gradients are the opposite: factors below 1 shrink the product toward zero, so the early layers barely update and the network stops learning.

    You spot exploding gradients by watching the gradient norm climb and the loss spike; you spot vanishing gradients when the loss plateaus early and the first layers' weights hardly move. The cures in this lesson — init, normalisation, clipping, and warmup — all work by keeping that running product near 1.

    2Reading the Loss Curve: spikes, NaN, and Inf

    The loss numbers tell you almost everything. A sudden jump (say 3x the previous step) is a loss spike — often a too-high learning rate or an unclipped gradient. A NaN ("not a number") means a value became undefined and the run is dead from that point on. The neat trick: NaN is the only value that is not equal to itself, so loss != loss is a one-line NaN test (no import needed).

    Run this and watch it flag the spike at step 4 and the NaN at step 7:

    Worked Example: spot the spike and the NaN

    Walk a loss list and flag the first sign of trouble

    Try it Yourself »
    Python
    # Worked example: read a loss curve and spot the trouble
    # A training run that goes wrong usually leaves clues in the loss numbers.
    
    losses = [2.40, 2.10, 1.85, 1.60, 9.80, 1.55, 1.50, float("nan"), 1.48]
    #                                  ^^^^ a loss SPIKE         ^^^ then NaN
    
    prev = losses[0]
    for step, loss in enumerate(losses):
        note = ""
        if loss != loss:                 # NaN is the only value not equal to itself
            note = "<-- NaN! training is broken from here"
        elif loss > 
    ...

    3Gradient Clipping by Norm

    The cleanest fix for exploding gradients is to cap the norm (the length) of the gradient. If the gradient's L2 norm exceeds a threshold, you scale the whole vector down by one factor — max_norm / norm — so its direction stays the same and only its size is capped. That last point is why clipping by norm beats clipping each value independently: distorting direction sends the optimiser somewhere it never intended to go.

    This worked example clips the vector [3, 4, 12] (norm 13) down to norm 5:

    Worked Example: clip a gradient to a max norm

    Scale the whole vector down so its length is at most max_norm

    Try it Yourself »
    Python
    # Worked example: clip a gradient vector to a maximum norm
    # Exploding gradients = the update vector gets huge and blows up the weights.
    # The fix: if its length (L2 norm) is too big, scale the WHOLE vector down.
    
    def l2_norm(vec):
        """Length of the vector = sqrt(sum of squares)."""
        return sum(v * v for v in vec) ** 0.5
    
    def clip_grad_norm(grad, max_norm):
        """Scale grad so its norm is at most max_norm. Direction is preserved."""
        norm = l2_norm(grad)
        if norm > max_norm:
           
    ...

    In real code you never hand-roll this — PyTorch's clip_grad_norm_ does it across every parameter at once and returns the pre-clip norm so you can log it:

    Worked Example (PyTorch): clip_grad_norm_

    The production version — read it, note the # Expected output

    Try it Yourself »
    Python
    import torch
    import torch.nn as nn
    
    # The real thing: PyTorch ships clip_grad_norm_ so you never hand-roll it.
    # It clips across ALL the model's parameters at once, in place.
    
    torch.manual_seed(0)                     # reproducibility: same run every time
    model = nn.Linear(4, 1)
    x = torch.randn(8, 4)
    y = torch.randn(8, 1)
    
    loss = ((model(x) - y) ** 2).mean()
    loss.backward()                          # fills .grad on every parameter
    
    before = torch.nn.utils.clip_grad_norm_(model.parameters(), max_
    ...

    🎯 Your Turn: finish the norm clipper

    Fill in the two blanks so the clipper scales [6, 8] (norm 10) down to norm 5.

    Your Turn: clip_grad_norm

    Replace each ___ using the 👉 hints, then check against the expected output

    Try it Yourself »
    Python
    # 🎯 YOUR TURN — finish the gradient-norm clipper (fill in the ___)
    
    def l2_norm(vec):
        return sum(v * v for v in vec) ** 0.5
    
    def clip_grad_norm(grad, max_norm):
        norm = l2_norm(grad)
        if norm > max_norm:
            scale = ___                       # 👉 shrink factor = max_norm / norm
            return [g * scale for g in grad]
        return list(grad)                     # already small enough
    
    grad = [6.0, 8.0]                         # norm = sqrt(36+64) = 10.0
    clipped = clip_grad_norm(grad
    ...

    4Learning-Rate Warmup, Schedules, and Mixed Precision

    A fresh model has random weights, so the first gradients are loud. Warmup ramps the learning rate linearly from near zero up to your target over the first few steps, letting Adam's running statistics settle before you take big steps. After warmup you usually decay the rate — cosine or step schedules — for a cleaner final model.

    Mixed precision (AMP) speeds training by computing in lower precision, but tiny gradients can underflow to zero. GradScaler multiplies the loss up before backward() and unscales afterwards; it also skips the step automatically when it detects inf/NaN gradients. Always unscale_ before clipping so the clip sees the true norms:

    Worked Example (PyTorch): warmup + AMP + clipping

    The full stability stack in one loop — read the # Expected output

    Try it Yourself »
    Python
    import torch
    import torch.nn as nn
    from torch.cuda.amp import autocast, GradScaler
    
    # Production stability stack: warmup schedule + AMP loss scaling + clipping.
    torch.manual_seed(0)
    
    model = nn.Linear(4, 1)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler()                    # AMP: scales loss so small grads don't underflow to 0
    warmup_steps = 5
    base_lr = 1e-3
    
    def lr_at(step):
        # Linear warmup: ramp 0 -> base_lr over warmup_steps, then hold.
        if step < warmup_steps
    ...

    5Weight Init and Normalisation for Stability

    Initialisation sets the starting scale of every weight so the signal neither shrinks nor blows up as it passes through layers. Use He (Kaiming) init for ReLU networks and Xavier (Glorot) for tanh/sigmoid. Never use plain randn * 1.0 or randn * 0.01 for a deep network — the first explodes, the second vanishes.

    Normalisation re-centres each layer's activations during training so the next layer always sees a well-behaved distribution. Batch Norm normalises across the batch (great for CNNs with large batches); Layer Norm normalises across features (the default for Transformers and RNNs, and the safe choice for small batches).

    Use for ReLU / large batches

    nn.init.kaiming_normal_(w)   # He
    nn.BatchNorm2d(channels)     # Batch Norm

    Use for Transformers / small batches

    nn.init.xavier_uniform_(w)   # Xavier
    nn.LayerNorm(features)       # Layer Norm

    🎯 Your Turn: find the first bad step

    Detecting trouble early is half the battle. Fill in the blanks so the loop reports the first NaN or Inf in the loss list.

    Your Turn: detect the first NaN/Inf

    Replace each ___ using the 👉 hints, then check against the expected output

    Try it Yourself »
    Python
    # 🎯 YOUR TURN — report the FIRST broken step in a loss list (fill in the ___)
    import math
    
    losses = [1.9, 1.5, 1.2, float("inf"), 0.9, float("nan")]
    
    first_bad = None
    for step, loss in enumerate(losses):
        if math.isnan(loss) or ___:           # 👉 also catch infinity: math.isinf(loss)
            first_bad = step
            break                             # stop at the first bad step
    
    if first_bad is ___:                      # 👉 None means we found nothing bad
        print("All losses are finite —
    ...

    6Reproducibility: seed everything

    If two runs give different results you can't tell whether a change helped or you just got lucky. Seed every source of randomness before you build the model — Python's random, NumPy, and PyTorch (CPU and GPU):

    import random, numpy as np, torch
    
    def seed_everything(seed=42):
        random.seed(seed)                 # Python's RNG
        np.random.seed(seed)              # NumPy
        torch.manual_seed(seed)           # PyTorch CPU
        torch.cuda.manual_seed_all(seed)  # PyTorch GPU
        # For bitwise-identical runs (slower):
        torch.use_deterministic_algorithms(True)
    
    seed_everything(42)   # call ONCE, at the very top, before building the model

    Call seed_everything once at the top of your script. Without it, weight init, shuffling, and dropout all differ between runs.

    🎯 Mini-Challenge: a tiny training guard

    Support is faded now — only a comment outline is given. Write the whole function yourself.

    Mini-Challenge: guard(grad, loss, max_norm)

    Skip broken steps, clip the rest — match the expected output

    Try it Yourself »
    Python
    # 🎯 MINI-CHALLENGE: a tiny "training guard"
    # Write a function guard(grad, loss, max_norm) that:
    #   1. Returns "skip" if loss is NaN or Inf   (use math.isnan / math.isinf)
    #   2. Otherwise clips grad to max_norm by L2 norm and returns the clipped list
    # Then test it on a healthy step AND a NaN step.
    #
    # ✅ Expected output:
    # healthy: [3.0, 4.0]
    # broken : skip
    
    import math
    
    # your code here
    

    Common Errors (And How to Fix Them)

    ❌ No gradient clipping → loss explodes to NaN

    An unclipped gradient spikes, the weights jump, and the loss goes inf then NaN.

    ✅ Fix: add nn.utils.clip_grad_norm_(model.parameters(), 1.0) after backward() and before optimizer.step().

    ❌ Learning rate too high → loss spikes or diverges

    The loss oscillates or climbs instead of falling — the steps are too big.

    ✅ Fix: lower the LR (try 10x smaller) and add warmup so early steps stay small.

    ❌ Bad initialisation → dead neurons or instant explosion

    randn * 1.0 explodes through deep layers; randn * 0.01 makes activations vanish and ReLUs go dead.

    ✅ Fix: use He init for ReLU (kaiming_normal_), Xavier for tanh/sigmoid.

    ❌ Not seeding → results change every run

    You can't reproduce a result or tell whether a tweak actually helped.

    ✅ Fix: call seed_everything(42) once at the very top, before building the model.

    ❌ Ignoring NaN → every later step is garbage

    Once a NaN enters the weights it poisons all subsequent updates, but training keeps "running".

    ✅ Fix: check each step with loss != loss or math.isnan(loss) and stop or skip immediately.

    📋 Quick Reference

    TechniqueFixesHow / Use With
    Clip by normExploding gradientsclip_grad_norm_(params, 1.0)
    LR warmupEarly instabilityRamp 0 → base over first steps
    LR scheduleNoisy late trainingCosine / step decay after warmup
    He initVanishing activationsReLU networks
    Xavier initVanishing activationstanh, sigmoid
    Batch NormCovariate shiftCNNs, large batches
    Layer NormCovariate shiftTransformers, small batches
    AMP + GradScalerUnderflow / inf gradsMixed-precision training
    NaN checkSilent corruptionloss != loss / math.isnan
    SeedingNon-reproducibilityseed_everything(42) once

    ❓ Frequently Asked Questions

    Q: What causes exploding and vanishing gradients?

    A: Both come from repeatedly multiplying gradients through many layers during backpropagation. If the factors are bigger than 1 on average, the product grows without bound (exploding — you see huge values, then NaN). If they are smaller than 1, the product shrinks toward zero (vanishing — early layers stop learning). Good weight initialisation, normalisation, residual connections, and gradient clipping all keep these products near 1.

    Q: Should I clip gradients by value or by norm?

    A: Clip by norm. Clipping by value caps each element independently, which distorts the gradient's direction. Clipping by global norm scales the whole vector by one factor only when its total length exceeds the threshold, so the direction is preserved and only the magnitude is capped. In PyTorch use torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm). Typical max_norm is 1.0 for Transformers and around 5.0 for RNNs.

    Q: Why do I need learning-rate warmup?

    A: At the very start of training the weights are random, so the first few gradients can be large and noisy — especially with adaptive optimisers like Adam, whose running statistics are not yet warmed up. Ramping the learning rate linearly from near zero to the target over the first few hundred to a few thousand steps lets those statistics settle before you take big steps, which prevents an early loss explosion. After warmup most people decay the rate (cosine or step schedule) for a cleaner final model.

    Q: My loss became NaN — what now?

    A: NaN means a number became undefined (often 0/0, log(0), or overflow from an exploding gradient). Do not ignore it: once a NaN enters the weights every later step is garbage. Detect it early (a loss is NaN when loss != loss, or use math.isnan), then fix the cause — lower the learning rate, add or tighten gradient clipping, check for log(0) or divide-by-zero in your loss, and confirm your inputs contain no NaN/Inf. With mixed precision, GradScaler skips the update automatically when it sees inf or NaN gradients.

    Q: How do I make a training run reproducible?

    A: Seed every source of randomness before you build the model: random.seed(s), numpy.random.seed(s), and torch.manual_seed(s) (plus torch.cuda.manual_seed_all(s) on GPU). For bitwise-identical runs also set torch.use_deterministic_algorithms(True) and avoid nondeterministic GPU kernels. Reproducibility is what lets you tell whether a change actually helped, instead of chasing random noise between runs.

    🎉 Lesson Complete!

    You can now keep a tall network from toppling: you read a loss curve for spikes and NaN, clip gradients to a safe norm (by hand and with clip_grad_norm_), warm up and schedule the learning rate, pick init/normalisation that hold signals steady, and seed every RNG for reproducible runs.

    🚀 Up next: Generative Models — now that training is stable, build models that create entirely new data.

    Sign up for free to track which lessons you've completed and get learning reminders.

    Previous

    Cookie & Privacy Settings

    We use cookies to improve your experience, analyze traffic, and show personalized ads. You can manage your preferences below.

    By clicking "Accept All", you consent to our use of cookies for analytics and personalized advertising. You can customize your preferences or reject non-essential cookies.

    Privacy PolicyTerms of Service