Gradient Descent in HMMs
In this tutorial we explore two ways to use gradient descent when fitting HMMs:
- Fitting parameters of an observation model that do not have closed-form updates (e.g., GLMs, neural networks, etc.), inside the EM algorithm.
- Fitting the entire HMM with gradient-based optimization by leveraging automatic differentiation.
We will explore both approaches below.
using ADTypes
using ComponentArrays
using DensityInterface
using ForwardDiff
using HiddenMarkovModels
using LinearAlgebra
using Optim
using Random
using StableRNGs
using StatsAPI
rng = StableRNG(42)StableRNGs.LehmerRNG(state=0x00000000000000000000000000000055)For both parts of this tutorial we use a simple HMM with Gaussian observations. Using gradient-based optimization here is overkill, but it keeps the tutorial simple while illustrating the relevant methods.
We begin by defining a Normal observation model.
mutable struct NormalModel{T}
μ::T
logσ::T # unconstrained parameterization; σ = exp(logσ)
end
model_mean(mod::NormalModel) = mod.μ
stddev(mod::NormalModel) = exp(mod.logσ)stddev (generic function with 1 method)We have defined a simple probability model with two parameters: the mean and the log of the standard deviation. Using logσ is intentional so we can optimize over all real numbers without worrying about the positivity constraint on σ.
Next, we provide the minimal interface expected by HiddenMarkovModels.jl: (logdensityof, rand, fit!).
function DensityInterface.logdensityof(mod::NormalModel, obs::T) where {T<:Real}
s = stddev(mod)
return -log(2π) / 2 - log(s) - ((obs - model_mean(mod)) / s)^2 / 2
end
DensityInterface.DensityKind(::NormalModel) = DensityInterface.HasDensity()
function Random.rand(rng::AbstractRNG, mod::NormalModel{T}) where {T}
return stddev(mod) * randn(rng, T) + model_mean(mod)
endBecause we are fitting a Gaussian (and the variance can collapse to ~0), we add weak priors to regularize the parameters. We use:
- A weak Normal prior on
μ - A moderate-strength Normal prior on
logσthat pullsσtoward ~1
const μ_prior = NormalModel(0.0, log(10.0))
const logσ_prior = NormalModel(log(1.0), log(0.5))
function neglogpost(
μ::T,
logσ::T,
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
tmp = NormalModel(μ, logσ)
nll = mapreduce(
i -> -weights[i] * logdensityof(tmp, data[i]), +, eachindex(data, weights)
)
nll += -logdensityof(μ_prior, μ)
nll += -logdensityof(logσ_prior, logσ)
return nll
end
function neglogpost(
θ::AbstractVector{T},
data::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
μ_prior::NormalModel,
logσ_prior::NormalModel,
) where {T<:Real}
μ, logσ = θ
return neglogpost(μ, logσ, data, weights, μ_prior, logσ_prior)
end
function StatsAPI.fit!(
mod::NormalModel, data::AbstractVector{<:Real}, weights::AbstractVector{<:Real}
)
T = promote_type(typeof(mod.μ), typeof(mod.logσ))
θ0 = T[T(mod.μ), T(mod.logσ)]
obj = θ -> neglogpost(θ, data, weights, μ_prior, logσ_prior)
result = Optim.optimize(obj, θ0, BFGS(); autodiff=AutoForwardDiff())
mod.μ, mod.logσ = Optim.minimizer(result)
return mod
endNow that we have fully defined our observation model, we can create an HMM using it.
init_dist = [0.2, 0.7, 0.1]
init_trans = [
0.9 0.05 0.05
0.075 0.9 0.025
0.1 0.1 0.8
]
obs_dists = [
NormalModel(-3.0, log(0.25)), NormalModel(0.0, log(0.5)), NormalModel(3.0, log(0.75))
]
hmm_true = HMM(init_dist, init_trans, obs_dists)Hidden Markov Model with:
- initialization: [0.2, 0.7, 0.1]
- transition matrix: [0.9 0.05 0.05; 0.075 0.9 0.025; 0.1 0.1 0.8]
- observation distributions: [Main.NormalModel{Float64}(-3.0, -1.3862943611198906), Main.NormalModel{Float64}(0.0, -0.6931471805599453), Main.NormalModel{Float64}(3.0, -0.2876820724517809)]We can now generate data from this HMM. Note: rand(rng, hmm, T) returns (state_seq, obs_seq).
state_seq, obs_seq = rand(rng, hmm_true, 10_000)(state_seq = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2 … 2, 1, 1, 1, 1, 1, 1, 1, 3, 3], obs_seq = [0.18676222545978152, -0.010693241131214582, -0.2784822510041131, 0.08762810726537032, -1.0844753758113237, 0.16158875227330957, 0.23966610660195012, 0.3420226663030283, 0.5216136091863335, 0.36307062133252843 … -0.5352268156888037, -3.4262027995877578, -2.715359345491884, -3.39569049903953, -2.9847698058919203, -2.6313341019638505, -2.6252349000380084, -2.8993960061998374, 2.814659049681727, 2.9787243151708425])Next we fit a new HMM to this data. Baum–Welch will perform EM updates for the HMM parameters; during the M-step, our observation model parameters are fit via gradient-based optimization (BFGS).
init_dist_guess = fill(1.0 / 3, 3)
init_trans_guess = [
0.98 0.01 0.01
0.01 0.98 0.01
0.01 0.01 0.98
]
obs_dist_guess = [
NormalModel(-2.0, log(1.0)), NormalModel(2.0, log(1.0)), NormalModel(0.0, log(1.0))
]
hmm_guess = HMM(init_dist_guess, init_trans_guess, obs_dist_guess)
hmm_est, lls = baum_welch(hmm_guess, obs_seq)(Hidden Markov Model with:
- initialization: [9.488970316129715e-219, 2.084310358336173e-32, 1.0]
- transition matrix: [0.9009541783953148 0.0523454731235879 0.04670034848109712; 0.0922091349544837 0.8101900503265297 0.09760081471898664; 0.07177949170293166 0.022646527730598564 0.9055739805664698]
- observation distributions: [Main.NormalModel{Float64}(-2.998739990936056, -1.3730911369240926), Main.NormalModel{Float64}(3.031999162184866, -0.3158190178345715), Main.NormalModel{Float64}(0.0056615595968454386, -0.701341692965898)], [-17992.376536779982, -11827.462331361017, -9046.465817561802, -9029.354676316934, -9029.22729506789, -9029.224714316486, -9029.224605808902, -9029.224596144144])Great! We were able to fit the model using gradient descent inside EM.
Now we will fit the entire HMM using gradient-based optimization by leveraging automatic differentiation. The key idea is that the forward algorithm marginalizes out the latent states, providing the likelihood of the observations directly as a function of all model parameters.
We can therefore optimize the negative log-likelihood returned by forward. Each objective evaluation runs the forward algorithm, which can be expensive for large datasets, but this approach allows end-to-end gradient-based fitting for arbitrary parameterized HMMs.
To respect HMM constraints, we optimize unconstrained parameters and map them to valid probability distributions via softmax:
π = softmax(ηπ)- each row of
A=softmax(row_logits)
function softmax(v::AbstractVector)
m = maximum(v)
ex = exp.(v .- m)
return ex ./ sum(ex)
end
function rowsoftmax(M::AbstractMatrix)
A = similar(M)
for i in 1:size(M, 1)
A[i, :] .= softmax(view(M, i, :))
end
return A
end
function unpack_to_hmm(θ::ComponentVector)
K = length(θ.ηπ)
π = softmax(θ.ηπ)
A = rowsoftmax(θ.ηA)
dists = [NormalModel(θ.μ[k], θ.logσ[k]) for k in 1:K]
return HMM(π, A, dists)
end
function hmm_to_θ0(hmm::HMM)
K = length(hmm.init)
T = promote_type(
eltype(hmm.init),
eltype(hmm.trans),
eltype(hmm.dists[1].μ),
eltype(hmm.dists[1].logσ),
)
ηπ = log.(hmm.init .+ eps(T))
ηA = log.(hmm.trans .+ eps(T))
μ = [hmm.dists[k].μ for k in 1:K]
logσ = [hmm.dists[k].logσ for k in 1:K]
return ComponentVector(; ηπ=ηπ, ηA=ηA, μ=μ, logσ=logσ)
end
function negloglik_from_θ(θ::ComponentVector, obs_seq)
hmm = unpack_to_hmm(θ)
_, loglik = forward(hmm, obs_seq; error_if_not_finite=false)
return -loglik[1]
end
θ0 = hmm_to_θ0(hmm_guess)
ax = getaxes(θ0)
obj(x) = negloglik_from_θ(ComponentVector(x, ax), obs_seq)
result = Optim.optimize(obj, Vector(θ0), BFGS(); autodiff=AutoForwardDiff())
hmm_est2 = unpack_to_hmm(ComponentVector(result.minimizer, ax))Hidden Markov Model with:
- initialization: [4.832572481121333e-18, 4.219535328535932e-18, 1.0]
- transition matrix: [0.900954177626253 0.052345216589079584 0.04670060578466745; 0.09220927064123588 0.8101851152959032 0.09760561406286092; 0.07177945535615372 0.022648712661270023 0.9055718319825763]
- observation distributions: [Main.NormalModel{Float64}(-2.998740427485512, -1.3737153743543045), Main.NormalModel{Float64}(3.0320182094334345, -0.31622032685126156), Main.NormalModel{Float64}(0.005660568991269338, -0.7017072704058933)]We have now trained an HMM using gradient-based optimization over all parameters!
This page was generated using Literate.jl.