Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Imbedded Laplace Approximation #3097

Draft
wants to merge 143 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
3ec8ab6
Port relevant laplace files and fix some unit tests.
charlesm93 Mar 4, 2020
c353fa4
Include laplace.hpp in header files.
charlesm93 Mar 4, 2020
1e19bc7
Merge branch 'try-laplace_approximation' of https://github.com/stan-d…
charlesm93 Mar 29, 2020
ffec2a6
Add rng function for bernoulli logit function.
charlesm93 Mar 29, 2020
0e15cd9
Template x argument.
charlesm93 Mar 29, 2020
f054278
update name laplace_marginal_poisson_log
charlesm93 Jul 17, 2020
a185079
Merge branch 'develop' into try-laplace_approximation2
charlesm93 Jul 17, 2020
d90a5ed
update name of laplace_rng.
charlesm93 Jul 17, 2020
a30f162
rename laplace bernoulli functions.
charlesm93 Jul 17, 2020
ac4de59
Update file names and header includes.
charlesm93 Jul 17, 2020
f66d582
Update signature for laplace_marginal_poisson_log.
charlesm93 Jul 25, 2020
dc00d6f
update signature of laplace_bernoulli_logit.
charlesm93 Jul 25, 2020
2c73a61
update reference for differentiation.
charlesm93 Jan 13, 2021
dab8b73
log likelihood for student t.
charlesm93 Jan 14, 2021
8795fed
logp for negative binomial.
charlesm93 Jan 14, 2021
61ee5b4
Features for neg binomial likelihood.
charlesm93 Jan 15, 2021
8dee734
Finish analytical likelihood diff for neg binomial.
charlesm93 Jan 27, 2021
fcf33be
Prototype differentiation wrt likelihood hyperparameters.
charlesm93 Jan 27, 2021
098df42
progress towards marginal diff.
charlesm93 Jan 27, 2021
4bf51ee
more unit tests.
charlesm93 Jan 30, 2021
9528588
Fix finite diff benchmark.
charlesm93 Jan 30, 2021
8bb9c78
Create wrapper for neg binomial likelihood.
charlesm93 Jan 30, 2021
4936428
update poisson_log likelihood.
charlesm93 Jan 31, 2021
12b6e37
update bernoulli.
charlesm93 Jan 31, 2021
3935877
Steps torwards higher-order autodiff.
charlesm93 Feb 19, 2021
8326baa
prototype likelihood using user-specified likelihood.
charlesm93 Feb 23, 2021
2ab0cb4
add test for autodiffed likelihood.
charlesm93 Feb 23, 2021
7831316
block diag hessian computation.
charlesm93 Feb 25, 2021
153cb8c
autodiff for non-diag hessian and eta.
charlesm93 Mar 15, 2021
e95375b
prototype eta differentiation.
charlesm93 Mar 16, 2021
d66684c
update rng functions for new interface.
charlesm93 Mar 23, 2021
1fe30d1
update bernoulli analytical lk.
charlesm93 Mar 23, 2021
f8c6ac0
clean up skim test.
charlesm93 Mar 24, 2021
d4b7c21
Extend gp motorcycle test.
charlesm93 Mar 30, 2021
f7831d1
wrapper for general laplace_marginal_pdf
charlesm93 Mar 30, 2021
23b20ae
lpdf and lpmf wrapper for general laplace approximation.
charlesm93 Mar 31, 2021
3a9df9e
update rng functions.
charlesm93 Apr 1, 2021
77e343f
Update all (relevant) unit tests and make sure they run.
charlesm93 Apr 1, 2021
fa3624c
Merge branch 'develop' of https://github.com/stan-dev/math into develop
charlesm93 Apr 1, 2021
cd272c8
Edit files to insure all unit tests still run.
charlesm93 Apr 1, 2021
a554cfd
Add inline keyword for internal functions.
charlesm93 Apr 2, 2021
0d29780
Temporary signature in agreement with parser.
charlesm93 Apr 9, 2021
d110e18
Fix bugs to run function from Stan.
charlesm93 Apr 16, 2021
0966258
prototype linesearch step.
charlesm93 Apr 21, 2021
ab69c98
prototype line search.
charlesm93 Apr 21, 2021
5c4b2e4
Update convergence criterion for linesearch.
charlesm93 Apr 21, 2021
8d45513
simplify the linesearch.
charlesm93 Apr 21, 2021
fcbf075
linesearch: check for non-finite values.
charlesm93 Apr 21, 2021
150f713
prototype Jarnos newton solver.
charlesm93 Apr 22, 2021
be9f880
prototype treatment of diagonal covariance.
charlesm93 Apr 28, 2021
47e2827
attempt at debug diagonal K case...
charlesm93 Apr 28, 2021
6bf438f
return int term.
charlesm93 Apr 28, 2021
9d9cffe
Merge remote-tracking branch 'origin/develop' into try-laplace_student
SteveBronder Sep 27, 2021
3985450
move some files around and cleanup tests. Comment out tests that do n…
SteveBronder Sep 28, 2021
77ad12a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 1, 2021
c1cf1b9
cleanup more of neg_binomial_2 tests
SteveBronder Oct 1, 2021
3f310e7
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Oct 1, 2021
36efe15
remove laplace vari to instead use reverse pass callback
SteveBronder Oct 4, 2021
bac6b0d
update poisson likelihood
SteveBronder Oct 4, 2021
769450f
update laplace ints
SteveBronder Oct 4, 2021
103c200
make
SteveBronder Oct 21, 2021
c2b2ae7
revert vari for laplace_density_vari
SteveBronder Dec 2, 2021
92e031f
Merge branch 'experimental/laplace' of https://github.com/stan-dev/ma…
charlesm93 Dec 2, 2021
805ae7f
formalize unit test for laplacr_marginal_poisson_log_pmf
charlesm93 Dec 2, 2021
2bcd90b
clean unit test for laplace_marginal_bernoulli_logit_lpmf.
charlesm93 Dec 3, 2021
d1d328f
start building unit test for laplace_marginal_lpdf_test.cpp
charlesm93 Dec 3, 2021
f73aec2
Add bernoulli_logit test for laplace_marginal_lpmf.
charlesm93 Dec 4, 2021
1592328
Add (incomplete) motorcycle GP test for laplacr_marginal_lpdf.
charlesm93 Dec 4, 2021
bc74943
rewrite laplace_marginal_poisson_log_lpmf to use the general likeliho…
charlesm93 Dec 5, 2021
6bf8246
(This time for real) Make laplace_marginal_log_lpmf use general likel…
charlesm93 Dec 5, 2021
b2e4b2f
make laplace_marginal_bernoulli_logit_lpmf use general likelihood fun…
charlesm93 Dec 6, 2021
031a81b
Delete research tests which are / will not be unit tests.
charlesm93 Dec 6, 2021
ae138fe
refactor rng test for laplace_poisson_log.
charlesm93 Jan 13, 2022
5231604
Benchmark laplace_poisson_log_rng against multi_normal_rng.
charlesm93 Jan 16, 2022
9e3f9ae
refactor rng poisson log to use full autodiff and add tests for expos…
charlesm93 Jan 16, 2022
f218ba0
Merge remote-tracking branch 'origin/develop' into experimental/laplace
SteveBronder Jan 27, 2022
0fc47f5
update tolerance for nearness tests
SteveBronder Jan 27, 2022
7a40142
update tolerance for nearness tests
SteveBronder Jan 27, 2022
a35ff32
cleanup laplace but failing address sanitizer
SteveBronder Jan 28, 2022
cfd2199
update off by one error
SteveBronder Jan 28, 2022
269b9af
Merge branch 'experimental/laplace' of https://github.com/stan-dev/ma…
charlesm93 Jan 28, 2022
d7e507f
clang format
SteveBronder Jan 28, 2022
c78c7d1
Merge branch 'experimental/laplace' of https://github.com/stan-dev/ma…
charlesm93 Jan 28, 2022
b1d8312
update to remove some extra functions and just call them inline
SteveBronder Jan 28, 2022
ce75b03
Merge branch 'experimental/laplace' of https://github.com/stan-dev/ma…
charlesm93 Jan 28, 2022
c7499b3
local changes for rng tests.
charlesm93 Jan 28, 2022
396278c
remove extra functions and have hessian_times_vector return a vector …
SteveBronder Jan 28, 2022
ef7d2a4
Merge branch 'experimental/laplace' of github.com:stan-dev/math into …
SteveBronder Jan 28, 2022
c6f222a
cleanup
SteveBronder Jan 29, 2022
fdcf8ca
clang format
SteveBronder Jan 29, 2022
a009b8f
remove the in/out parameters from hessian_block_diag and diff
SteveBronder Feb 3, 2022
ae6511e
removes in/out parameters from laplace_marginal_density_est
SteveBronder Feb 3, 2022
41a3d23
small fix in laplace_marginal eta adjoint
SteveBronder Feb 3, 2022
1f6199a
small cleanup
SteveBronder Feb 3, 2022
adcddca
swap x and phi to covariance functor
SteveBronder Feb 3, 2022
3b60f23
start working on variadic version of laplace
SteveBronder Feb 4, 2022
26c6a14
update laplace so it takes in variadic arguments for known likelihoods
SteveBronder Feb 5, 2022
03dd6e0
clang format
SteveBronder Feb 5, 2022
265991c
update tests
SteveBronder Feb 7, 2022
6b97f1f
Removes diagonal_covariance option and code. Also removes do_line_search
SteveBronder Feb 18, 2022
9c5e5ee
Merge remote-tracking branch 'origin/develop' into experimental/laplace
SteveBronder Mar 7, 2022
de38e18
add _tol_ versions of functions for laplace
SteveBronder Mar 9, 2022
31ab267
update with forwards
SteveBronder Mar 9, 2022
25e3ae8
update templates
SteveBronder Mar 9, 2022
c5d814c
update tol params
SteveBronder Mar 9, 2022
efb4662
default value for propto
SteveBronder Mar 9, 2022
f1a05d9
update
SteveBronder Mar 10, 2022
edb2e0d
update
SteveBronder Mar 10, 2022
8900bdb
update names for poisson_log with extra term
SteveBronder Apr 20, 2022
5d685b8
Merge remote-tracking branch 'origin/develop' into experimental/laplace
SteveBronder Apr 29, 2022
ea0e9be
Merge branch 'experimental/laplace' of https://github.com/stan-dev/ma…
charlesm93 May 18, 2022
cfa518a
update signature for rng test.
charlesm93 May 18, 2022
90d50cc
Make minimum hessian_block_size = 1, and treat this as the special di…
charlesm93 May 18, 2022
7bafe13
change TRUE to true and FALSE to false
SteveBronder May 20, 2022
04b17c5
update to develop
SteveBronder May 31, 2024
e76c980
Merge remote-tracking branch 'origin/develop' into experimental/laplace
SteveBronder Jul 22, 2024
6ed072f
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 22, 2024
a33aba4
move laplace files around. Get expect ad to work for laplace
SteveBronder Jul 22, 2024
902cfff
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Jul 22, 2024
ae91da6
update
SteveBronder Jul 25, 2024
051c389
fix header includes
SteveBronder Jul 25, 2024
74d2ea2
update
SteveBronder Jul 25, 2024
4c0c815
update laplace_likelihood
SteveBronder Jul 26, 2024
a8a8257
update
SteveBronder Jul 26, 2024
e21b654
adds variadic args to laplace_likelihood member functions
SteveBronder Jul 26, 2024
784de61
remove laplace_likelihood struct and now just use a namespace
SteveBronder Jul 29, 2024
fa1dad8
cleanup
SteveBronder Jul 30, 2024
1ea25be
cleanup
SteveBronder Jul 30, 2024
51106c3
update
SteveBronder Jul 30, 2024
0e3eb28
clang-format
SteveBronder Jul 30, 2024
d8f95ba
update docs
SteveBronder Aug 5, 2024
f4048ea
Merge commit 'f120c5e8e86c473f64ab222613e173acfa2648de' into HEAD
yashikno Aug 6, 2024
1f05474
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 6, 2024
a78d42c
test headers and cpplint
SteveBronder Aug 6, 2024
6301e37
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 6, 2024
a8be255
update
SteveBronder Aug 6, 2024
846cbbc
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 6, 2024
dceaf5b
update test names
SteveBronder Aug 6, 2024
4dc95e9
update docs
SteveBronder Aug 6, 2024
c2b9ed2
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 6, 2024
0b1cec2
force c++17
SteveBronder Aug 6, 2024
4bb62ea
remove double fvar grad in compute_s2
SteveBronder Aug 15, 2024
746a6f2
remove double fvar grad in compute_s2
SteveBronder Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions make/compiler_flags
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,9 @@ INC_GTEST ?= -I $(GTEST)/include -I $(GTEST)
CPPFLAGS_BOOST ?= -DBOOST_DISABLE_ASSERTS
CPPFLAGS_SUNDIALS ?= -DNO_FPRINTF_OUTPUT $(CPPFLAGS_OPTIM_SUNDIALS) $(CXXFLAGS_FLTO_SUNDIALS)
#CPPFLAGS_GTEST ?=
STAN_HAS_CXX17 ?= false
ifeq ($(CXX_TYPE), gcc)
GCC_GE_73 := $(shell [ $(CXX_MAJOR) -gt 7 -o \( $(CXX_MAJOR) -eq 7 -a $(CXX_MINOR) -ge 1 \) ] && echo true)
ifeq ($(GCC_GE_73),true)
STAN_HAS_CXX17 := true
endif
else ifeq ($(CXX_TYPE), clang)
CLANG_GE_5 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true)
ifeq ($(CLANG_GE_5),true)
STAN_HAS_CXX17 := true
endif
else ifeq ($(CXX_TYPE), mingw32-gcc)
MINGW_GE_50 := $(shell [ $(CXX_MAJOR) -gt 5 -o \( $(CXX_MAJOR) -eq 5 -a $(CXX_MINOR) -ge 0 \) ] && echo true)
ifeq ($(MINGW_GE_50),true)
STAN_HAS_CXX17 := true
endif
endif

ifeq ($(STAN_HAS_CXX17), true)
CXXFLAGS_LANG ?= -std=c++17
CXXFLAGS_STANDARD ?= c++17
else
$(warning "Stan cannot detect if your compiler has the C++17 standard. If it does, please set STAN_HAS_CXX17=true in your make/local file. C++17 support is mandatory in the next release of Stan. Defaulting to C++14")
CXXFLAGS_LANG ?= -std=c++1y
CXXFLAGS_STANDARD ?= c++1y
endif
CXXFLAGS_LANG ?= -std=c++17
CXXFLAGS_STANDARD ?= c++17
#CXXFLAGS_BOOST ?=
CXXFLAGS_SUNDIALS ?= -pipe $(CXXFLAGS_OPTIM_SUNDIALS) $(CPPFLAGS_FLTO_SUNDIALS)
#CXXFLAGS_GTEST
Expand Down
10 changes: 5 additions & 5 deletions stan/math/mix.hpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
#ifndef STAN_MATH_MIX_HPP
#define STAN_MATH_MIX_HPP

#include <stan/math/mix/meta.hpp>
#include <stan/math/mix/fun.hpp>
#include <stan/math/mix/functor.hpp>

#include <stan/math/fwd/constraint.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/fwd/meta.hpp>
Expand All @@ -26,4 +21,9 @@

#include <stan/math/prim.hpp>

#include <stan/math/mix/meta.hpp>
#include <stan/math/mix/fun.hpp>
#include <stan/math/mix/functor.hpp>
#include <stan/math/mix/prob.hpp>

#endif
7 changes: 5 additions & 2 deletions stan/math/mix/functor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
#define STAN_MATH_MIX_FUNCTOR_HPP

#include <stan/math/mix/functor/derivative.hpp>
#include <stan/math/mix/functor/finite_diff_grad_hessian.hpp>
#include <stan/math/mix/functor/finite_diff_grad_hessian_auto.hpp>
#include <stan/math/mix/functor/finite_diff_grad_hessian.hpp>
#include <stan/math/mix/functor/grad_hessian.hpp>
#include <stan/math/mix/functor/grad_tr_mat_times_hessian.hpp>
#include <stan/math/mix/functor/gradient_dot_vector.hpp>
#include <stan/math/mix/functor/hessian.hpp>
#include <stan/math/mix/functor/laplace_base_rng.hpp>
#include <stan/math/mix/functor/laplace_likelihood.hpp>
#include <stan/math/mix/functor/laplace_marginal_density.hpp>
#include <stan/math/mix/functor/hessian_block_diag.hpp>
#include <stan/math/mix/functor/hessian_times_vector.hpp>
#include <stan/math/mix/functor/partial_derivative.hpp>

#endif
56 changes: 56 additions & 0 deletions stan/math/mix/functor/hessian_block_diag.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef STAN_MATH_MIX_FUNCTOR_HESSIAN_BLOCK_DIAG_HPP
#define STAN_MATH_MIX_FUNCTOR_HESSIAN_BLOCK_DIAG_HPP

#include <stan/math/mix/functor/hessian_times_vector.hpp>
#include <Eigen/Sparse>

namespace stan {
namespace math {

/**
* Returns a block diagonal Hessian by computing the relevant directional
* derivatives and storing them in a matrix.
* For m the size of each block, the operations const m calls to
* hessian_times_vector, that is m forward sweeps and m reverse sweeps.
* @tparam F
* @tparam Eta
* @tparam Args
* @param f
* @param x
* @param eta
* @param hessian_block_size
* @param args
*/
template <typename F, typename Eta, typename... Args,
require_eigen_t<Eta>* = nullptr>
inline Eigen::SparseMatrix<double> hessian_block_diag(
F&& f, const Eigen::VectorXd& x, const Eta& eta,
const Eigen::Index hessian_block_size, Args&&... args) {
using Eigen::MatrixXd;
using Eigen::VectorXd;

const Eigen::Index x_size = x.size();
Eigen::SparseMatrix<double> H(x_size, x_size);
H.reserve(Eigen::VectorXi::Constant(x_size, hessian_block_size));
VectorXd v(x_size);
Eigen::Index n_blocks = x_size / hessian_block_size;
for (Eigen::Index i = 0; i < hessian_block_size; ++i) {
v.setZero();
for (Eigen::Index j = i; j < x_size; j += hessian_block_size) {
v(j) = 1;
}
VectorXd Hv = hessian_times_vector(f, x, eta, v, args...);
for (int j = 0; j < n_blocks; ++j) {
for (int k = 0; k < hessian_block_size; ++k) {
H.insert(k + j * hessian_block_size, i + j * hessian_block_size)
= Hv(k + j * hessian_block_size);
}
}
}
return H;
}

} // namespace math
} // namespace stan

#endif
31 changes: 28 additions & 3 deletions stan/math/mix/functor/hessian_times_vector.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
#ifndef STAN_MATH_MIX_FUNCTOR_HESSIAN_TIMES_VECTOR_HPP
#define STAN_MATH_MIX_FUNCTOR_HESSIAN_TIMES_VECTOR_HPP

#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/rev/core.hpp>
#include <stdexcept>
#include <vector>
#include <stan/math/prim/meta.hpp>

namespace stan {
namespace math {
Expand Down Expand Up @@ -35,6 +34,7 @@ void hessian_times_vector(const F& f,
Hv(i) = x_var(i).adj();
}
}

template <typename T, typename F>
void hessian_times_vector(const F& f,
const Eigen::Matrix<T, Eigen::Dynamic, 1>& x,
Expand All @@ -47,6 +47,31 @@ void hessian_times_vector(const F& f,
Hv = H * v;
}

/**
* Overload Hessian_times_vector function, under stan/math/mix/functor
* to handle functions which take in arguments eta, delta, delta_int,
* and pstream.
*/
template <typename F, typename Eta, require_eigen_t<Eta>* = nullptr,
typename... Args>
inline Eigen::VectorXd hessian_times_vector(const F& f,
const Eigen::VectorXd& x,
const Eta& eta,
const Eigen::VectorXd& v,
Args&&... args) {
nested_rev_autodiff nested;
const Eigen::Index x_size = x.size();
Eigen::Matrix<var, Eigen::Dynamic, 1> x_var = x;
Eigen::Matrix<fvar<var>, Eigen::Dynamic, 1> x_fvar(x_size);
for (Eigen::Index i = 0; i < x_size; i++) {
x_fvar(i) = fvar<var>(x_var(i), v(i));
}
fvar<var> fx_fvar = f(x_fvar, eta, args...);
grad(fx_fvar.d_.vi_);
return x_var.adj();
}

} // namespace math
} // namespace stan

#endif
92 changes: 92 additions & 0 deletions stan/math/mix/functor/laplace_base_rng.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#ifndef STAN_MATH_MIX_FUNCTOR_LAPLACE_BASE_RNG_HPP
#define STAN_MATH_MIX_FUNCTOR_LAPLACE_BASE_RNG_HPP

#include <stan/math/mix/functor/laplace_marginal_density.hpp>
#include <stan/math/prim/prob/multi_normal_cholesky_rng.hpp>
#include <stan/math/prim/fun.hpp>

#include <Eigen/Sparse>

namespace stan {
namespace math {

/**
* In a latent gaussian model,
*
* theta ~ Normal(theta | 0, Sigma(phi, x))
* y ~ pi(y | theta, eta)
*
* returns a multivariate normal random variate sampled
* from the gaussian approximation of p(theta_pred | y, phi, x_pred).
* Note that while the data is observed at x, the new samples
* are drawn for covariates x_pred.
* To sample the "original" theta's, set x_pred = x.
* @tparam D
* @tparam LLArgs
* @tparam ThetaMatrix
* @tparam EtaMatrix
* @tparam CovarFun
* @tparam RNG
* @tparam TrainTuple
* @tparam PredTuple
* @tparam Args
* @param ll_fun
* @param ll_args
* @param covariance_function
* @param eta
* @param theta_0
* @param options
* @param train_tuple
* @param pred_tuple
* @param rng
* @param msgs
* @param args
*/
template <typename D, typename LLArgs, typename ThetaMatrix, typename EtaMatrix,
typename CovarFun, class RNG, typename TrainTuple, typename PredTuple,
typename... Args,
require_all_eigen_t<ThetaMatrix, EtaMatrix>* = nullptr>
inline Eigen::VectorXd laplace_base_rng(
D&& ll_fun, LLArgs&& ll_args, CovarFun&& covariance_function,
const ThetaMatrix& eta, const EtaMatrix& theta_0,
const laplace_options& options, TrainTuple&& train_tuple,
PredTuple&& pred_tuple, RNG& rng, std::ostream* msgs, Args&&... args) {
using Eigen::MatrixXd;
using Eigen::VectorXd;
auto args_dbl = std::make_tuple(to_ref(value_of(args))...);
auto eta_dbl = value_of(eta);
auto md_est = apply(
[&](auto&&... args_val) {
return laplace_marginal_density_est(
ll_fun, ll_args, covariance_function, eta_dbl, value_of(theta_0),
msgs, options, args_val...);
},
std::tuple_cat(std::forward<TrainTuple>(train_tuple), args_dbl));
// Modified R&W method
MatrixXd covariance_pred = apply(
[&covariance_function, &msgs](auto&&... args_val) {
return covariance_function(args_val..., msgs);
},
std::tuple_cat(std::forward<PredTuple>(pred_tuple), args_dbl));
VectorXd pred_mean = covariance_pred * md_est.l_grad.head(theta_0.rows());
if (options.solver == 1 || options.solver == 2) {
Eigen::MatrixXd V_dec = mdivide_left_tri<Eigen::Lower>(
md_est.L, md_est.W_r * covariance_pred);
Eigen::MatrixXd Sigma = covariance_pred - V_dec.transpose() * V_dec;
return multi_normal_rng(pred_mean, Sigma, rng);
} else {
Eigen::MatrixXd Sigma
= covariance_pred
- covariance_pred
* (md_est.W_r
- md_est.W_r
* md_est.LU.solve(md_est.covariance * md_est.W_r))
* covariance_pred;
return multi_normal_rng(pred_mean, Sigma, rng);
}
}

} // namespace math
} // namespace stan

#endif
Loading