00 Gradient descent

Stat 406

Geoff Pleiss, Trevor Campbell

Last modified – 21 October 2024

\[ \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator*{\argmax}{argmax} \DeclareMathOperator*{\minimize}{minimize} \DeclareMathOperator*{\maximize}{maximize} \DeclareMathOperator*{\find}{find} \DeclareMathOperator{\st}{subject\,\,to} \newcommand{\E}{E} \newcommand{\Expect}[1]{\E\left[ #1 \right]} \newcommand{\Var}[1]{\mathrm{Var}\left[ #1 \right]} \newcommand{\Cov}[2]{\mathrm{Cov}\left[#1,\ #2\right]} \newcommand{\given}{\ \vert\ } \newcommand{\X}{\mathbf{X}} \newcommand{\x}{\mathbf{x}} \newcommand{\y}{\mathbf{y}} \newcommand{\P}{\mathcal{P}} \newcommand{\R}{\mathbb{R}} \newcommand{\norm}[1]{\left\lVert #1 \right\rVert} \newcommand{\snorm}[1]{\lVert #1 \rVert} \newcommand{\tr}[1]{\mbox{tr}(#1)} \newcommand{\brt}{\widehat{\beta}^R_{s}} \newcommand{\brl}{\widehat{\beta}^R_{\lambda}} \newcommand{\bls}{\widehat{\beta}_{ols}} \newcommand{\blt}{\widehat{\beta}^L_{s}} \newcommand{\bll}{\widehat{\beta}^L_{\lambda}} \newcommand{\U}{\mathbf{U}} \newcommand{\D}{\mathbf{D}} \newcommand{\V}{\mathbf{V}} \]

Motivation: maximum likelihood estimation as optimization

By the principle of maximum likelihood, we have that

\[ \begin{align*} \hat \beta &= \argmax_{\beta} \prod_{i=1}^n \P(Y_i \mid X_i) \\ &= \argmin_{\beta} \sum_{i=1}^n -\log\P(Y_i \mid X_i) \end{align*} \]

Under the model we use for logistic regression… \[ \begin{gathered} \P(Y=1 \mid X=x) = h(\beta^\top x), \qquad \P(Y=0 \mid X=x) = h(-\beta^\top x), \\ h(z) = \tfrac{1}{1-e^{-z}} \end{gathered} \]

… we can’t simply find the argmin with algebra.

Gradient descent: the workhorse optimization algorithm

We’ll see “gradient descent” a few times:

  1. solves logistic regression
  2. gradient boosting
  3. Neural networks

This seems like a good time to explain it.

So what is it and how does it work?

Very basic example

Suppose I want to minimize \(f(x)=(x-6)^2\) numerically.

I start at a point (say \(x_1=23\))

I want to “go” in the negative direction of the gradient.

The gradient (at \(x_1=23\)) is \(f'(23)=2(23-6)=34\).

Move current value toward current value - 34.

\(x_2 = x_1 - \gamma 34\), for \(\gamma\) small.

In general, \(x_{n+1} = x_n -\gamma f'(x_n)\).

niter <- 10
gam <- 0.1
x <- double(niter)
x[1] <- 23
grad <- function(x) 2 * (x - 6)
for (i in 2:niter) x[i] <- x[i - 1] - gam * grad(x[i - 1])

Why does this work?

Heuristic interpretation:

  • Gradient tells me the slope.

  • negative gradient points toward the minimum

  • go that way, but not too far (or we’ll miss it)

Why does this work?

More rigorous interpretation:

  • Taylor expansion \[ f(x) \approx f(x_0) + \nabla f(x_0)^{\top}(x-x_0) + \frac{1}{2}(x-x_0)^\top H(x_0) (x-x_0) \]

  • replace \(H\) with \(\gamma^{-1} I\)

  • minimize this quadratic approximation in \(x\): \[ 0\overset{\textrm{set}}{=}\nabla f(x_0) + \frac{1}{\gamma}(x-x_0) \Longrightarrow x = x_0 - \gamma \nabla f(x_0) \]

Visually

Visually

What \(\gamma\)? (more details than we have time for)

What to use for \(\gamma_k\)?

Fixed

  • Only works if \(\gamma\) is exactly right
  • Usually does not work

Decay on a schedule

\(\gamma_{n+1} = \frac{\gamma_n}{1+cn}\) or \(\gamma_{n} = \gamma_0 b^n\)

Exact line search

  • Tells you exactly how far to go.
  • At each iteration \(n\), solve \(\gamma_n = \arg\min_{s \geq 0} f( x^{(n)} - s f(x^{(n-1)}))\)
  • Usually can’t solve this.

\[ f(x_1,x_2) = x_1^2 + 0.5x_2^2\]

x <- matrix(0, 40, 2); x[1, ] <- c(1, 1)
grad <- function(x) c(2, 1) * x

\[ f(x_1,x_2) = x_1^2 + 0.5x_2^2\]

gamma <- .1
for (k in 2:40) x[k, ] <- x[k - 1, ] - gamma * grad(x[k - 1, ])

\[ f(x_1,x_2) = x_1^2 + 0.5x_2^2\]

gamma <- .9 # bigger gamma
for (k in 2:40) x[k, ] <- x[k - 1, ] - gamma * grad(x[k - 1, ])

\[ f(x_1,x_2) = x_1^2 + 0.5x_2^2\]

gamma <- .9 # big, but decrease it on schedule
for (k in 2:40) x[k, ] <- x[k - 1, ] - gamma * .9^k * grad(x[k - 1, ])

\[ f(x_1,x_2) = x_1^2 + 0.5x_2^2\]

gamma <- .5 # theoretically optimal
for (k in 2:40) x[k, ] <- x[k - 1, ] - gamma * grad(x[k - 1, ])

When do we stop?

For \(\epsilon>0\), small

Check any / all of

  1. \(|f'(x)| < \epsilon\)
  2. \(|x^{(k)} - x^{(k-1)}| < \epsilon\)
  3. \(|f(x^{(k)}) - f(x^{(k-1)})| < \epsilon\)

Stochastic gradient descent (SGD)

If optimizing \(\argmin_\beta \sum_{i=1}^n -\log P_\beta(Y_i \mid X_i)\) then derivative also additive: \[ \sum_{i=1}^n \frac{\partial}{\partial \beta} \left[-\log P_\beta(Y_i \mid X_i) \right] \]

If \(n\) is really big, it may take a long time to compute this

So, just sample a subset of data \(\mathcal{M} \subset \{ (X_i, Y_i)\}_{i=1}^n\) and approximate: \[\sum_{i=1}^n \frac{\partial}{\partial \beta} \left[-\log P_\beta(Y_i \mid X_i) \right] \approx \frac{n}{\vert \mathcal M \vert}\sum_{i\in\mathcal{M}} \left[-\log P_\beta(Y_i \mid X_i) \right]\]

For SGD need:

  • the gradient estimates to be unbiased (are they?)
  • decaying step size \(\gamma\) (why?)

SGD

\[ \begin{aligned} f'(\beta) &= \frac{1}{n}\sum_{i=1}^n f'_i(\beta) \approx \frac{1}{|\mathcal{M}_j|}\sum_{i\in\mathcal{M}_j}f'_{i}(\beta) \end{aligned} \]

Instead of drawing samples independently, better to:

  • Randomly order the whole dataset (N points), then iterate:
    • grab the next M points
    • compute a gradient estimate based on those points, take a step
    • once you exhaust all the data, that’s an “epoch”; start from the beginning again

Gradient estimates are still marginally unbiased (why?)

This is the workhorse for neural network optimization

When do we stop SGD?

For \(\epsilon>0\), small

Can we check any / all of

  1. \(|f'(x)| < \epsilon\) ?
  2. \(|x^{(k)} - x^{(k-1)}| < \epsilon\) ?
  3. \(|f(x^{(k)}) - f(x^{(k-1)})| < \epsilon\) ?

None of this works due to the stochasticity. Knowing when to terminate SGD is hard.

Practice with GD and Logistic regression

Gradient descent for Logistic regression

\[ \begin{gathered} \P(Y=1 \mid X=x) = h(\beta^\top x), \qquad \P(Y=0 \mid X=x) = h(-\beta^\top x), \\ \\ h(z) = \tfrac{1}{1+e^{-z}} \end{gathered} \]

n <- 100
beta <- 2
x <- runif(n, -5, 5)
logit <- function(x) 1 / (1 + exp(-x))
p <- logit(beta * x)
y <- rbinom(n, 1, p)
df <- tibble(x, y)
ggplot(df, aes(x, y)) +
  geom_point(colour = "cornflowerblue") +
  stat_function(fun = ~ logit(beta * .x))

\[ \P(Y=1 \mid X=x) = h(\beta^\top x), \qquad \P(Y=0 \mid X=x) = h(-\beta^\top x) \]

Under maximum likelihood

\[ \hat \beta = \argmin_{\beta} \underbrace{ \textstyle{\sum_{i=1}^n - \log( \P_\beta(Y_i=y_i \mid X_i=x_i) )} }_{:= -\ell(\beta)} \]

\[ \begin{align*} \P_\beta(Y_i=y_i \mid X_i=X_i) &= h\left( [-1]^{1 - y_i} \beta^\top x_i \right) \\ \\ -\ell(\beta) &= \sum_{i=1}^n -\log\left( \P_\beta(Y_i=y_i \mid X_i=X_i) \right) \\ &= \sum_{i=1}^n \log\left( 1 + \exp\left( [-1]^{y_i} \beta^\top x_i \right) \right) \\ \\ -\frac{\partial \ell}{\partial \beta} &= \sum_{i=1}^n x_i[-1]^{y_i} \frac{\exp\left( [-1]^{y_i} \beta^\top x_i \right)}{1 + \exp\left( [-1]^{y_i} \beta^\top x_i \right)} \\ %&= \sum_{i=1}^n x_i \left( y_i - \P_\beta(Y_i=y_i \mid X_i=X_i) \right) \end{align*} \]

Finding \(\hat\beta = \argmin_{\beta} -\ell(\beta)\) with gradient descent:

  1. Input \(\beta_0,\ \gamma>0,\ \epsilon>0,\ \tfrac{d \ell}{d\beta}\).
  2. For \(j=1,\ 2,\ \ldots\), \[\beta_j = \beta_{j-1} - \gamma \tfrac{d}{d\beta}\left(-\!\ell(\beta_{j-1}) \right)\]
  3. Stop if \(|\beta_j - \beta_{j-1}| < \epsilon\) or \(|d\ell / d\beta\ | < \epsilon\).
beta.mle <- function(x, y, beta0, gamma = 0.5, jmax = 50, eps = 1e-6) {
  beta <- double(jmax) # place to hold stuff (always preallocate space)
  beta[1] <- beta0 # starting value
  for (j in 2:jmax) { # avoid possibly infinite while loops
    px <- logit(beta[j - 1] * x)
    grad <- mean(-x * (y - px))
    beta[j] <- beta[j - 1] - gamma * grad
    if (abs(grad) < eps || abs(beta[j] - beta[j - 1]) < eps) break
  }
  beta[1:j]
}

Try it:

too_big <- beta.mle(x, y, beta0 = 5, gamma = 50)
too_small <- beta.mle(x, y, beta0 = 5, gamma = 1)
just_right <- beta.mle(x, y, beta0 = 5, gamma = 10)
negll <- function(beta) {
  -beta * mean(x * y) -
    rowMeans(log(1 / (1 + exp(outer(beta, x)))))
}
blah <- list_rbind(
  map(
    rlang::dots_list(
      too_big, too_small, just_right, .named = TRUE
    ), 
    as_tibble),
  names_to = "gamma"
) |> mutate(negll = negll(value))
ggplot(blah, aes(value, negll)) +
  geom_point(aes(colour = gamma)) +
  facet_wrap(~gamma, ncol = 1) +
  stat_function(fun = negll, xlim = c(-2.5, 5)) +
  scale_y_log10() + 
  xlab("beta") + 
  ylab("negative log likelihood") +
  geom_vline(xintercept = tail(just_right, 1)) +
  scale_colour_brewer(palette = "Set1") +
  theme(legend.position = "none")

Check vs. glm()

summary(glm(y ~ x - 1, family = "binomial"))

Call:
glm(formula = y ~ x - 1, family = "binomial")

Coefficients:
  Estimate Std. Error z value Pr(>|z|)    
x   1.9174     0.4785   4.008 6.13e-05 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 138.629  on 100  degrees of freedom
Residual deviance:  32.335  on  99  degrees of freedom
AIC: 34.335

Number of Fisher Scoring iterations: 7