OpenKalman
solve.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2022 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_SOLVE_HPP
17 #define OPENKALMAN_SOLVE_HPP
18 
19 namespace OpenKalman
20 {
21  namespace detail
22  {
23  template<typename A, typename B>
24  void solve_check_A_and_B_rows_match(const A& a, const B& b)
25  {
26  if (get_vector_space_descriptor<0>(a) != get_vector_space_descriptor<0>(b))
27  throw std::domain_error {"The rows of the two operands of the solve function must be the same, but instead "
28  "the first operand has " + std::to_string(get_index_dimension_of<0>(a)) + " rows and the second operand has " +
29  std::to_string(get_index_dimension_of<0>(b)) + " rows"};
30  }
31 
32 
33  template<typename A, typename B, typename Arg>
34  decltype(auto) wrap_solve_result_impl(Arg&& arg)
35  {
36  constexpr TriangleType tri = triangle_type_of_v<A, B>;
37  if constexpr (tri != TriangleType::any)
38  return make_triangular_matrix<tri>(std::forward<Arg>(arg));
39  else if constexpr (((constant_diagonal_matrix<A> and hermitian_matrix<B>) or (constant_diagonal_matrix<B> and hermitian_matrix<A>)))
40  return make_hermitian_matrix(std::forward<Arg>(arg));
41  else
42  return std::forward<Arg>(arg);
43  }
44 
45 
46  template<typename A, typename B, typename Arg>
47  decltype(auto) wrap_solve_result(Arg&& arg)
48  {
49  using V0 = vector_space_descriptor_of_t<A, 1>;
50  using V1 = vector_space_descriptor_of_t<B, 1>;
51  return internal::make_fixed_size_adapter<V0, V1>(wrap_solve_result_impl<A, B>(std::forward<Arg>(arg)));
52  }
53  } // namespace detail
54 
55 
70  #ifdef __cpp_concepts
71  template<bool must_be_unique = false, bool must_be_exact = false, typename A, typename B> requires
72  (not zero<A> or not zero<B> or not must_be_unique) and
73  (not zero<A> or not (constant_matrix<B> or constant_diagonal_matrix<B>) or zero<B> or not must_be_exact) and
74  (not constant_matrix<A> or not constant_diagonal_matrix<B> or has_dynamic_dimensions<A> or
75  (index_dimension_of_v<A, 0> <= index_dimension_of_v<A, 1> and index_dimension_of_v<B, 0> <= index_dimension_of_v<A, 1>) or
76  (index_dimension_of_v<A, 0> == 1 and index_dimension_of_v<B, 0> == 1) or not must_be_exact)
77  constexpr compatible_with_vector_space_descriptor_collection<std::tuple<vector_space_descriptor_of_t<A, 1>, vector_space_descriptor_of_t<B, 1>>> auto
78  #else
79  template<bool must_be_unique = false, bool must_be_exact = false, typename A, typename B, std::enable_if_t<
80  (not zero<A> or not zero<B> or not must_be_unique) and
81  (not zero<A> or not (constant_matrix<B> or constant_diagonal_matrix<B>) or zero<B> or not must_be_exact) and
82  (not constant_matrix<A> or not constant_diagonal_matrix<B> or has_dynamic_dimensions<A> or
83  (index_dimension_of_v<A, 0> <= index_dimension_of_v<A, 1> and index_dimension_of_v<B, 0> <= index_dimension_of_v<A, 1>) or
84  (index_dimension_of_v<A, 0> == 1 and index_dimension_of_v<B, 0> == 1) or not must_be_exact), int> = 0>
85  constexpr auto
86  #endif
87  solve(A&& a, B&& b)
88  {
89  static_assert(dynamic_dimension<A, 0> or dynamic_dimension<B, 0> or index_dimension_of_v<A, 0> == index_dimension_of_v<B, 0>,
90  "The rows of two operands of the solve function must be the same.");
91 
93 
94  if constexpr (zero<B>)
95  {
96  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
97 
98  if constexpr (must_be_unique and not constant_matrix<A> and not constant_diagonal_matrix<A>)
99  {
100  // \todo the predicate should be simpler, such as (a == to_native_matrix<A>(make_zero(a))), but this causes a failure in Eigen.
101  if (reduce([](auto c1, auto c2) { if (c1 == 0 and c2 == 0) return 0; else return 1; }, a) == 0)
102  throw std::runtime_error {"solve function requires a unique solution, "
103  "but because operands A and B are both zero matrices, result X may take on any value"};
104  else return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
105  }
106  else return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
107  }
108  else if constexpr (zero<A>) //< This will be a non-exact solution unless b is zero.
109  {
110  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
111  return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
112  }
113  else if constexpr (index_dimension_of_v<A, 1> == 1 and (index_dimension_of_v<A, 0> == 1 or index_dimension_of_v<B, 0> == 1))
114  {
115  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
116 
118 
119  if constexpr (identity_matrix<A>)
120  return internal::make_fixed_size_adapter<coordinates::Axis, V1>(std::forward<B>(b));
121  else
122  return internal::make_fixed_size_adapter<coordinates::Axis, V1>(scalar_quotient(std::forward<B>(b), internal::get_singular_component(std::forward<A>(a))));
123  }
124  else if constexpr (constant_diagonal_matrix<A> and (square_shaped<A> or (not dynamic_dimension<A, 1> and index_dimension_of_v<A, 1> == index_dimension_of_v<B, 0>)))
125  {
126  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
127 
128  using V0 = decltype(internal::best_vector_space_descriptor(get_vector_space_descriptor<0>(b), get_vector_space_descriptor<0>(a), get_vector_space_descriptor<1>(a)));
130 
131  if constexpr (identity_matrix<A> and square_shaped<A>)
132  return internal::make_fixed_size_adapter<V0, V1>(std::forward<B>(b));
133  else
134  return internal::make_fixed_size_adapter<V0, V1>(scalar_quotient(std::forward<B>(b), constant_diagonal_coefficient{std::forward<A>(a)}));
135  }
136  else if constexpr (constant_matrix<A> and (constant_matrix<B>))
137  {
138  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
139 
140  return make_constant<B>(
141  constant_coefficient{b} / (values::cast_to<scalar_type_of_t<A>>(get_index_dimension_of<1>(a)) * constant_coefficient{a}),
142  get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
143  }
144  else if constexpr (constant_matrix<A> and (index_dimension_of_v<A, 0> == 1 or index_dimension_of_v<B, 0> == 1 or
145  (not must_be_exact and (not must_be_unique or (not has_dynamic_dimensions<A> and index_dimension_of_v<A, 0> >= index_dimension_of_v<A, 1>)))))
146  {
147  if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
148 
149  return detail::wrap_solve_result<A, B>(
150  scalar_quotient(std::forward<B>(b), values::cast_to<scalar_type_of_t<A>>(get_index_dimension_of<1>(a)) * constant_coefficient{a}));
151  }
152  else if constexpr (diagonal_matrix<A> and square_shaped<A>)
153  {
154  auto op = [](auto&& b_elem, auto&& a_elem) {
155  if (a_elem == 0)
156  {
157  if constexpr (not std::numeric_limits<scalar_type_of_t<B>>::has_infinity) throw std::logic_error {
158  "In solve function, an element should be infinite, but the scalar type does not have infinite values"};
159  else return std::numeric_limits<scalar_type_of_t<B>>::infinity();
160  }
161  else
162  {
163  return std::forward<decltype(b_elem)>(b_elem) / static_cast<scalar_type_of_t<B>>(std::forward<decltype(a_elem)>(a_elem));
164  }
165  };
166 
167  return detail::wrap_solve_result<A, B>(
168  n_ary_operation(all_vector_space_descriptors(b), std::move(op), std::forward<B>(b), diagonal_of(std::forward<A>(a))));
169  }
170  else if constexpr (interface::solve_defined_for<A, must_be_unique, must_be_exact, A, B>)
171  {
172  return detail::wrap_solve_result<A, B>(
173  Interface::template solve<must_be_unique, must_be_exact>(std::forward<A>(a), std::forward<B>(b)));
174  }
175  else
176  {
177  return detail::wrap_solve_result<A, B>(
178  Interface::template solve<must_be_unique, must_be_exact>(std::forward<A>(a), to_native_matrix<A>(std::forward<B>(b))));
179  }
180  }
181 
182 
183 } // namespace OpenKalman
184 
185 #endif //OPENKALMAN_SOLVE_HPP
constexpr auto n_ary_operation(const std::tuple< Ds... > &d_tup, Operation &&operation, Args &&...args)
Perform a component-wise n-ary operation, using broadcasting to match the size of a pattern matrix...
Definition: n_ary_operation.hpp:319
TriangleType
The type of a triangular matrix.
Definition: global-definitions.hpp:60
decltype(auto) constexpr make_hermitian_matrix(Arg &&arg)
Creates a hermitian_matrix by, if necessary, wrapping the argument in a hermitian_adapter.
Definition: make_hermitian_matrix.hpp:37
typename scalar_type_of< T >::type scalar_type_of_t
helper template for scalar_type_of.
Definition: scalar_type_of.hpp:54
Lower, upper, or diagonal matrix.
decltype(auto) constexpr reduce(BinaryFunction &&b, Arg &&arg)
Perform a partial reduction based on an associative binary function, across one or more indices...
Definition: reduce.hpp:143
The constant associated with T, assuming T is a constant_matrix.
Definition: constant_coefficient.hpp:36
decltype(auto) constexpr all_vector_space_descriptors(const T &t)
Return a collection of coordinates::pattern objects associated with T.
Definition: all_vector_space_descriptors.hpp:52
The root namespace for OpenKalman.
Definition: basics.hpp:34
An interface to various routines from the linear algebra library associated with indexible object T...
Definition: library_interface.hpp:37
The constant associated with T, assuming T is a constant_diagonal_matrix.
Definition: constant_diagonal_coefficient.hpp:32
constexpr auto solve(A &&a, B &&b)
Solve the equation AX = B for X, which may or may not be a unique solution.
Definition: solve.hpp:87
decltype(auto) constexpr diagonal_of(Arg &&arg)
Extract a column vector (or column slice for rank>2 tensors) comprising the diagonal elements...
Definition: diagonal_of.hpp:33
typename vector_space_descriptor_of< T, N >::type vector_space_descriptor_of_t
helper template for vector_space_descriptor_of.
Definition: vector_space_descriptor_of.hpp:56