OpenKalman
contract.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2022-2023 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_CONTRACT_HPP
17 #define OPENKALMAN_CONTRACT_HPP
18 
19 namespace OpenKalman
20 {
21  namespace detail
22  {
23  template<typename C, typename A, typename B, std::size_t...Is>
24  static constexpr auto contract_constant(C&& c, A&& a, B&& b, std::index_sequence<Is...>)
25  {
26  return make_constant<A>(std::forward<C>(c),
27  get_vector_space_descriptor<0>(a), get_vector_space_descriptor<1>(b), get_vector_space_descriptor<Is + 2>(a)...);
28  }
29 
30 
31  /* // Only for use with alternate code below
32  template<typename A, typename B, std::size_t...Is>
33  static constexpr auto contract_dimensions(A&& a, B&& b, std::index_sequence<Is...>)
34  {
35  return std::tuple {get_vector_space_descriptor<0>(a), get_vector_space_descriptor<1>(b), get_vector_space_descriptor<Is + 2>(a)...};
36  }
37  */
38  } // namespace detail
39 
40 
44 #ifdef __cpp_concepts
45  template<indexible A, indexible B> requires dimension_size_of_index_is<A, 1, index_dimension_of_v<B, 0>, Applicability::permitted> and
46  (index_count_v<A> == dynamic_size or index_count_v<A> <= 2) and (index_count_v<B> == dynamic_size or index_count_v<B> <= 2)
48 #else
49  template<typename A, typename B, std::enable_if_t<indexible<A> and indexible<B> and
50  (dimension_size_of_index_is<A, 1, index_dimension_of<B, 0>::value, Applicability::permitted>) and
51  (index_count<A>::value == dynamic_size or index_count<A>::value <= 2) and (index_count<B>::value == dynamic_size or index_count<B>::value <= 2), int> = 0>
52  constexpr decltype(auto)
53 #endif
54  contract(A&& a, B&& b)
55  {
56  if constexpr (dynamic_dimension<A, 1> or dynamic_dimension<B, 0>) if (get_vector_space_descriptor<1>(a) != get_vector_space_descriptor<0>(b))
57  throw std::domain_error {"In contract, columns of a (" + std::to_string(get_index_dimension_of<1>(a)) +
58  ") do not match rows of b (" + std::to_string(get_index_dimension_of<0>(b)) + ")"};
59 
60  constexpr std::size_t dims = std::max({index_count_v<A>, index_count_v<B>, 2_uz});
61  constexpr std::make_index_sequence<dims - 2> seq;
62 
63  using Scalar = std::decay_t<decltype(std::declval<scalar_type_of_t<A>>() * std::declval<scalar_type_of_t<B>>())>;
64 
65  if constexpr (identity_matrix<B> and square_shaped<B>)
66  {
67  if constexpr (dynamic_dimension<A, 1> and not dynamic_dimension<B, 1>)
68  return internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::forward<A>(a));
69  else
70  return std::forward<A>(a);
71  }
72  else if constexpr (identity_matrix<A> and square_shaped<A>)
73  {
74  if constexpr (dynamic_dimension<B, 0> and not dynamic_dimension<A, 0>)
75  return internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::forward<B>(b));
76  else
77  return std::forward<B>(b);
78  }
79  else if constexpr (zero<A> or zero<B>)
80  {
81  return detail::contract_constant(values::Fixed<Scalar, 0>{}, std::forward<A>(a), std::forward<B>(b), seq);
82  }
83  else if constexpr (constant_matrix<A> and constant_matrix<B>)
84  {
85  auto dim_const = [](const auto& a, const auto& b) {
86  if constexpr (dynamic_dimension<A, 1>) return values::cast_to<Scalar>(get_index_dimension_of<0>(b));
87  else return values::cast_to<Scalar>(get_index_dimension_of<1>(a));
88  }(a, b);
89 
90  auto abd = constant_coefficient{a} * constant_coefficient{b} * std::move(dim_const);
91  return detail::contract_constant(std::move(abd), std::forward<A>(a), std::forward<B>(b), seq);
92  }
93  else if constexpr (diagonal_matrix<A> and constant_matrix<B>)
94  {
95  auto col = diagonal_of(std::forward<A>(a)) * constant_coefficient{b}();
96  return chipwise_operation<1>([&]{ return col; }, get_index_dimension_of<1>(b));
97  //Another way to do this:
98  //auto tup = detail::contract_dimensions(std::forward<A>(a), std::forward<B>(b), seq);
99  //auto op = [](auto&& x){ return std::forward<decltype(x)>(x); };
100  //return n_ary_operation(std::move(tup), op, std::move(col));
101  }
102  else if constexpr (constant_matrix<A> and diagonal_matrix<B>)
103  {
104  auto row = transpose(diagonal_of(std::forward<B>(b))) * constant_coefficient{a}();
105  return chipwise_operation<0>([&]{ return row; }, get_index_dimension_of<0>(a));
106  //Another way to do this:
107  //auto tup = detail::contract_dimensions(std::forward<A>(a), std::forward<B>(b), seq);
108  //auto op = [](auto&& x){ return std::forward<decltype(x)>(x); };
109  //return n_ary_operation(std::move(tup), op, std::move(row));
110  }
111  else if constexpr (diagonal_matrix<A> and diagonal_matrix<B>)
112  {
113  auto ret {to_diagonal(n_ary_operation(std::multiplies<Scalar>{}, diagonal_of(std::forward<A>(a)), diagonal_of(std::forward<B>(b))))};
114  return ret;
115  }
116  else if constexpr (interface::contract_defined_for<A, A, B>)
117  {
118  auto x = interface::library_interface<std::decay_t<A>>::contract(std::forward<A>(a), std::forward<B>(b));
119  auto ret = internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::move(x));
120 
121  constexpr TriangleType tri = triangle_type_of_v<A, B>;
122  if constexpr (tri != TriangleType::any and not triangular_matrix<decltype(ret), tri>)
123  return make_triangular_matrix<tri>(std::move(ret));
124  else
125  return ret;
126  }
127  else if constexpr ((hermitian_matrix<A> or hermitian_matrix<B>) and
128  interface::contract_defined_for<B, decltype(adjoint(std::declval<B>())), decltype(adjoint(std::declval<A>()))>)
129  {
130  return adjoint(interface::library_interface<std::decay_t<B>>::contract(adjoint(std::forward<B>(b)), adjoint(std::forward<A>(a))));
131  }
132  else if constexpr (interface::contract_defined_for<B, decltype(transpose(std::declval<B>())), decltype(transpose(std::declval<A>()))>)
133  {
134  return transpose(interface::library_interface<std::decay_t<B>>::contract(transpose(std::forward<B>(b)), transpose(std::forward<A>(a))));
135  }
136  else
137  {
138  return interface::library_interface<std::decay_t<A>>::contract(std::forward<A>(a), to_native_matrix<A>(std::forward<B>(b)));
139  }
140  }
141 
142 } // namespace OpenKalman
143 
144 
145 #endif //OPENKALMAN_CONTRACT_HPP
Definition: tuple_reverse.hpp:103
The root namespace for OpenKalman.
Definition: basics.hpp:34
The concept, trait, or restraint is permitted, but whether it applies is not necessarily known at com...
typename vector_space_descriptor_of< T, N >::type vector_space_descriptor_of_t
helper template for vector_space_descriptor_of.
Definition: vector_space_descriptor_of.hpp:56
constexpr bool compatible_with_vector_space_descriptor_collection
indexible T is compatible with pattern_collection D.
Definition: compatible_with_vector_space_descriptor_collection.hpp:74
constexpr std::size_t dynamic_size
A constant indicating that a size or index is dynamic.
Definition: global-definitions.hpp:33