OpenKalman
KalmanFilter.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2017-2021 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_KALMANFILTER_HPP
17 #define OPENKALMAN_KALMANFILTER_HPP
18 
19 namespace OpenKalman
20 {
25  template<typename...Transform>
26  struct KalmanFilter;
27 
28 
33  template<typename Transform>
34  struct KalmanFilter<Transform>
35  {
36  protected:
37  Transform transform;
38 
39  template<typename XDistribution, typename YDistribution, typename CrossCovariance, typename Measurement>
40  static auto
41  update_step(const XDistribution& Nx, const YDistribution& Ny, const CrossCovariance& P_xy, const Measurement& z)
42  {
43  static_assert(gaussian_distribution<XDistribution>);
44  static_assert(gaussian_distribution<YDistribution>);
45  static_assert(typed_matrix<CrossCovariance>);
46  static_assert(typed_matrix<Measurement> and vector<Measurement>);
47  static_assert(coordinates::compares_with<vector_space_descriptor_of_t<Measurement, 0>,
48  typename DistributionTraits<YDistribution>::StaticDescriptor>);
49  static_assert(coordinates::compares_with<vector_space_descriptor_of_t<CrossCovariance, 0>,
50  typename DistributionTraits<XDistribution>::StaticDescriptor>);
51  static_assert(coordinates::compares_with<vector_space_descriptor_of_t<CrossCovariance, 1>,
52  typename DistributionTraits<YDistribution>::StaticDescriptor>);
53 
54  const auto y = mean_of(Ny);
55  const auto P_yy = covariance_of(Ny);
56  const auto K_adj = solve(adjoint(P_yy), adjoint(P_xy));
57  const auto K = adjoint(K_adj);
58  // K * P_yy == P_xy, or K == P_xy * inverse(P_yy)
59  auto out_x_mean = sum(mean_of(Nx), contract(K, (Mean {z} - y)));
60  using re = typename DistributionTraits<XDistribution>::random_number_engine;
61 
62  if constexpr (cholesky_form<YDistribution>)
63  {
64  // P_xy * adjoint(K) == K * P_yy * adjoint(K) == K * square_root(P_yy) * adjoint(K * square_root(P_yy))
65  // == square(LQ(K * square_root(P_yy)))
66  auto out_x_cov = covariance_of(Nx) - square(LQ_decomposition(K * square_root(P_yy)));
67  return make_GaussianDistribution<re>(std::move(out_x_mean), std::move(out_x_cov));
68  }
69  else
70  {
71  // K == P_xy * inverse(P_yy), so
72  // P_xy * adjoint(K) == P_xy * adjoint(inverse(P_yy)) * adjoint(P_xy)
73  auto out_x_cov = covariance_of(Nx) - Covariance(P_xy * K_adj);
74  return make_GaussianDistribution<re>(std::move(out_x_mean), std::move(out_x_cov));
75  }
76  }
77 
78  public:
79  explicit KalmanFilter(const Transform& trans)
80  : transform(trans)
81  {}
82 
83  /*
84  * Predict the next state based on the process model.
85  */
86  template<typename...ProcessTransformArguments>
87  auto
88  predict(const ProcessTransformArguments&...args)
89  {
90  return transform(args...);
91  }
92 
93  /*
94  * Update the state based on a measurement and the measurement model.
95  */
96  template<typename Measurement, typename State, typename...MeasurementTransformArguments>
97  auto
98  update(const Measurement& z, const State& x, const MeasurementTransformArguments&...args)
99  {
100  const auto [y, P_xy] = transform.transform_with_cross_covariance(x, args...);
101  return update_step(x, y, P_xy, z);
102  }
103 
104  /*
105  * Perform a complete predict-update cycle. Predict the next state based on the process model, and then
106  * update the state based on a measurement and the measurement model.
107  */
108  template<
109  typename Measurement,
110  typename State,
111  typename...ProcessTransformArgs,
112  typename...MeasurementTransformArgs>
113  auto
114  operator()(
115  const Measurement& z,
116  const State& x,
117  const std::tuple<ProcessTransformArgs...>& proc_args = std::tuple {},
118  const std::tuple<MeasurementTransformArgs...>& meas_args = std::tuple {})
119  {
120  const auto&& [y_, P_xy, y] = std::apply(
121  transform.transform_with_cross_covariance,
122  std::tuple_cat(std::forward_as_tuple(x), proc_args, meas_args));
123  return update_step(y, y_, P_xy, z);
124  }
125 
126  };
127 
128 
134  template<typename ProcessTransform, typename MeasurementTransform>
135  struct KalmanFilter<ProcessTransform, MeasurementTransform> : KalmanFilter<ProcessTransform>
136  {
137  protected:
139  using Base::transform;
140  MeasurementTransform measurement_transform;
141 
142  using Base::update_step;
143 
144  public:
145  KalmanFilter(const ProcessTransform& p_transform, const MeasurementTransform& m_transform)
146  : Base(p_transform), measurement_transform(m_transform)
147  {}
148 
149  /*
150  * Update the state based on a measurement and the measurement model.
151  */
152  template<typename Measurement, typename State, typename...MeasurementTransformArguments>
153  auto
154  update(const Measurement& z, const State& x, const MeasurementTransformArguments&...args)
155  {
156  const auto [y, P_xy] = measurement_transform.transform_with_cross_covariance(x, args...);
157  return update_step(x, y, P_xy, z);
158  }
159 
160  /*
161  * Perform a complete predict-update cycle. Predict the next state based on the process model, and then
162  * update the state based on a measurement and the measurement model.
163  */
164  template<
165  typename Measurement,
166  typename State,
167  typename...ProcessTransformArgs,
168  typename...MeasurementTransformArgs>
169  auto
170  operator()(
171  const Measurement& z,
172  const State& x,
173  const std::tuple<ProcessTransformArgs...>& proc_args = std::tuple {},
174  const std::tuple<MeasurementTransformArgs...>& meas_args = std::tuple {})
175  {
176  const auto y = std::apply(transform, std::tuple_cat(std::forward_as_tuple(x), proc_args));
177  const auto [y_, P_xy] = std::apply(
178  measurement_transform.transform_with_cross_covariance,
179  std::tuple_cat(std::forward_as_tuple(y), meas_args));
180  return update_step(y, y_, P_xy, z);
181  }
182 
183  };
184 
185 
187  // Deduction guides //
189 
190  template<typename P>
192 
193  template<typename P, typename M>
194  KalmanFilter(P&&, M&&) -> KalmanFilter<P, M>;
195 
196 }
197 
198 #endif //OPENKALMAN_KALMANFILTER_HPP
decltype(auto) constexpr contract(A &&a, B &&b)
Matrix multiplication of A * B.
Definition: contract.hpp:54
A set of one or more column vectors, each representing a statistical mean.
Definition: forward-class-declarations.hpp:477
A Kalman filter, using one or more statistical transforms.
Definition: KalmanFilter.hpp:26
The root namespace for OpenKalman.
Definition: basics.hpp:34
constexpr auto solve(A &&a, B &&b)
Solve the equation AX = B for X, which may or may not be a unique solution.
Definition: solve.hpp:87
decltype(auto) constexpr adjoint(Arg &&arg)
Take the adjoint of a matrix.
Definition: adjoint.hpp:33
typename vector_space_descriptor_of< T, N >::type vector_space_descriptor_of_t
helper template for vector_space_descriptor_of.
Definition: vector_space_descriptor_of.hpp:56
decltype(auto) constexpr sum(Ts &&...ts)
Element-by-element sum of one or more objects.
Definition: sum.hpp:112
Covariance(M &&) -> Covariance< Dimensions< index_dimension_of_v< M, 0 >>, passable_t< M >>
Deduce Covariance type from a covariance_nestable.
decltype(auto) constexpr LQ_decomposition(A &&a)
Perform an LQ decomposition of matrix A=[L,0]Q, L is a lower-triangular matrix, and Q is orthogonal...
Definition: LQ_decomposition.hpp:33
constexpr detail::update_adaptor update
a std::ranges::range_adaptor_closure associated with update_view.
Definition: update.hpp:403