Control dependency
Here, we give a example of controlled HMM (also called input-output HMM), in the special case of Markov switching regression.
using DensityInterface
using Distributions
using HiddenMarkovModels
import HiddenMarkovModels as HMMs
using LinearAlgebra
using Random
using StableRNGs
using StatsAPIrng = StableRNG(63);Model
A Markov switching regression is like a classical regression, except that the weights depend on the unobserved state of an HMM. We can represent it with the following subtype of AbstractHMM (see Custom HMM structures), which has one vector of coefficients $\beta_i$ per state.
struct ControlledGaussianHMM{T} <: AbstractHMM
init::Vector{T}
trans::Matrix{T}
dist_coeffs::Vector{Vector{T}}
endIn state $i$ with a vector of controls $u$, our observation is given by the linear model $y \sim \mathcal{N}(\beta_i^\top u, 1)$. Controls must be provided to both transition_matrix and obs_distributions even if they are only used by one.
function HMMs.initialization(hmm::ControlledGaussianHMM)
return hmm.init
end
function HMMs.transition_matrix(hmm::ControlledGaussianHMM, control::AbstractVector)
return hmm.trans
end
function HMMs.obs_distributions(hmm::ControlledGaussianHMM, control::AbstractVector)
return [Normal(dot(hmm.dist_coeffs[i], control), 1.0) for i in 1:length(hmm)]
endIn this case, the transition matrix does not depend on the control.
Simulation
d = 3
init = [0.6, 0.4]
trans = [0.7 0.3; 0.2 0.8]
dist_coeffs = [-ones(d), ones(d)]
hmm = ControlledGaussianHMM(init, trans, dist_coeffs);Simulation requires a vector of controls, each being a vector itself with the right dimension.
Let us build several sequences of variable lengths.
control_seqs = [[randn(rng, d) for t in 1:rand(100:200)] for k in 1:1000];
obs_seqs = [rand(rng, hmm, control_seq).obs_seq for control_seq in control_seqs];
obs_seq = reduce(vcat, obs_seqs)
control_seq = reduce(vcat, control_seqs)
seq_ends = cumsum(length.(obs_seqs));Inference
Not much changes from the case with simple time dependency.
best_state_seq, _ = viterbi(hmm, obs_seq, control_seq; seq_ends)([1, 1, 1, 1, 1, 2, 2, 1, 1, 1 … 2, 2, 1, 1, 1, 1, 1, 1, 1, 1], [-307.0946501237572, -178.74169062556288, -358.5632654371543, -290.93668352299943, -248.89409606676082, -187.19574304254445, -327.590964672368, -294.8980751442451, -363.02069870025025, -293.19103736938104 … -239.5558721434393, -285.6811961498322, -297.73393682786957, -212.42338489789483, -291.6224632524889, -188.77429379234698, -300.5029930897908, -201.8510123690674, -330.6329903708718, -337.5399974872673])Learning
Once more, we override the fit! function. The state-related parameters are estimated in the standard way. Meanwhile, the observation coefficients are given by the formula for weighted least squares.
function StatsAPI.fit!(
hmm::ControlledGaussianHMM{T},
fb_storage::HMMs.ForwardBackwardStorage,
obs_seq::AbstractVector,
control_seq::AbstractVector;
seq_ends,
) where {T}
(; γ, ξ) = fb_storage
N = length(hmm)
hmm.init .= 0
hmm.trans .= 0
for k in eachindex(seq_ends)
t1, t2 = HMMs.seq_limits(seq_ends, k)
hmm.init .+= γ[:, t1]
hmm.trans .+= sum(ξ[t1:t2])
end
hmm.init ./= sum(hmm.init)
for row in eachrow(hmm.trans)
row ./= sum(row)
end
U = reduce(hcat, control_seq)'
y = obs_seq
for i in 1:N
W = sqrt.(Diagonal(γ[i, :]))
hmm.dist_coeffs[i] = (W * U) \ (W * y)
end
endNow we put it to the test.
init_guess = [0.5, 0.5]
trans_guess = [0.6 0.4; 0.3 0.7]
dist_coeffs_guess = [-2 * ones(d), 2 * ones(d)]
hmm_guess = ControlledGaussianHMM(init_guess, trans_guess, dist_coeffs_guess);hmm_est, loglikelihood_evolution = baum_welch(hmm_guess, obs_seq, control_seq; seq_ends)
first(loglikelihood_evolution), last(loglikelihood_evolution)(-477070.37617184647, -259086.9833297345)How did we perform?
cat(hmm_est.trans, hmm.trans; dims=3)2×2×2 Array{Float64, 3}:
[:, :, 1] =
0.703392 0.296608
0.197335 0.802665
[:, :, 2] =
0.7 0.3
0.2 0.8hcat(hmm_est.dist_coeffs[1], hmm.dist_coeffs[1])3×2 Matrix{Float64}:
-1.0011 -1.0
-1.00507 -1.0
-1.00671 -1.0hcat(hmm_est.dist_coeffs[2], hmm.dist_coeffs[2])3×2 Matrix{Float64}:
0.998608 1.0
1.00218 1.0
0.999313 1.0Built-in ControlledEmissionHMM
The pattern above defines a custom AbstractHMM subtype, which is the right tool when controls influence the transition matrix or the initial distribution.
When only the emissions depend on the control, the package ships ControlledEmissionHMM so you don't have to write the boilerplate. init and trans are stored as plain control-independent vectors and matrices, and you only have to provide a control-aware emission distribution per state.
Defining a control-aware emission
Each emission subtypes ControlledEmission and must implement three "non-standard" three-argument methods:
DensityInterface.logdensityof(d, obs, control)for inferenceRandom.rand(rng, d, control)for samplingStatsAPI.fit!(d, obs_seq, control_seq, weights)for learning
Subtyping ControlledEmission documents the interface and provides DensityKind automatically. At inference time, obs_distributions(hmm, control) wraps each emission into a ControlBoundEmission bound to control, so the inner inference loops keep using the standard two-argument logdensityof(dist, obs) / rand(rng, dist) interface.
To mirror the example above, we define a Gaussian whose mean is linear in a scalar control:
mutable struct LinearGaussian{T} <: ControlledEmission
β0::T
β1::T
logσ::T
end
function DensityInterface.logdensityof(d::LinearGaussian, obs::Real, control::Real)
μ = d.β0 + d.β1 * control
σ = exp(d.logσ)
return -log(2π) / 2 - d.logσ - ((obs - μ) / σ)^2 / 2
end
function Random.rand(rng::AbstractRNG, d::LinearGaussian, control::Real)
μ = d.β0 + d.β1 * control
σ = exp(d.logσ)
return μ + σ * randn(rng)
endThe fit! method below performs a weighted maximum-likelihood update, where the weights are the state posteriors $\gamma_t$ supplied by Baum-Welch. Maximizing the weighted Gaussian log-likelihood over $(\beta_0, \beta_1)$ is an ordinary weighted least squares problem: writing $S_k = \sum_t \gamma_t u_t^k$ and $T_k = \sum_t \gamma_t u_t^k y_t$, the normal equations have the closed-form solution used below, with $\Delta = S_0 S_2 - S_1^2$. Given those coefficients, the variance estimate is the weighted mean of the squared residuals, $\sigma^2 = \frac{1}{S_0} \sum_t \gamma_t (y_t - \mu_t)^2$.
function StatsAPI.fit!(
d::LinearGaussian,
obs_seq::AbstractVector{<:Real},
control_seq::AbstractVector{<:Real},
weights::AbstractVector{<:Real},
)
S0 = sum(weights) # S₀ = Σₜ γₜ
S1 = sum(weights .* control_seq) # S₁ = Σₜ γₜ·uₜ
S2 = sum(weights .* control_seq .^ 2) # S₂ = Σₜ γₜ·uₜ²
T0 = sum(weights .* obs_seq) # T₀ = Σₜ γₜ·yₜ
T1 = sum(weights .* control_seq .* obs_seq) # T₁ = Σₜ γₜ·uₜ·yₜ
Δ = S0 * S2 - S1^2 # Δ = S₀·S₂ - S₁² (determinant of the normal equations)
d.β0 = (T0 * S2 - T1 * S1) / Δ # β₀ = (T₀·S₂ - T₁·S₁) / Δ
d.β1 = (T1 * S0 - T0 * S1) / Δ # β₁ = (T₁·S₀ - T₀·S₁) / Δ
sse = sum(weights .* (obs_seq .- (d.β0 .+ d.β1 .* control_seq)) .^ 2) # SSE = Σₜ γₜ·(yₜ - μₜ)²
d.logσ = log(sqrt(sse / S0)) # logσ = ½·log(SSE / S₀)
return d
endBuilding the HMM
Construction takes the standard parameters plus the vector of control-aware emissions — no AbstractHMM subtype needed.
dists_lg = [LinearGaussian(-1.0, 2.0, log(0.5)), LinearGaussian(0.0, -1.0, log(1.0))]
hmm_lg = ControlledEmissionHMM(init, trans, dists_lg);Simulation
A ControlledEmissionHMM always requires a concrete control sequence: calling rand(hmm, T::Integer) is not supported, since there is no sensible default control. Provide a control_seq of the desired length instead.
control_seq_lg = randn(rng, 10000);
obs_seq_lg = rand(rng, hmm_lg, control_seq_lg).obs_seq;Inference and learning
Inference works exactly as with any other AbstractHMM:
best_state_seq_lg, _ = viterbi(hmm_lg, obs_seq_lg, control_seq_lg)([2, 2, 2, 2, 2, 2, 1, 1, 1, 2 … 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [-15914.59538384597])For learning, ControlledEmissionHMM ships its own fit! method that re-estimates init and trans in the standard way and then calls your distribution's fit!(d, obs_seq, control_seq, weights) for each state — so there is nothing to override on the HMM itself.
dists_lg_guess = [LinearGaussian(-0.5, 1.0, log(1.0)), LinearGaussian(0.0, 0.0, log(1.0))]
hmm_lg_guess = ControlledEmissionHMM(init_guess, trans_guess, dists_lg_guess)
hmm_lg_est, ll_lg = baum_welch(hmm_lg_guess, obs_seq_lg, control_seq_lg)
first(ll_lg), last(ll_lg)(-20868.582255382586, -15022.967843880258)This page was generated using Literate.jl.