OpenKalman
reduce.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2022-2024 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_REDUCE_HPP
17 #define OPENKALMAN_REDUCE_HPP
18 
19 namespace OpenKalman
20 {
21  // -------- //
22  // reduce //
23  // -------- //
24 
25  namespace detail
26  {
27  template<std::size_t index_to_delete, std::size_t...new_indices>
28  constexpr auto delete_reduction_index(std::index_sequence<>) { return std::index_sequence<new_indices...>{}; }
29 
30 
31  template<std::size_t index_to_delete, std::size_t...new_indices, std::size_t index, std::size_t...indices>
32  constexpr auto delete_reduction_index(std::index_sequence<index, indices...>)
33  {
34  if constexpr (index == index_to_delete)
35  return delete_reduction_index<index_to_delete, new_indices...>(std::index_sequence<indices...>{});
36  else
37  return delete_reduction_index<index_to_delete, new_indices..., index>(std::index_sequence<indices...>{});
38  }
39 
40 
41  template<typename BinaryOperation, typename Constant, typename Dim>
42  constexpr auto constant_reduce_operation(const BinaryOperation& op, const Constant& c, const Dim& dim)
43  {
44  if constexpr (internal::is_plus<BinaryOperation>::value) return c * dim;
45  else if constexpr (internal::is_multiplies<BinaryOperation>::value) return values::pow(c, dim);
46  else
47  {
48  if constexpr (values::dynamic<Dim>)
49  {
50  if (values::to_number(dim) <= 1) return values::to_number(c);
51  else return op(constant_reduce_operation(op, c, values::to_number(dim) - 1), values::to_number(c));
52  }
53  else if constexpr (Dim::value <= 1) return c;
54  else
55  {
56  auto dim_m1 = std::integral_constant<std::size_t, Dim::value - 1>{};
57  return values::operation {op, constant_reduce_operation(op, c, dim_m1), c};
58  }
59  }
60  }
61 
62 
63  template<typename BinaryFunction, typename Arg, std::size_t...indices, std::size_t...Ix>
64  constexpr decltype(auto)
65  reduce_impl(BinaryFunction&& b, Arg&& arg, std::index_sequence<indices...> indices_seq, std::index_sequence<Ix...> seq)
66  {
68  {
69  return make_zero<Arg>(internal::get_reduced_vector_space_descriptor<Ix, indices...>(std::forward<Arg>(arg))...);
70  }
71  else if constexpr (constant_matrix<Arg>)
72  {
73  auto dim = internal::count_reduced_dimensions(arg, indices_seq, seq);
74  auto c = detail::constant_reduce_operation(b, constant_coefficient{arg}, dim);
75  return make_constant<Arg>(std::move(c), internal::get_reduced_vector_space_descriptor<Ix, indices...>(std::forward<Arg>(arg))...);
76  }
77  else if constexpr (diagonal_matrix<Arg> and internal::is_plus<BinaryFunction>::value and
78  not dynamic_dimension<Arg, 0> and not dynamic_dimension<Arg, 1> and
79  ((((indices == 1) or ...) and index_dimension_of_v<Arg, 1> >= index_dimension_of_v<Arg, 0>) or
80  (((indices == 0) or ...) and ((indices == 1) or ...))))
81  {
82  return reduce_impl(std::forward<BinaryFunction>(b), diagonal_of(std::forward<Arg>(arg)), delete_reduction_index<1>(indices_seq), seq);
83  }
84  else if constexpr (diagonal_matrix<Arg> and internal::is_plus<BinaryFunction>::value and
85  not dynamic_dimension<Arg, 0> and not dynamic_dimension<Arg, 1> and
86  (((indices == 0) or ...) and index_dimension_of_v<Arg, 1> <= index_dimension_of_v<Arg, 0>))
87  {
88  return reduce_impl(std::forward<BinaryFunction>(b), transpose(diagonal_of(std::forward<Arg>(arg))), delete_reduction_index<0>(indices_seq), seq);
89  }
90  else if constexpr (diagonal_matrix<Arg> and internal::is_multiplies<BinaryFunction>::value and
91  not dynamic_dimension<Arg, 1> and ((indices == 1) or ...))
92  {
93  if constexpr (index_dimension_of_v<Arg, 1> == 1)
94  return reduce_impl(std::forward<BinaryFunction>(b), std::forward<Arg>(arg), delete_reduction_index<0>(indices_seq), seq);
95  else
96  return reduce_impl(std::forward<BinaryFunction>(b), make_zero(diagonal_of(std::forward<Arg>(arg))), delete_reduction_index<0>(indices_seq), seq);
97  }
98  else if constexpr (diagonal_matrix<Arg> and internal::is_multiplies<BinaryFunction>::value and
99  not dynamic_dimension<Arg, 0> and ((indices == 0) or ...))
100  {
101  if constexpr (index_dimension_of_v<Arg, 0> == 1)
102  return reduce_impl(std::forward<BinaryFunction>(b), std::forward<Arg>(arg), delete_reduction_index<0>(indices_seq), seq);
103  else
104  return reduce_impl(std::forward<BinaryFunction>(b), make_zero(transpose(diagonal_of(std::forward<Arg>(arg)))), delete_reduction_index<0>(indices_seq), seq);
105  }
106  else
107  {
108  using LibraryInterface = interface::library_interface<std::decay_t<Arg>>;
109  auto red = LibraryInterface::template reduce<indices...>(std::forward<BinaryFunction>(b), std::forward<Arg>(arg));
110  if constexpr (values::number<decltype(red)>)
111  return make_constant<Arg>(std::move(red), internal::get_reduced_vector_space_descriptor<Ix, indices...>(arg)...);
112  else
113  return red;
114  }
115  }
116 
117  } // namespace detail
118 
119 
133 #ifdef __cpp_concepts
134  template<std::size_t index, std::size_t...indices, typename BinaryFunction, internal::has_uniform_static_vector_space_descriptors<index, indices...> Arg> requires
135  std::is_invocable_r_v<scalar_type_of_t<Arg>, BinaryFunction&&, scalar_type_of_t<Arg>, scalar_type_of_t<Arg>>
136  constexpr indexible decltype(auto)
137 #else
138  template<std::size_t index, std::size_t...indices, typename BinaryFunction, typename Arg, std::enable_if_t<
139  internal::has_uniform_static_vector_space_descriptors<Arg, index, indices...> and
140  std::is_invocable_r<typename scalar_type_of<Arg>::type, BinaryFunction&&, typename scalar_type_of<Arg>::type, typename scalar_type_of<Arg>::type>::value, int> = 0>
141  constexpr decltype(auto)
142 #endif
143  reduce(BinaryFunction&& b, Arg&& arg)
144  {
145  if constexpr (dimension_size_of_index_is<Arg, index, 1>) //< Check if Arg is already reduced along index.
146  {
147  if constexpr (sizeof...(indices) == 0) return std::forward<Arg>(arg);
148  else return reduce<indices...>(std::forward<BinaryFunction>(b), std::forward<Arg>(arg));
149  }
150  else
151  {
152  return detail::reduce_impl(
153  std::forward<BinaryFunction>(b),
154  std::forward<Arg>(arg),
155  std::index_sequence<index, indices...>{},
156  std::make_index_sequence<index_count_v<Arg>>{});
157  }
158  }
159 
160 
171 #ifdef __cpp_concepts
172  template<typename BinaryFunction, internal::has_uniform_static_vector_space_descriptors Arg> requires
173  std::is_invocable_r_v<scalar_type_of_t<Arg>, BinaryFunction&&, scalar_type_of_t<Arg>, scalar_type_of_t<Arg>>
174 #else
175  template<typename BinaryFunction, typename Arg, std::enable_if_t<internal::has_uniform_static_vector_space_descriptors<Arg> and
176  std::is_invocable_r<typename scalar_type_of<Arg>::type, BinaryFunction&&,
177  typename scalar_type_of<Arg>::type, typename scalar_type_of<Arg>::type>::value, int> = 0>
178 #endif
179  constexpr scalar_type_of_t<Arg>
180  reduce(const BinaryFunction& b, const Arg& arg)
181  {
182  auto seq = std::make_index_sequence<index_count_v<Arg>>{};
183 
185  {
186  return 0;
187  }
188  else if constexpr (one_dimensional<Arg>)
189  {
190  return internal::get_singular_component(arg);
191  }
192  else if constexpr (constant_matrix<Arg>)
193  {
194  auto dim = internal::count_reduced_dimensions(arg, seq, seq);
195  return constant_reduce_operation(b, constant_coefficient {arg}, dim);
196  }
197  else
198  {
199  decltype(auto) red = detail::reduce_impl(b, arg, seq, seq);
200  using Red = decltype(red);
201 
202  static_assert(values::number<Red> or one_dimensional<Red, Applicability::permitted>,
203  "Incorrect library interface for total 'reduce' on all indices: must return a scalar or one-by-one matrix.");
204 
205  if constexpr (values::number<Red>)
206  return red;
207  else
208  return internal::get_singular_component(red);
209  }
210  }
211 
212 
213 } // namespace OpenKalman
214 
215 #endif //OPENKALMAN_REDUCE_HPP
constexpr bool diagonal_matrix
Specifies that a type is a diagonal matrix or tensor.
Definition: diagonal_matrix.hpp:32
constexpr bool indexible
T is a generalized tensor type.
Definition: indexible.hpp:32
constexpr bool number
T is a numerical type.
Definition: number.hpp:33
Definition: tuple_reverse.hpp:103
constexpr bool value
T is numerical value or is reducible to a numerical value.
Definition: value.hpp:31
decltype(auto) constexpr reduce(BinaryFunction &&b, Arg &&arg)
Perform a partial reduction based on an associative binary function, across one or more indices...
Definition: reduce.hpp:143
constexpr auto to_number(Arg arg)
Convert any values::value to a values::number.
Definition: to_number.hpp:34
The constant associated with T, assuming T is a constant_matrix.
Definition: constant_coefficient.hpp:36
decltype(auto) constexpr transpose(Arg &&arg)
Take the transpose of a matrix.
Definition: transpose.hpp:58
The root namespace for OpenKalman.
Definition: basics.hpp:34
decltype(auto) constexpr diagonal_of(Arg &&arg)
Extract a column vector (or column slice for rank>2 tensors) comprising the diagonal elements...
Definition: diagonal_of.hpp:33
constexpr bool dynamic_dimension
Specifies that T&#39;s index N has a dimension defined at run time.
Definition: dynamic_dimension.hpp:43
Definition: global-definitions.hpp:130
constexpr auto make_constant(C &&c, Descriptors &&descriptors)
Make a constant object based on a particular library object.
Definition: make_constant.hpp:37
constexpr auto make_zero(Descriptors &&descriptors)
Make a zero associated with a particular library.
Definition: make_zero.hpp:36
Definition: global-definitions.hpp:124
operation(const Operation &, const Args &...) -> operation< Operation, Args... >
Deduction guide.
constant_coefficient(const T &) -> constant_coefficient< T >
Deduction guide for constant_coefficient.
constexpr bool index
An object describing a collection of /ref values::index objects.
Definition: index.hpp:75