Skip to main content

    Lesson 39 • Advanced

    Model Compression

    Shrink a trained model so it ships on phones, browsers, and cheap GPUs — quantization, pruning, knowledge distillation, low-rank factorization, and the accuracy/size/latency trade-off that decides which one to reach for.

    What You'll Learn in This Lesson

    • Quantize FP32 weights to INT8 with a scale factor, and measure the error
    • Prune away small weights with magnitude and structured pruning
    • Distill a large teacher model into a small student with soft labels
    • Use low-rank factorization to split one big matrix into two small ones
    • Weigh the accuracy vs size vs latency trade-off for a target device
    • Avoid the accuracy cliff, the no-speedup pruning trap, and skipping fine-tuning

    🧳 Real-World Analogy: Packing a Suitcase

    A full-precision model is an over-packed suitcase you can barely close. Compression is packing it smartly so it fits the carry-on limit — the same trip, far less bulk. Each technique is a different packing trick:

    • Quantization — swap hardcover books for paperbacks. Same words, a quarter of the weight (FP32 → INT8).
    • Pruning — leave behind the clothes you never wear. Most weights barely matter, so drop them.
    • Knowledge distillation — instead of the whole encyclopaedia, carry an expert's concise summary (a small student trained by a big teacher).
    • Low-rank factorization — roll your clothes instead of folding flat. The same content in a more compact shape (one big matrix becomes two small ones).

    And the trade-off is the airline's scale: pack too aggressively and you leave behind something you needed (accuracy). The art is the lightest bag that still has everything for the trip — small enough for an edge device, cheap to run, and fast to respond.

    1Why Compress a Model at All?

    A trained model is just a big pile of numbers (its weights). The more numbers and the more bits each one uses, the more memory it eats and the slower it runs. Compression trades a sliver of accuracy for three things you nearly always need in production:

    • Edge devices — a phone, watch, or browser has a few GB of RAM and no datacentre GPU. A 280GB model simply won't load; a 4GB one will.
    • Cost — smaller models fit on cheaper hardware. A LLaMA-70B model in FP32 needs four A100 GPUs; in INT4 it fits on a single 24GB card.
    • Speed — fewer bytes to move and integer maths to run means lower latency (time per request), which users feel directly.

    2Quantization — Fewer Bits per Number (FP32 → INT8)

    Most models store each weight as a 32-bit float (FP32) — 4 bytes that can represent tiny fractions very precisely. Quantization asks: do you really need that precision? Often you can map those floats onto 8-bit integers (INT8) — just 256 possible values, 1 byte each. That is an instant 4x size cut, and integer maths runs faster on most CPUs and accelerators.

    The trick is a single scale factor. You find the largest weight, divide the range into 256 steps, and store which step each weight is closest to. To use the model you multiply the integer back by the scale to get an approximate float. The gap between the original and the recovered float is the quantization error — small if your weights are well-behaved.

    Going further — INT4 (16 levels) gives 8x compression but a bigger error, so it needs smart schemes like GPTQ or AWQ. The two ways to apply quantization: post-training (PTQ) quantizes an already-trained model (easy, slight loss), while quantization-aware training (QAT) simulates the rounding during training so the model learns to tolerate it (better, more work).

    Worked example — one-line INT8 quantization in PyTorch:

    # Post-training dynamic quantization in PyTorch.
    # Maps FP32 weights to INT8 with one line — 4x smaller, faster on CPU.
    
    import torch
    import torch.nn as nn
    
    # A small model standing in for any trained network.
    model = nn.Sequential(
        nn.Linear(512, 512),
        nn.ReLU(),
        nn.Linear(512, 10),
    )
    model.eval()
    
    # Quantize the Linear layers to INT8. Dynamic = activations are
    # quantized on the fly at inference, weights are stored as INT8.
    quantized = torch.quantization.quantize_dynamic(
        model, {nn.Linear}, dtype=torch.qint8
    )
    
    def size_mb(m):
        torch.save(m.state_dict(), "tmp.pt")
        import os
        return os.path.getsize("tmp.pt") / 1e6
    
    print(f"FP32 size: {size_mb(model):.2f} MB")
    print(f"INT8 size: {size_mb(quantized):.2f} MB")
    
    # Expected output (approximate):
    #   FP32 size: 1.07 MB
    #   INT8 size: 0.28 MB

    The INT8 state dict is roughly a quarter of the FP32 one — exactly the 4x you'd expect from 4 bytes down to 1. You'll build this scale-and-round logic by hand in the runnable example below.

    3Pruning — Delete the Weights That Barely Matter

    Trained networks are wasteful: a large fraction of weights are so close to zero they contribute almost nothing. Pruning deletes them. The simplest recipe is magnitude pruning — pick a threshold, then zero out every weight whose absolute value is below it. The fraction of weights now zero is the sparsity.

    There is a crucial catch. Magnitude (unstructured) pruning scatters zeros anywhere in the matrix — you get high sparsity, but the matrix is still the same shape, so ordinary GPUs do the same amount of work and you get no speedup, only a smaller file (when stored sparsely). Structured pruning instead removes whole rows, channels, or attention heads, so the matrix genuinely shrinks and standard hardware runs it faster. Rule of thumb: unstructured for size, structured for latency.

    Magnitude (unstructured)

    • Zeroes individual weights below a threshold
    • Reaches very high sparsity (90%+)
    • Smaller file, but no speedup on dense hardware

    Structured

    • Removes whole rows / channels / heads
    • Lower max sparsity, harder to do well
    • Matrix actually shrinks → real latency win

    4Knowledge Distillation — Teacher Trains Student

    Rather than shrink a model, knowledge distillation trains a brand-new small student model to imitate a large, accurate teacher. The clever part is what the student copies. A normal label is hard — "this image is a cat, everything else is 0%". The teacher's full probability distribution is a soft label: "94% cat, 4% dog, 1% bird". That extra structure — the teacher quietly saying "this looks a bit like a dog too" — is the dark knowledge the student learns from.

    To expose that structure you raise the softmax temperature. Higher temperature flattens the distribution so the small probabilities become visible and informative. The student is trained to match these softened teacher outputs (often plus the real labels). The result is a small model that punches well above its size — DistilBERT is 40% smaller than BERT while keeping about 97% of its accuracy.

    Worked example — the distillation loss in PyTorch:

    # The core of knowledge distillation: a student copies a teacher's
    # SOFT labels (full probability distribution), not just the hard label.
    
    import torch
    import torch.nn.functional as F
    
    T = 4.0   # temperature: higher T = softer, more informative targets
    
    # Logits the teacher and student produce for one example, 5 classes.
    teacher_logits = torch.tensor([[5.0, 2.0, 0.5, 0.1, -1.0]])
    student_logits = torch.tensor([[4.2, 2.3, 0.4, 0.2, -0.8]])
    
    # Soften both with temperature, then match them with KL-divergence.
    teacher_soft = F.softmax(teacher_logits / T, dim=1)
    student_soft = F.log_softmax(student_logits / T, dim=1)
    
    distill_loss = F.kl_div(student_soft, teacher_soft, reduction="batchmean") * (T * T)
    
    print(f"Teacher soft: {teacher_soft.round(decimals=3).tolist()[0]}")
    print(f"Distill loss: {distill_loss.item():.4f}")
    
    # Expected output (approximate):
    #   Teacher soft: [0.63, 0.298, 0.205, ...]
    #   Distill loss: 0.0009

    The KL-divergence measures how far the student's softened output is from the teacher's. Multiplying by T*T keeps the gradient scale right when the temperature is high. Minimise this and the student's distribution slides toward the teacher's.

    5Low-Rank Factorization — Split One Big Matrix Into Two Small Ones

    A linear layer is a weight matrix. A 1000×1000 matrix holds a million numbers. Low-rank factorization approximates that matrix as the product of two skinny matrices — say 1000×r times r×1000 for a small rank r. With r = 50 that is 1000×50 + 50×1000 = 100,000 numbers instead of 1,000,000 — a 10x reduction for that layer.

    It works because real weight matrices are often low rank: their information lives in far fewer dimensions than their shape suggests, so a smaller rank captures almost all of it. The same idea powers LoRA, the dominant way to fine-tune large language models cheaply — you train only a tiny low-rank update instead of the full matrix. The cost is that too small a rank throws away real signal, so you pick the smallest r your accuracy can tolerate.

    # Parameter count: full matrix vs a rank-r factorization
    d_in, d_out, r = 1000, 1000, 50
    
    full   = d_in * d_out              # 1,000,000 numbers
    factor = d_in * r + r * d_out      # 100,000 numbers
    print(f"full:   {full:,}")         # full:   1,000,000
    print(f"rank-{r}: {factor:,}")     # rank-50: 100,000
    print(f"compression: {full / factor:.0f}x")  # compression: 10x

    6The Accuracy / Size / Latency Trade-off

    There is no free lunch — every technique spends a little accuracy to buy size and speed. Choosing well means knowing your target. Deploying to a watch? You need maximum size and latency savings and can accept more accuracy loss. Serving a medical model? Guard accuracy and compress conservatively.

    In practice you stack techniques: distill to a smaller architecture, prune it structurally, then quantize to INT8 — combinations reach 10–50x compression on edge devices. Crucially, some tasks live near an accuracy cliff: maths reasoning and code generation collapse under aggressive quantization, while text classification barely notices. Always benchmark on your task, not a generic one.

    If your priority is…Reach forWatch out for
    Easy 2–4x with near-zero lossFP16 / INT8 quantization (PTQ)Sensitive tasks may still drop
    Smallest file sizeINT4 + unstructured pruningOften no speedup on dense GPUs
    Lowest latencyStructured pruning + INT8Harder to keep accuracy
    A whole new tiny modelKnowledge distillationNeeds the teacher + a training run
    Cheap fine-tuning of an LLMLow-rank factorization (LoRA)Too-low rank loses signal

    ▶️ Worked Example: Quantize to INT8 and Back (run it)

    This is INT8 quantization with no libraries — just a scale factor, rounding, and the error it costs. Read the comments, then press run and watch the recovered floats land close to the originals.

    Worked Example: INT8 Quantization Round-Trip

    Map FP32 weights onto INT8 with a scale factor and measure the rounding error

    Try it Yourself »
    Python
    # Worked example: quantize a list of FP32 weights to INT8 and back.
    # INT8 stores 256 levels (-128..127). We map the float range onto that
    # grid with a single 'scale' factor, then measure the rounding error.
    
    weights = [0.12, -0.83, 0.55, -0.20, 0.91, -0.47, 0.05, 0.73]
    
    # 1) Build the scale: how much value each integer step represents.
    w_max = max(abs(w) for w in weights)     # symmetric range around 0
    scale = w_max / 127                       # 127 = max positive INT8
    
    # 2) Quantize: float ->
    ...

    🎯 Your Turn #1: Quantize the Weights

    Fill in the two blanks marked ___ so the round-trip works: build the scale from INT8's largest value, and dequantize with the same scale. Check your output against the # ✅ Expected output comment.

    Your Turn #1: INT8 Quantization

    Finish the scale and the dequantize step, then read the mean error

    Try it Yourself »
    Python
    # 🎯 YOUR TURN #1 — quantize floats to INT8 and measure the error
    # Fill in the two blanks marked ___ so the round-trip works.
    
    weights = [0.40, -0.10, 0.80, -0.60, 0.20]
    
    w_max = max(abs(w) for w in weights)
    scale = w_max / ___          # 👉 INT8's largest positive value is 127
    
    # Quantize each weight to the nearest integer step, then back to float.
    q   = [round(w / scale) for w in weights]
    deq = [i * ___ for i in q]   # 👉 multiply by the SAME scale to recover floats
    
    mean_err = sum(abs(w - d)
    ...

    🎯 Your Turn #2: Prune the Small Weights

    Fill in the two blanks so magnitude pruning zeroes every weight below the threshold and counts the survivors. Half the weights should fall away here.

    Your Turn #2: Magnitude Pruning

    Zero the small weights, count what remains, and read the sparsity

    Try it Yourself »
    Python
    # 🎯 YOUR TURN #2 — magnitude-prune the small weights
    # Zero out every weight whose absolute value is below the threshold,
    # then count how many survive.
    
    weights = [0.9, -0.05, 0.4, 0.02, -0.7, 0.01, 0.6, -0.03]
    threshold = 0.1
    
    # Keep a weight if |weight| is at or above the threshold, else set it to 0.
    pruned = [w if abs(w) >= threshold else ___ for w in weights]   # 👉 prune small ones to 0
    
    # Count the survivors: the weights that are still non-zero.
    remaining = sum(1 for w in pruned if w != 
    ...

    Common Errors (And How to Fix Them)

    These four mistakes sink most first compression attempts:

    ❌ Too-aggressive quantization (the accuracy cliff)

    Jumping straight to INT4 on a sensitive task — maths reasoning, code generation — and watching quality fall off a cliff while text-classification benchmarks looked fine.

    ✅ Fix: step down gradually (FP16 → INT8 → INT4), use a smart scheme like GPTQ/AWQ for 4-bit, and benchmark on your actual task at each step.

    ❌ Unstructured pruning, expecting a speedup

    Pruning 90% of weights to zero and being baffled that inference is exactly as slow — the matrix is the same shape, just full of zeros, so a dense GPU does the same work.

    ✅ Fix: use structured pruning (remove whole channels/heads) for latency, or sparse-aware kernels/hardware if you must keep unstructured sparsity.

    ❌ No fine-tune after compressing

    Pruning or quantizing and shipping immediately, then losing several points of accuracy that were completely recoverable.

    ✅ Fix: fine-tune the compressed model for a few epochs (or use quantization-aware training) so the surviving weights adapt and claw the accuracy back.

    ❌ Trusting average accuracy, hitting an accuracy cliff in production

    A compressed model that looks fine on average but collapses on a specific slice (a rare class, long inputs) because you only checked an aggregate number.

    ✅ Fix: evaluate before vs after on the same held-out set, broken down by slice, and set an accuracy budget you refuse to cross.

    📋 Quick Reference

    TechniqueWhat it doesTypical compressionSpeedup?
    FP16 quantization16-bit floats instead of 322xYes
    INT8 quantization (PTQ)FP32 → INT8 via a scale4xYes
    INT4 (GPTQ / AWQ)4-bit weights for LLMs8xYes
    Magnitude pruningZero small weights≈2x (file)No (dense HW)
    Structured pruningDrop whole channels/heads1.5–3xYes
    Knowledge distillationSmall student copies teacher2–10xYes
    Low-rank factorizationOne matrix → two small ones2–10x/layerYes

    ❓ Frequently Asked Questions

    Q: What is model compression?

    A: Model compression is a set of techniques that make a trained model smaller and faster to run while keeping its accuracy as close to the original as possible. The four main techniques are quantization (use fewer bits per number), pruning (delete weights that barely matter), knowledge distillation (train a small student model to copy a large teacher), and low-rank factorization (replace one big weight matrix with two small ones).

    Q: What is quantization and why does it shrink a model?

    A: Quantization stores each weight using fewer bits — for example mapping 32-bit floats (FP32) onto 8-bit integers (INT8). Since each number now takes 1 byte instead of 4, the model is 4x smaller, and integer maths runs faster on most hardware. You store a scale factor so you can convert the integers back to approximate floats at inference time.

    Q: What is the difference between magnitude pruning and structured pruning?

    A: Magnitude (unstructured) pruning zeroes out individual weights with the smallest absolute value, scattered anywhere in the matrix. It gives high sparsity but rarely speeds anything up, because the matrix is still the same shape with zeros in it. Structured pruning removes whole rows, channels, or attention heads, so the matrix actually gets smaller and standard hardware runs it faster.

    Q: What is knowledge distillation?

    A: Knowledge distillation trains a small student model to imitate a large teacher model. Instead of only learning the correct label, the student learns from the teacher's full probability distribution (the soft labels), which carries extra information about how classes relate. DistilBERT is a famous example: 40% smaller than BERT while keeping about 97% of its accuracy.

    Q: Why does my quantized model run at the same speed but lose accuracy?

    A: Two separate problems. If you quantize too aggressively (for example straight to INT4 on a sensitive task) you cross an accuracy cliff and quality collapses. If you prune in an unstructured way, the weights are zero but the matrix shape is unchanged, so dense hardware does the same amount of work — no speedup. Use structured pruning for real latency wins, and always fine-tune after compressing.

    Q: Do I need to retrain after compressing a model?

    A: Usually yes. A short fine-tune (or quantization-aware training) lets the remaining weights adjust to the loss caused by pruning or quantization, which recovers most of the dropped accuracy. Skipping the fine-tune is the single most common reason a compressed model underperforms.

    🎯 Mini-Challenge: Prune Then Quantize

    Now combine both skills with only a comment outline — no filled-in logic. Magnitude-prune a layer, then quantize the survivors to INT8 and print a tiny compression report.

    Mini-Challenge: Prune Then Quantize

    Prune small weights, count survivors, then quantize them to INT8

    Try it Yourself »
    Python
    # 🎯 MINI-CHALLENGE: a tiny compression report
    # 1. Start with weights = [0.7, -0.02, 0.5, 0.08, -0.9, 0.03, 0.6, -0.4]
    # 2. Magnitude-prune: set any weight with |w| < 0.1 to 0
    # 3. Count remaining non-zero weights and the sparsity %
    # 4. Quantize the SURVIVING weights to INT8:
    #       scale = max(abs(w) for non-zero w) / 127
    #       q = round(w / scale) for each surviving weight
    # 5. Print: remaining count, sparsity %, and the INT8 list
    #
    # ✅ Expected output (your numbers will match if the logi
    ...
    🎉

    Lesson 39 complete — you can shrink a model for deployment!

    You can quantize FP32 weights to INT8 with a scale factor, prune small weights with magnitude and structured pruning, distill a teacher into a small student, factorize a big matrix into two small ones, and weigh the accuracy/size/latency trade-off for any target device — while dodging the accuracy cliff, the no-speedup pruning trap, and skipping the fine-tune.

    🚀 Up next: Hardware Optimization — tune a compressed model for a specific chip with kernels, fusion, and accelerators.

    Sign up for free to track which lessons you've completed and get learning reminders.

    Previous

    Cookie & Privacy Settings

    We use cookies to improve your experience, analyze traffic, and show personalized ads. You can manage your preferences below.

    By clicking "Accept All", you consent to our use of cookies for analytics and personalized advertising. You can customize your preferences or reject non-essential cookies.

    Privacy PolicyTerms of Service