Gradient Descent in HMMs

In this tutorial we explore two ways to use gradient descent when fitting HMMs:

  1. Fitting parameters of an observation model that do not have closed-form updates (e.g., GLMs, neural networks, etc.), inside the EM algorithm.
  2. Fitting the entire HMM with gradient-based optimization by leveraging automatic differentiation.

We will explore both approaches below.

using ADTypes
using ComponentArrays
using DensityInterface
using ForwardDiff
using HiddenMarkovModels
using LinearAlgebra
using Optim
using Random
using StableRNGs
using StatsAPI

rng = StableRNG(42)
StableRNGs.LehmerRNG(state=0x00000000000000000000000000000055)

For both parts of this tutorial we use a simple HMM with Gaussian observations. Using gradient-based optimization here is overkill, but it keeps the tutorial simple while illustrating the relevant methods.

We begin by defining a Normal observation model.

mutable struct NormalModel{T}
    μ::T
    logσ::T  # unconstrained parameterization; σ = exp(logσ)
end

model_mean(mod::NormalModel) = mod.μ
stddev(mod::NormalModel) = exp(mod.logσ)
stddev (generic function with 1 method)

We have defined a simple probability model with two parameters: the mean and the log of the standard deviation. Using logσ is intentional so we can optimize over all real numbers without worrying about the positivity constraint on σ.

Next, we provide the minimal interface expected by HiddenMarkovModels.jl: (logdensityof, rand, fit!).

function DensityInterface.logdensityof(mod::NormalModel, obs::T) where {T<:Real}
    s = stddev(mod)
    return -log(2π) / 2 - log(s) - ((obs - model_mean(mod)) / s)^2 / 2
end

DensityInterface.DensityKind(::NormalModel) = DensityInterface.HasDensity()

function Random.rand(rng::AbstractRNG, mod::NormalModel{T}) where {T}
    return stddev(mod) * randn(rng, T) + model_mean(mod)
end

Because we are fitting a Gaussian (and the variance can collapse to ~0), we add weak priors to regularize the parameters. We use:

  • A weak Normal prior on μ
  • A moderate-strength Normal prior on logσ that pulls σ toward ~1
const μ_prior = NormalModel(0.0, log(10.0))
const logσ_prior = NormalModel(log(1.0), log(0.5))

function neglogpost(
    μ::T,
    logσ::T,
    data::AbstractVector{<:Real},
    weights::AbstractVector{<:Real},
    μ_prior::NormalModel,
    logσ_prior::NormalModel,
) where {T<:Real}
    tmp = NormalModel(μ, logσ)

    nll = mapreduce(
        i -> -weights[i] * logdensityof(tmp, data[i]), +, eachindex(data, weights)
    )

    nll += -logdensityof(μ_prior, μ)
    nll += -logdensityof(logσ_prior, logσ)

    return nll
end

function neglogpost(
    θ::AbstractVector{T},
    data::AbstractVector{<:Real},
    weights::AbstractVector{<:Real},
    μ_prior::NormalModel,
    logσ_prior::NormalModel,
) where {T<:Real}
    μ, logσ = θ
    return neglogpost(μ, logσ, data, weights, μ_prior, logσ_prior)
end

function StatsAPI.fit!(
    mod::NormalModel, data::AbstractVector{<:Real}, weights::AbstractVector{<:Real}
)
    T = promote_type(typeof(mod.μ), typeof(mod.logσ))
    θ0 = T[T(mod.μ), T(mod.logσ)]
    obj = θ -> neglogpost(θ, data, weights, μ_prior, logσ_prior)
    result = Optim.optimize(obj, θ0, BFGS(); autodiff=AutoForwardDiff())
    mod.μ, mod.logσ = Optim.minimizer(result)
    return mod
end

Now that we have fully defined our observation model, we can create an HMM using it.

init_dist = [0.2, 0.7, 0.1]
init_trans = [
    0.9 0.05 0.05
    0.075 0.9 0.025
    0.1 0.1 0.8
]

obs_dists = [
    NormalModel(-3.0, log(0.25)), NormalModel(0.0, log(0.5)), NormalModel(3.0, log(0.75))
]

hmm_true = HMM(init_dist, init_trans, obs_dists)
Hidden Markov Model with:
 - initialization: [0.2, 0.7, 0.1]
 - transition matrix: [0.9 0.05 0.05; 0.075 0.9 0.025; 0.1 0.1 0.8]
 - observation distributions: [Main.NormalModel{Float64}(-3.0, -1.3862943611198906), Main.NormalModel{Float64}(0.0, -0.6931471805599453), Main.NormalModel{Float64}(3.0, -0.2876820724517809)]

We can now generate data from this HMM. Note: rand(rng, hmm, T) returns (state_seq, obs_seq).

state_seq, obs_seq = rand(rng, hmm_true, 10_000)
(state_seq = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2  …  2, 1, 1, 1, 1, 1, 1, 1, 3, 3], obs_seq = [0.18676222545978152, -0.010693241131214582, -0.2784822510041131, 0.08762810726537032, -1.0844753758113237, 0.16158875227330957, 0.23966610660195012, 0.3420226663030283, 0.5216136091863335, 0.36307062133252843  …  -0.5352268156888037, -3.4262027995877578, -2.715359345491884, -3.39569049903953, -2.9847698058919203, -2.6313341019638505, -2.6252349000380084, -2.8993960061998374, 2.814659049681727, 2.9787243151708425])

Next we fit a new HMM to this data. Baum–Welch will perform EM updates for the HMM parameters; during the M-step, our observation model parameters are fit via gradient-based optimization (BFGS).

init_dist_guess = fill(1.0 / 3, 3)
init_trans_guess = [
    0.98 0.01 0.01
    0.01 0.98 0.01
    0.01 0.01 0.98
]

obs_dist_guess = [
    NormalModel(-2.0, log(1.0)), NormalModel(2.0, log(1.0)), NormalModel(0.0, log(1.0))
]

hmm_guess = HMM(init_dist_guess, init_trans_guess, obs_dist_guess)

hmm_est, lls = baum_welch(hmm_guess, obs_seq)
(Hidden Markov Model with:
 - initialization: [9.488970316129715e-219, 2.084310358336173e-32, 1.0]
 - transition matrix: [0.9009541783953148 0.0523454731235879 0.04670034848109712; 0.0922091349544837 0.8101900503265297 0.09760081471898664; 0.07177949170293166 0.022646527730598564 0.9055739805664698]
 - observation distributions: [Main.NormalModel{Float64}(-2.998739990936056, -1.3730911369240926), Main.NormalModel{Float64}(3.031999162184866, -0.3158190178345715), Main.NormalModel{Float64}(0.0056615595968454386, -0.701341692965898)], [-17992.376536779982, -11827.462331361017, -9046.465817561802, -9029.354676316934, -9029.22729506789, -9029.224714316486, -9029.224605808902, -9029.224596144144])

Great! We were able to fit the model using gradient descent inside EM.

Now we will fit the entire HMM using gradient-based optimization by leveraging automatic differentiation. The key idea is that the forward algorithm marginalizes out the latent states, providing the likelihood of the observations directly as a function of all model parameters.

We can therefore optimize the negative log-likelihood returned by forward. Each objective evaluation runs the forward algorithm, which can be expensive for large datasets, but this approach allows end-to-end gradient-based fitting for arbitrary parameterized HMMs.

To respect HMM constraints, we optimize unconstrained parameters and map them to valid probability distributions via softmax:

  • π = softmax(ηπ)
  • each row of A = softmax(row_logits)
function softmax(v::AbstractVector)
    m = maximum(v)
    ex = exp.(v .- m)
    return ex ./ sum(ex)
end

function rowsoftmax(M::AbstractMatrix)
    A = similar(M)
    for i in 1:size(M, 1)
        A[i, :] .= softmax(view(M, i, :))
    end
    return A
end

function unpack_to_hmm(θ::ComponentVector)
    K = length(θ.ηπ)

    π = softmax(θ.ηπ)
    A = rowsoftmax(θ.ηA)
    dists = [NormalModel(θ.μ[k], θ.logσ[k]) for k in 1:K]

    return HMM(π, A, dists)
end

function hmm_to_θ0(hmm::HMM)
    K = length(hmm.init)

    T = promote_type(
        eltype(hmm.init),
        eltype(hmm.trans),
        eltype(hmm.dists[1].μ),
        eltype(hmm.dists[1].logσ),
    )

    ηπ = log.(hmm.init .+ eps(T))
    ηA = log.(hmm.trans .+ eps(T))

    μ = [hmm.dists[k].μ for k in 1:K]
    logσ = [hmm.dists[k].logσ for k in 1:K]

    return ComponentVector(; ηπ=ηπ, ηA=ηA, μ=μ, logσ=logσ)
end

function negloglik_from_θ(θ::ComponentVector, obs_seq)
    hmm = unpack_to_hmm(θ)
    _, loglik = forward(hmm, obs_seq; error_if_not_finite=false)
    return -loglik[1]
end

θ0 = hmm_to_θ0(hmm_guess)
ax = getaxes(θ0)

obj(x) = negloglik_from_θ(ComponentVector(x, ax), obs_seq)

result = Optim.optimize(obj, Vector(θ0), BFGS(); autodiff=AutoForwardDiff())
hmm_est2 = unpack_to_hmm(ComponentVector(result.minimizer, ax))
Hidden Markov Model with:
 - initialization: [4.832572481121333e-18, 4.219535328535932e-18, 1.0]
 - transition matrix: [0.900954177626253 0.052345216589079584 0.04670060578466745; 0.09220927064123588 0.8101851152959032 0.09760561406286092; 0.07177945535615372 0.022648712661270023 0.9055718319825763]
 - observation distributions: [Main.NormalModel{Float64}(-2.998740427485512, -1.3737153743543045), Main.NormalModel{Float64}(3.0320182094334345, -0.31622032685126156), Main.NormalModel{Float64}(0.005660568991269338, -0.7017072704058933)]

We have now trained an HMM using gradient-based optimization over all parameters!


This page was generated using Literate.jl.