OpenKalman
transpose.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2022-2025 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_TRANSPOSE_HPP
17 #define OPENKALMAN_TRANSPOSE_HPP
18 
19 #include "values/values.hpp"
26 #include "conjugate.hpp"
27 
28 namespace OpenKalman
29 {
30  namespace detail
31  {
32  template<typename C, typename Arg, std::size_t...Is>
33  constexpr decltype(auto) transpose_constant(C&& c, Arg&& arg, std::index_sequence<Is...>)
34  {
35  return make_constant<Arg>(std::forward<C>(c),
36  get_pattern_collection<1>(arg), get_pattern_collection<0>(arg), get_pattern_collection<Is + 2>(arg)...);
37  }
38 
39 
40  template<typename NestedLayout, std::size_t indexa, std::size_t indexb>
42  {
43  template<class Extents>
44  struct mapping {
45  using extents_type = Extents;
46  using index_type = typename extents_type::index_type;
47  using size_type = typename extents_type::size_type;
48  using rank_type = typename extents_type::rank_type;
50 
51  private:
52 
53  template<typename = std::make_index_sequence<Extents::rank()>>
54  struct transposed_extents {};
55 
56  template<std::size_t...i>
57  struct transposed_extents<std::index_sequence<i...>>
58  {
59  using type = stdex::extents<index_type, (Extents::static_extent(
60  i == indexa ? indexb : i == indexb ? indexa : i))...>;
61  };
62 
63  template<std::size_t...i>
64  constexpr typename transposed_extents<Extents>::type
65  transpose_extents(const stdex::extents<index_type, i...>& e)
66  {
67  return {e.extent(i == indexa ? indexb : i == indexb ? indexa : i)...};
68  }
69 
70  using nested_mapping_type = typename NestedLayout::template mapping<typename transposed_extents<Extents>::type>;
71 
72  public:
73 
74  constexpr explicit
75  mapping(const nested_mapping_type& map)
76  : nested_mapping_(map), extents_(transpose_extents(map.extents()))
77  {}
78 
79  constexpr const extents_type&
80  extents() const noexcept
81  {
82  return extents_;
83  }
84 
85 #ifdef __cpp_concepts
86  template<std::convertible_to<index_type> IndexType0, std::convertible_to<index_type> IndexType1,
87  std::convertible_to<index_type>...IndexTypes>
88 #else
89  template<typename IndexType0, typename IndexType1, typename...IndexTypes, std::enable_if_t<
90  std::is-convertible_v<IndexType0, index_type> and std::is-convertible_v<IndexType1, index_type> and
91  (... and std::is_convertible_v<IndexTypes, index_type>), int> = 0>
92 #endif
93  index_type
94  operator() (IndexType0 i, IndexType1 j, IndexTypes...ks) const
95  {
96  return nested_mapping_(j, i, ks...);
97  }
98 
99  constexpr index_type
100  required_span_size() const noexcept(noexcept(nested_mapping_.required_span_size()))
101  {
102  return nested_mapping_.required_span_size();
103  }
104 
105  const nested_mapping_type&
106  nested_mapping() const { return nested_mapping_; }
107 
108  static constexpr bool
109  is_always_unique() noexcept { return nested_mapping_type::is_always_unique(); }
110 
111  static constexpr bool
112  is_always_exhaustive() noexcept { return nested_mapping_type::is_always_contiguous(); }
113 
114  static constexpr bool
115  is_always_strided() noexcept { return nested_mapping_type::is_always_strided(); }
116 
117  constexpr bool
118  is_unique() const { return nested_mapping_.is_unique(); }
119 
120  constexpr bool
121  is_exhaustive() const { return nested_mapping_.is_exhaustive(); }
122 
123  constexpr bool
124  is_strided() const { return nested_mapping_.is_strided(); }
125 
126  constexpr index_type
127  stride(size_t r) const
128  {
129  assert(this->is_strided());
130  assert(r < extents_type::rank());
131  return nested_mapping_.stride(r == indexa ? indexb : r == indexb ? indexa : r);
132  }
133 
134  template<class OtherExtents>
135  friend constexpr bool
136  operator==(const mapping& lhs, const mapping<OtherExtents>& rhs) noexcept
137  {
138  return lhs.nested_mapping_ == rhs.nested_mapping_;
139  }
140 
141  private:
142 
143  nested_mapping_type nested_mapping_;
144  extents_type extents_;
145 
146  };
147  };
148 
149 
150  }
151 
152 
157 #ifdef __cpp_concepts
158  template<std::size_t indexa = 0, std::size_t indexb = 1, indexible Arg> requires (indexa < indexb)
159 #else
160  template<std::size_t indexa = 0, std::size_t indexb = 1, typename Arg, std::enable_if_t<
161  indexible<Arg> and (indexa < indexb), int> = 0>
162 #endif
163  constexpr decltype(auto) transpose(Arg&& arg)
164  {
165  constexpr bool square_matrix = values::size_compares_with<index_dimension_of<Arg, 0>, index_dimension_of<Arg, 1>>;
166  constexpr bool diag_invariant = (diagonal_matrix<Arg> or constant_object<Arg>) and square_matrix;
167  constexpr bool hermitian_invariant = hermitian_matrix<Arg> and not values::complex<element_type_of_t<Arg>>;
168  constexpr bool invariant = diag_invariant or hermitian_invariant;
169 
170  if constexpr (invariant)
171  {
172  return std::forward<Arg>(arg);
173  }
174  else if constexpr (indexb == 1 and interface::matrix_transpose_defined_for<Arg&&>)
175  {
177  }
178  else if constexpr (interface::transpose_defined_for<Arg&&, indexa, indexb>)
179  {
180  return interface::library_interface<stdex::remove_cvref_t<Arg>>::template transpose<indexa, indexb>(std::forward<Arg>(arg));
181  }
182  else if constexpr (std::is_lvalue_reference_v<Arg> and index_count_v<Arg> <= 2)
183  {
184  return transpose(get_mdspan(arg));
185  }
186  else if constexpr (constant_object<Arg>)
187  {
188  constexpr std::make_index_sequence<std::max({index_count_v<Arg>, 2_uz}) - 2_uz> seq;
189  return detail::transpose_constant(constant_value(arg), std::forward<Arg>(arg), seq);
190  }
191  else if constexpr (hermitian_matrix<Arg>)
192  {
193  return conjugate(std::forward<Arg>(arg));
194  }
195  else if constexpr (std::is_lvalue_reference_v<Arg>)
196  {
197  auto m = get_mdspan(arg);
199 
200  auto mapping = layout_type::mapping(m.mapping());
201  auto n = stdex::mdspan(m.data_handle(), mapping, m.accessor());
202  return attach_pattern(std::move(n), get_pattern_collection(std::forward<Arg>(arg)));
203  }
204  else if (indexb == 1)
205  {
206  static_assert(interface::matrix_transpose_defined_for<Arg&&>, "Interface not defined");
207  }
208  else
209  {
210  static_assert(interface::transpose_defined_for<Arg&&, indexa, indexb>, "Interface not defined");
211  }
212  }
213 
214 
215 }
216 
217 #endif
constexpr auto attach_pattern(Arg &&arg, P &&p)
Attach a pattern_collection to an indexible object.
Definition: attach_pattern.hpp:36
constexpr auto get_mdspan(T &t)
Get the coordinates::pattern_collection associated with indexible object T.
Definition: get_mdspan.hpp:35
Definition for constant_object.
Header file for code relating to values (e.g., scalars and indices)
decltype(auto) constexpr conjugate(Arg &&arg)
Take the complex conjugate of an indexible object.
Definition: conjugate.hpp:44
decltype(auto) constexpr get_pattern_collection(T &&t)
Get the coordinates::pattern_collection associated with indexible object T.
Definition: get_pattern_collection.hpp:59
Definition: transpose.hpp:41
Definition for diagonal_matrix.
Definition for element_type_of.
Definition: mdspan.hpp:34
The root namespace for OpenKalman.
Definition: basics.hpp:34
An interface to various routines from the linear algebra library associated with indexible object T...
Definition: library_interface.hpp:42
decltype(auto) constexpr transpose(Arg &&arg)
Swap any two indices of an indexible_object.
Definition: transpose.hpp:163
Definition of get_mdspan function.
constexpr auto constant_value(T &&t)
The constant value associated with a constant_object or constant_diagonal_object. ...
Definition: constant_value.hpp:37
Definition for hermitian_matrix.
Definition for indexible.
The dimension of an index for a matrix, expression, or array.
Definition: index_dimension_of.hpp:35
Definition of conjugate function.
Definition: extents.hpp:372