Chapter 4 · Part 3
Learning to denoise
Everything so far has been arithmetic — no learning required. Now we reach the one part the model actually has to learn, and it's a surprisingly narrow task. We don't ask it to paint, or to understand art, or even to recover the original image directly. We ask it one question:
"Here's a noisy image and the timestep it came from. What noise was added?"
That's it. The network predicts the noise ε̂, and because the noisy image is just
image + noise, predicting the noise is the same as knowing how to remove it.
Scroll to train the network: watch its guess go from random garbage to the real noise, and the recovered image snap into focus.
Untrained, the network's predicted noise ε̂ is just random — subtracting it wrecks the image.
Why predict the noise instead of the image?
It seems roundabout — why not just output the clean image? Two reasons. First,
predicting noise turns out to be an easier, more stable target to learn: noise
looks statistically the same at every timestep, while clean images vary wildly.
Second, we already have the perfect label for free. The forward process from Chapter
3 handed us the exact ε it used, so every training example comes with its own
graded answer key — no human labeling, no captions, just images.
Training is then a tight loop:
- Take a real image
x₀, pick a random timestept. - Use the closed-form shortcut to make
xₜand remember the noiseε. - Show the network
xₜandt; it outputs a guessε̂. - Nudge its weights to make
ε̂closer toε. Repeat millions of times.
What's doing the predicting: a U-Net
The network here is almost always a U-Net — an architecture that shrinks the
image down through successive layers to capture the big picture, then expands it back
up to full resolution, with shortcut connections that preserve fine detail. The
timestep t is fed in too, so the network knows how noisy its input is supposed to
be and can calibrate its guess accordingly.
for x0 in dataset: # a batch of real images
t = random_timestep() # pick a noise level
xt, eps = q_sample(x0, t) # noise it; remember the noise
eps_hat = unet(xt, t) # the network's guess
loss = mean((eps - eps_hat) ** 2) # how wrong was it?
loss.backward() # nudge weights downhill
optimizer.step()After enough of these steps, the U-Net becomes a general-purpose noise spotter: hand it any noisy image at any timestep and it points to the noise hiding inside.
Where we're headed
We now have the missing piece — a trained network that can look at static and say "here's the noise." In the visual, we removed all of it in one go because we knew the answer. A real model isn't that confident: it removes just a little at a time, re-checks, and removes a little more. Do that starting from pure static and an image materializes. That loop — sampling — is next.