23 Neural nets - generalization

Stat 406

Geoff Pleiss, Trevor Campbell

Last modified – 18 November 2024

\[ \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator*{\argmax}{argmax} \DeclareMathOperator*{\minimize}{minimize} \DeclareMathOperator*{\maximize}{maximize} \DeclareMathOperator*{\find}{find} \DeclareMathOperator{\st}{subject\,\,to} \newcommand{\E}{E} \newcommand{\Expect}[1]{\E\left[ #1 \right]} \newcommand{\Var}[1]{\mathrm{Var}\left[ #1 \right]} \newcommand{\Cov}[2]{\mathrm{Cov}\left[#1,\ #2\right]} \newcommand{\given}{\ \vert\ } \newcommand{\X}{\mathbf{X}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\P}{\mathcal{P}} \newcommand{\R}{\mathbb{R}} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \newcommand{\snorm}[1]{\lVert #1 \rVert} \newcommand{\tr}[1]{\mbox{tr}(#1)} \newcommand{\brt}{\widehat{\beta}^R_{s}} \newcommand{\brl}{\widehat{\beta}^R_{\lambda}} \newcommand{\bls}{\widehat{\beta}_{ols}} \newcommand{\blt}{\widehat{\beta}^L_{s}} \newcommand{\bll}{\widehat{\beta}^L_{\lambda}} \newcommand{\U}{\mathbf{U}} \newcommand{\D}{\mathbf{D}} \newcommand{\V}{\mathbf{V}} \]

This lecture

  1. What factors affect generalization? (The ability to make accurate predictions)

  2. Why do NN generalize, despite having lots of parameters?

  3. Modern techniques to improve generalization

What factors affect generalization?

(The ability to make accurate predictions)

Tunable parameters of NN

  1. Number of hidden layers (\(L\))

  2. Width of hidden layers (\(D\))

  3. Nonlinearity function

  4. Loss function

  5. Initial SGD step size

  6. SGD step size decay rate

  7. SGD batch size

  8. SGD stopping criterion

  9. Amount of regularization (we’ll talk about this concept in a bit)

  10. Initialization of NN parameters

How to tune NN parameters

😓

  • There are exponentially many designs of NNs

  • Training a single NN is expensive

  • NN training depends on random initialization, so you generally have to do multiple runs

In Practice

  • Compare a handful of designs on a single holdout set (no cross val)

  • Principled NN architecture search is an active area of research

Some Common Patterns to Reduce the Search Space

  1. Use ReLU nonlinearities, and nothing else

  2. Use the same width for all layers (or grow with with a simple formula)

  3. Measure loss on a validation set throughout training, and stop SGD when the validation loss plateaus

  4. Ask a grad student for their tricks

Why do NN generalize…

… despite having tons of parameters?

Capacity vs Generalization

Consider a NN with ReLU nonlinearities \(g( \boldsymbol w^\top \boldsymbol z) = \max\{\boldsymbol w^\top \boldsymbol z, 0 \}\) with \(L\) hidden layers, each with \(D\) hidden activations.

Recall:

  • Number of piecewise-linear regions: \(O(D^L)\) (exponential!)

  • Number of parameters: \(O(D^2 L)\)

This implies:

  • Our NN is capable of learning complicated functions (many piecewise-linear components)

  • But will it learn the right function from limited data?

Recall: Bias/Variance Tradeoff For Trees

  • Neural networks have lots of parameters ( \(O(D^2 L)\), which is typically \(> n\) )

  • In theory, we would expect similar bias/variance curves for neural networks as a function of # params

The Surprising Bias/Var Curves For NN (Double Descent)

  • NN risk (as a function of # params) experiences a “double descent” shape?!?!?!

  • Most modern NN have tons of parameters, and so they’re explained by the right side of the graph

The Surprising Bias/Var Curves For NN (Double Descent)

Image credit: Belkin et al., (2019)

  • Double descent is a newly discovered phenomenon (~2019)

  • Statisticians are still trying to understand why it occurs.
    There has been good progress since ~2020!

To Understand Double Descent: Study Basis Regression

The double descent phenomenon is not specific to neural networks.
We can observe it in basis regression (read: linear models!) as we increase the number of basis functions \(> n\):

library(splines)
set.seed(20221102)
n <- 20
df <- tibble(
  x = seq(-1.5 * pi, 1.5 * pi, length.out = n),
  y = sin(x) + runif(n, -0.5, 0.5)
)
g <- ggplot(df, aes(x, y)) + geom_point() + stat_function(fun = sin) + ylim(c(-2, 2))
g

xn <- seq(-1.5 * pi, 1.5 * pi, length.out = 1000)
# Spline by hand
X <- bs(df$x, df = 20, intercept = TRUE)
Xn <- bs(xn, df = 20, intercept = TRUE)
S <- svd(X)
yhat <- Xn %*% S$v %*% diag(1/S$d) %*% crossprod(S$u, df$y)
g + geom_line(data = tibble(x = xn, y = yhat), colour = orange) +
  ggtitle("20 basis functions (n=20)")

xn <- seq(-1.5 * pi, 1.5 * pi, length.out = 1000)
# Spline by hand
X <- bs(df$x, df = 40, intercept = TRUE)
Xn <- bs(xn, df = 40, intercept = TRUE)
S <- svd(X)
yhat <- Xn %*% S$v %*% diag(1/S$d) %*% crossprod(S$u, df$y)
g + geom_line(data = tibble(x = xn, y = yhat), colour = orange) +
  ggtitle("40 basis functions (n=20)")

Code
doffs <- 4:50
mse <- function(x, y) mean((x - y)^2)
get_errs <- function(doff) {
  X <- bs(df$x, df = doff, intercept = TRUE)
  Xn <- bs(xn, df = doff, intercept = TRUE)
  S <- svd(X)
  yh <- S$u %*% crossprod(S$u, df$y)
  bhat <- S$v %*% diag(1 / S$d) %*% crossprod(S$u, df$y)
  yhat <- Xn %*% S$v %*% diag(1 / S$d) %*% crossprod(S$u, df$y)
  nb <- sqrt(sum(bhat^2))
  tibble(train = mse(df$y, yh), test = mse(yhat, sin(xn)), norm = nb)
}
errs <- map(doffs, get_errs) |>
  list_rbind() |> 
  mutate(`degrees of freedom` = doffs) |> 
  pivot_longer(train:test, values_to = "error")
ggplot(errs, aes(`degrees of freedom`, error, color = name)) +
  geom_line(linewidth = 2) + 
  coord_cartesian(ylim = c(0, .12)) +
  scale_x_log10() + 
  scale_colour_manual(values = c(blue, orange), name = "") +
  geom_vline(xintercept = 20)
  • Inflection point occurs when # basis functions = n!

  • This is the point at which our basis regressor is able to perfectly fit the training data.

Understanding Double Descent (Hand-Wavy)

Let \(\boldsymbol Z \in \R^{n \times d}\) be the matrix of basis expansions for our \(n\) training points.

Basis regression is just OLS with the basis expansion \(\boldsymbol Z\): \[ \min_{\boldsymbol \beta} \left\Vert \boldsymbol Z \boldsymbol \beta - \boldsymbol y \right\Vert_2^2. \]

  • When \(d < n\), the regressor is underparameterized.
    I.e. there is no \(\boldsymbol \beta\) that perfectly explains our training responses given our basis-expanded training inputs.

  • When \(d = n\), there is a value of \(\boldsymbol \beta\) that fits our training data perfectly.
    I.e. \(\Vert \boldsymbol Z \boldsymbol \beta - \boldsymbol y \Vert = 0\).

    • We are fitting both the noise and the signal (leading to a high variance predictor).
  • When \(d > n\), we can also fit the data (noise + signal) perfectly.👋 However, more features implies that the the noise gets “spread out” over all of parameters. 👋

    • 👋 Since each parameter only captures “some” of the noise, we are less likely to make predictions based on it. 👋

    • This explanation is overly simplified, and there is a lot more at play.

Understanding Double Descent (Less Hand-Wavy)

(From Hastie et al., 2020)

  • \(\gamma = D / N\) (ratio of features / data)

  • \(\sigma^2 = \mathbb{E}[Y|X]\) (observational noise)

  • When basis features are uncorrelated, we have (asymptotically)

\[ \begin{aligned} \mathrm{Bias}^2 &= \begin{cases} 0 & \gamma < 1 \text{ (underparam.)} \\ 1 - \tfrac{1}{\gamma} & \gamma \geq 1 \text{ (overparam.)} \end{cases} \\ & \\ \mathrm{Var} &= \begin{cases} \sigma^2 \tfrac{\gamma}{1 - \gamma} & \gamma < 1 \text{ (underparam.)} \\ \sigma^2 \tfrac{1}{\gamma - 1} & \gamma \geq 1 \text{ (overparam.)} \end{cases} \\ \end{aligned} \]

Do we need to worry about variance?

Regularizing a neural network (adding a complexity penalty to the loss) is a common practice to prevent overfitting to the noise.

\[ \argmin_{\boldsymbol W^{(t)}, \boldsymbol \beta} \sum_{i=1}^n \ell(y_i, \hat f_\mathrm{NN}(\boldsymbol x_i) \: + \: \text{complexity penalty} \]

E.g. weight decay / L2 regularization:

\[ \text{complexity penalty} = \frac{\lambda}{2} \left( \Vert \boldsymbol \beta \Vert_2^2 + \sum_{i=1}^L \Vert \mathrm{vec} (\boldsymbol W^{(L)}) \Vert_2^2 \right) \]

  • \(\lambda\) is a tuning parameter

  • What does weight decay / L2 regularization remind you of? Think about linear models

Do we need to worry about variance?

\[ \text{complexity penalty} = \frac{\lambda}{2} \left( \Vert \boldsymbol \beta \Vert_2^2 + \sum_{i=1}^L \Vert \mathrm{vec} (\boldsymbol W^{(L)}) \Vert_2^2 \right) \]

  • Before we understood double descent, we used to think you needed high \(\lambda\) (lots of regularization) to combat high variance

    • People invented many other regularizers (e.g. dropout, pruning, mixup, etc.)
  • Now that we understand double descent (and we realize we don’t have a variance problem), it’s now uncommon to do anything more than light weight decay (small \(\lambda\))

Modern Techniques to Improve Generalization

Specialty architectures

So far we’ve studied neural networks where we (recursively) construct basis functions from “building blocks” of the form: \[ \boldsymbol a^{(t)}_j = g( \boldsymbol w^{(i)\top}_j \boldsymbol a^{(t - 1)}) \]

  • These neural networks are known as multilayer perceptrons (MLP).

  • By using different building blocks, we can make neural networks that are more adept to different types of data. E.g.:

    1. Convolutional NN (good for image data)

    2. Graph NN (good for molecules, social networks, etc.)

    3. Transformers (good for language and sequential data)

Specialty architectures: convolutional neural networks

Rather than computing an inner product with the hidden layer parameters (i.e. \(\boldsymbol w^{(i)\top}_j \boldsymbol a^{(t - 1)}\)), we instead perform a convolution:

\[ \boldsymbol a^{(t)}_j = g( \boldsymbol w^{(i)}_j \star \boldsymbol a^{(t - 1)}) \]

  • Captures spatial correlations amongst neighbouring pixels

  • Predictions remain constant even if we translate objects in the image

The convolutional building blocks are usually combined with other building blocks, like pooling layers and normalization layers.

Specialty architectures: convolutional neural networks

Why is a convolutional neural network better for images?

Image credit: Varsha Kishore

  • With an standard MLP, we’d need to “flatten” our image into a vector of pixels. This flattening doesn’t preseve spatial correlations amongst pixels.

  • If the dog in our image shifts, then we are not guaranteed to make the same prediction (we are not translation invariant).

Transfer Learning

You want to build an image classifier for CT scans, but you only have \(n=1000\) 😢
Conventional wisdom would tell you that you don’t have enough data to train a neural network.

Transfer learning to the rescue!

  • Start with an existing neural network trained on a related predictive task

  • Train this neural network on your data using gradient descent with a small step size
    Also known as fine-tuning

Transfer Learning

Why Does This Work?

  • The original NN has learned basis functions that are REALLY good for image data

  • You are now essentially using these good basis functions on your smaller dataset

Final Thoughts

  • Not much theory for why NN work (though this is increasing)

  • NN are best for unstructured data types (e.g. images, text, etc.)

    • Best when combined with a specialty architecture (e.g. convolutional NN)

    • If you have “tabular” data, use another algorithm (e.g. random forest)

  • Transfer learning is now the defacto approach

    • Try not to train NN from scratch

    • Makes NN work for small datasets

  • NN are computational expensive

    • They won’t run on your laptiop

    • You need a GPU cluster

  • NN are amazing, but they’re not always the right solution. What are some other downsides?

Final Thoughts

  • If you want to play around with NN, learn Python

    • There’s an example on the website of how to train NN in R. It’s gnarly.

    • Use the PyTorch library

  • There’s a wide world of NN to learn about!

Next time…

Module 5

unsupervised learning