Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling latent parameters #595

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Apr 21, 2024

Okay, so this all started from this gist https://gist.github.com/JasonPekos/82be830e4bf390fd1cc2886a7518aede by @JasonPekos using Turing.jl combined with @gdalle's HiddenMarkovModels.jl.

This PR demonstrates how we could support such latent parameter models on the RHS of ~ without any / minimal intervention of the user, aaannd it's fairly easy to handle.

We could also do the same to replace @submodel, though there it's a bit non-trivial because there are two "types" of realizations: the random variables involved in ~ and those returned, so let's leave that for now.

But with this PR, one needs to implement a few methods:

  • latent(dist): returns a Distribution for the latent parameters.
  • conditional(dist, latents): returns a Distribution for the conditional distribution of the data given the latent parameters.
  • marginalize(dist): returns a Distribution for the marginal distribution of the data.

And that's it!

For example, to allow straight-forward usage of HiddenMarkovModels.jl with Turing.jl, the following "just works" (though it's a bit hacky, it's not particularly difficult):

julia> """
           hmm(K, T)

       A Hidden Markov Model with `K` states and `T` observations with marginalized hidden states.
       """
       @model function hmm(K, T)
           # Transition matrix.
           π ~ product_distribution(fill(Dirichlet(fill(1 / K, K)), K))
           # Mean of emission distribution.
           μ ~ Bijectors.ordered(MvNormal(zeros(K), 10I))

           # HMM(init, trans, emissions).
           hmm = HMM(π[:, 1], permutedims(π), Normal.(μ, 1))
           y ~ FixedLengthHMM(hmm, T)

           return y
       end

hmm

julia> model = DynamicPPL.contextualize(hmm(K, T), DynamicPPL.LatentHandlingContext());

julia> # Captures the latent variables!
       rand(OrderedDict, model)
OrderedDict{Any, Any} with 4 entries:
  π   => [0.88056 0.989755; 0.11944 0.010245]
  μ   => [-2.40058, 34.1099]
  y.z => [1, 1, 1, 2, 1, 1, 1, 1, 1, 2]
  y   => [-0.386338, -1.5516, -2.39555, 34.6776, -3.41318, -2.45152, -2.37044, -2.38218, -2.3104, 32.9849]

julia> # And we can condition.
       y = model();

julia> model_conditioned = model | (y=y,);

julia> # And now the latent parameters are also gone!
       rand(OrderedDict, model_conditioned)
OrderedDict{Any, Any} with 2 entries:
  π => [0.755412 0.992134; 0.244588 0.00786611]
  μ => [-1.03905, -1.02069]

The above requires the following to be implemented:

using DynamicPPL, Distributions, HiddenMarkovModels, Random, LinearAlgebra, Bijectors

struct FixedLengthHMM{M}
    hmm::M
    n::Int
end

struct MarginalizedHMM{M,V,F} <: Distribution{V,F}
    hmm::M
end

function MarginalizedHMM(hmm::FixedLengthHMM)
    # TODO: Determine variate form and type from `hmm`.
    return MarginalizedHMM{typeof(hmm),Multivariate,Continuous}(hmm)
end

UnivariateMarginalizedHMM = MarginalizedHMM{<:FixedLengthHMM,Univariate,Continuous}
MultivariateMarginalizedHMM = MarginalizedHMM{<:FixedLengthHMM,Multivariate,Continuous}

function Distributions.rand(rng::Random.AbstractRNG, dist::MarginalizedHMM)
    return last(rand(dist.hmm.hmm, dist.hmm.n))
end
function Distributions.logpdf(dist::MarginalizedHMM, x::AbstractVector{<:Real})
    return logdensityof(dist.hmm.hmm, x)
end
function Distributions.logpdf(dist::MarginalizedHMM, x::Real)
    return logdensityof(dist.hmm.hmm, x)
end

# Latent distribution.
struct LatentDistribution{M,V,F} <: Distribution{V,F}
    hmm::M
end

function LatentDistribution(hmm::FixedLengthHMM)
    return LatentDistribution{typeof(hmm),Multivariate,Discrete}(hmm)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::LatentDistribution)
    return first(rand(dist.hmm.hmm, dist.hmm.n))
end

function Distributions.logpdf(dist::LatentDistribution, x::AbstractVector{<:Real})
    @assert length(x) == dist.hmm.n "Length of `x` must match number of latent states."
    hmm = dist.hmm.hmm
    lp = HiddenMarkovModels.log_initialization(hmm)[x[1]]
    logtrans = HiddenMarkovModels.log_transition_matrix(hmm)
    for t in 2:(dist.hmm.n)
        lp += logtrans[x[t - 1], x[t]]
    end
    return lp
end

# Conditional.
struct ConditionalHMM{M,A,V,F} <: Distribution{V,F}
    hmm::M
    latents::A
end

function ConditionalHMM(hmm::FixedLengthHMM, latents)
    # TODO: Determine variate form and type from `hmm`.
    return ConditionalHMM{typeof(hmm),typeof(latents),Multivariate,Continuous}(hmm, latents)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::ConditionalHMM)
    # Sample from the emission distributions conditional on the latent states.
    hmm = dist.hmm.hmm
    conditionals = HiddenMarkovModels.obs_distributions(hmm)[dist.latents]
    return [rand(rng, c) for c in conditionals]
end

function Distributions.logpdf(dist::ConditionalHMM, y::AbstractVector{<:Real})
    hmm = dist.hmm.hmm
    conditionals = HiddenMarkovModels.obs_distributions(hmm)[dist.latents]
    return sum(logpdf(c, y[i]) for (i, c) in enumerate(conditionals))
end

# Make compatible with DPPL.
DynamicPPL.check_tilde_rhs(hmm::FixedLengthHMM) = hmm
DynamicPPL.has_latents(hmm::FixedLengthHMM) = true
# Dispatches for the different distributions.
DynamicPPL.latent(hmm::FixedLengthHMM) = LatentDistribution(hmm)
DynamicPPL.conditional(hmm::FixedLengthHMM, latents) = ConditionalHMM(hmm, latents)
DynamicPPL.marginalize(hmm::FixedLengthHMM) = MarginalizedHMM(hmm)
# Choose varname suffix.
function DynamicPPL.suffix_latent_varname(::FixedLengthHMM, vn)
    return DynamicPPL.suffix_varname(vn, Val{:z}())
end

Some other people who might be interested in this: @THargreaves @yebai @devmotion

@torfjelde torfjelde marked this pull request as draft April 21, 2024 18:54
@coveralls
Copy link

coveralls commented Apr 21, 2024

Pull Request Test Coverage Report for Build 8774826250

Details

  • 0 of 33 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.7%) to 77.976%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/latent_handling.jl 0 33 0.0%
Totals Coverage Status
Change from base Build 8765865071: -0.7%
Covered Lines: 2751
Relevant Lines: 3528

💛 - Coveralls

@gdalle
Copy link

gdalle commented Apr 22, 2024

I don't know if I'm familiar enough wit Turing to help, but I would be very enthusiastic if it landed in the official package / docs, so ping me whenever you run into HMM-related troubles!

@yebai
Copy link
Member

yebai commented Apr 28, 2024

I like the idea here; also agree this can be developed into a complete replacement for submodels.

@yebai
Copy link
Member

yebai commented May 5, 2024

In addition to latent, marginal/conditional APIs to distributions, we can also consider introducing returned

 x ~ returned(some_turing_model(...))

to replace

@submodel x = some_turing_model(...)

Here, returned has effects similar to our current submodel syntax, except that it replaces assignments with tilde. The semantics here is that x is distributionally equal to the returned variables of some_turing_model.

@torfjelde
Copy link
Member Author

Here, returned has effects similar to our current submodel syntax, except that it replaces assignments with tilde. The semantics here is that x is distributionally equal to the returned variables of some_turing_model.

Though not fundamentally against this, I guess the issue with all of this is that it really conflates what ~ is meant to represent. It's really not unlikely that users will start doing stuff like trying to condition in these cases, which will not be possible.

@yebai
Copy link
Member

yebai commented May 6, 2024

I guess the issue with all of this is that it really conflates what ~ is meant to represent. It's really not unlikely that users will start doing stuff like trying to condition in these cases, which will not be possible.

Good question. The returned variables in a submodel induce an implicit distribution (sometimes a Delta distribution if RVs do not influence these variables in the submodel). However, the likelihood distribution for this implicit distribution might be hard to compute or intractable in certain cases. This is when the condition syntax becomes intractable since it depends on the likelihood function being tractable.

We can probably update the docs to inform the user that certain model operations are not applicable for submodels or, more generally, distributions without closed-form log density functions.

@torfjelde
Copy link
Member Author

However, the likelihood distribution for this implicit distribution might be hard to compute or intractable in certain cases.

But this is, because we don't know anything about the underlying model, all of the models in DPPL.

We can probably update the docs to inform the user that certain model operations are not applicable for submodels or, more generally, distributions without closed-form log density functions.

We would most certainly have to make the user aware of this if we were to use this syntax.

My issue is that we don't want to end up in a scenario where the user has to ask themselves "is this valid?" every time they write left ~ right. Right now, we're very stringent regarding what's allowed, and that makes it easy to reason about. If we start allowing arbitrary things on the RHS of a ~, we run the risk of making things very confusing for the user; e.g. end up with a bunch of issues just asking "I do condition(model, x=1) and it doesn't work" because they used a model on the RHS.

I'm not sure which side I'm on tbh. We might be able to handle it properly if we introduce a lot of useful error messages, but it requires care.

@torfjelde
Copy link
Member Author

torfjelde commented May 6, 2024

Also, regarding this particular PR, I was naively thinking that we could just use the forward sampling conditioned on the inferred HMM parameters to sample the latents conditioned on data, but of course (thanks to @THargreaves for pointing this out), this doesn't work since in general we don't have
$$p(z \mid y, \theta) p(\theta \mid y) \neq p(z \mid y) p(\theta \mid y)$$

This means that we need three different modus operandi:

  1. Sampling from prior: $\theta \sim p(\theta )$ then $z \sim p(z \mid \theta)$
  2. Sampling from marginalized: $y \sim p(y \mid \theta)$
  3. Sampling from posterior: $\theta \sim p(\theta \mid y)$ then $z \sim p(z \mid \theta, y)$ (forward-backward alg)

Buuut currently we don't really have a good way of knowing when to use (1) or (3). This is relevant to the discussion had in #589 regarding how to specify whether we are performing "post inference"-analysis or performing inference.

@yebai
Copy link
Member

yebai commented May 6, 2024

My view is that the latents shouldn't be handled by Turing's inference engine by default (unless it is a Turing submodel). Instead, it should be infered by manually specified external algorithms.

@gdalle
Copy link

gdalle commented May 6, 2024

I'm unsubscribing since I can't help right now, ping me if there is some HMM-related stuff coming up as a test case!

@yebai yebai changed the title [Proof of concept] Handling latent parameters Handling latent parameters May 8, 2024
@torfjelde
Copy link
Member Author

torfjelde commented May 9, 2024

As written above, if we have a generative model of the form

θ ~ p(θ)
z ~ p(z ∣ θ)
y ~ p(y ∣ θ, z)

and we want to perform the following:

  1. During inference, we want to marginalize out $z$, to get $p(y \mid \theta)$.
  2. During sampling from the prior, we want to just sample according to the above model.
  3. During posterior sampling, we have samples $\theta \sim p(\theta \mid y)$ from (1), but we then need to sample $z \sim p(z \mid \theta, y)$ and then (if we're predicting) sample $y \sim p(y \mid \theta, z)$.

The first two are "easy" (both are achieved in the current PR), while the third is not so trivial. The problem with the third one is as follows:

  • We need access to the observations $y$ in the model, which means that we're going to hit an observe rather than assume.
  • But if we hit observe, we don't have access to the varname of the observation (y in this case), and so we don't know which varname to use for the latent variable. In (2) (i.e. prior sampling), this is not an issue because we hit assume, which in turn means we have access to its varname => can just use this to construct the varname of the latent.
  • Even if we address the previous issue, we also need a way to determine distinguish whether we are sampling from the prior or the posterior, which is not currently possible (in fact, this is a nice thing about the current impl of Turing.predict; it just samples from the prior model, i.e. without conditioning on the observations, conditioned on posterior samples).

We could fix this by also adding varname as an argument to the observe pipeline, but this does mean:

  1. It'll be very breaking.
  2. For literals, e.g. 1 ~ Normal(), we need a placeholder (probably just nothing), which is not ideal.

@yebai
Copy link
Member

yebai commented May 10, 2024

We could fix this by also adding varname as an argument to the observe pipeline, but this does mean:

I don't think we can handle inference for z in the general setting efficiently; we should probably leave this to the user, but perhaps provide a user-friendly interface so users can provide their samples for z using some external algorithm (e.g. fix/unfix).

@torfjelde
Copy link
Member Author

I don't think we can handle inference for z in the general setting efficiently; we should probably leave this to the user, but perhaps provide a user-friendly interface so users can provide their samples for z using some external algorithm (e.g. fix/unfix).

Two things:

  1. It's already possible to "easily" do this by hand by simply putting z := ... in an if-statement that is only used when using something like generated quantities.
  2. If we don't support it for something like HMM, it does make it very easy for the user to "the wrong thing" by accident (like I did myself) 😕

@yebai
Copy link
Member

yebai commented May 10, 2024

If we don't support it for something like HMM, it does make it very easy for the user to "the wrong thing" by accident (like I did myself)

It feels like this belongs to the doc!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants