OpenKalman
TensorContractionOp.hpp
Go to the documentation of this file.
1 /* This file is part of OpenKalman, a header-only C++ library for
2  * Kalman filters and other recursive filters.
3  *
4  * Copyright (c) 2023 Christopher Lee Ogden <ogden@gatech.edu>
5  *
6  * This Source Code Form is subject to the terms of the Mozilla Public
7  * License, v. 2.0. If a copy of the MPL was not distributed with this
8  * file, You can obtain one at https://mozilla.org/MPL/2.0/.
9  */
10 
16 #ifndef OPENKALMAN_EIGEN_TRAITS_TENSORCONTRACTIONOP_HPP
17 #define OPENKALMAN_EIGEN_TRAITS_TENSORCONTRACTIONOP_HPP
18 
19 
20 namespace OpenKalman::interface
21 {
22  template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
23  struct indexible_object_traits<Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>>
24  : Eigen3::indexible_object_traits_tensor_base<Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>>
25  {
26  private:
27 
28  using Xpr = Eigen::TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>;
30 
31  public:
32 
33  template<typename Arg, typename N>
34  static constexpr std::size_t get_vector_space_descriptor(const Arg& arg, N n)
35  {
36  using IndexType = typename Xpr::Index;
37  return Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions()[static_cast<IndexType>(n)];
38  }
39 
40 
41  // nested_object() not defined
42 
43 
44  template<typename Arg>
45  static constexpr auto get_constant(const Arg& arg)
46  {
47  using Scalar = scalar_type_of_t<Arg>;
48 
49  if constexpr (zero<LhsXprType>)
50  {
51  return constant_coefficient{arg.lhsExpression()};
52  }
53  else if constexpr (zero<RhsXprType>)
54  {
55  return constant_coefficient{arg.rhsExpression()};
56  }
57  else if constexpr (constant_diagonal_matrix<LhsXprType> and constant_matrix<RhsXprType>)
58  {
59  if constexpr (std::tuple_size_v<decltype(arg.indices())> == 1)
60  {
61  return constant_diagonal_coefficient{arg.lhsExpression()} * constant_coefficient{arg.rhsExpression()};
62  }
63  else
64  {
65  auto& indices = arg.indices();
66  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
67  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
68  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
69  return factor * (constant_diagonal_coefficient{arg.lhsExpression()} * constant_coefficient{arg.rhsExpression()});
70  }
71  }
72  else if constexpr (constant_matrix<LhsXprType> and constant_diagonal_matrix<RhsXprType>)
73  {
74  if constexpr (std::tuple_size_v<decltype(arg.indices())> == 1)
75  {
76  return constant_coefficient{arg.lhsExpression()} * constant_diagonal_coefficient{arg.rhsExpression()};
77  }
78  else
79  {
80  auto& indices = arg.indices();
81  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
82  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
83  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
84  return factor * (constant_coefficient{arg.lhsExpression()} * constant_diagonal_coefficient{arg.rhsExpression()});
85  }
86  }
87  else
88  {
89  auto& indices = arg.indices();
90  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
91  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
92  auto factor = std::accumulate(indices.cbegin(), indices.cend(), Scalar{1}, f);
93  return factor * (constant_coefficient{arg.lhsExpression()} * constant_coefficient{arg.rhsExpression()});
94  }
95  }
96 
97 
98  template<typename Arg>
99  static constexpr auto get_constant_diagonal(const Arg& arg)
100  {
101  if constexpr (std::tuple_size_v<decltype(arg.indices())> == 1)
102  {
103  return values::operation {std::multiplies{},
105  }
106  else
107  {
108  using Scalar = scalar_type_of_t<Arg>;
109  auto& indices = arg.indices();
110  auto dims = Eigen::TensorEvaluator<const Arg, Eigen::DefaultDevice>{arg, Eigen::DefaultDevice{}}.dimensions();
111  auto f = [&dims](const Scalar& a, auto b) -> Scalar { return a * dims[b.first]; };
112  auto factor = std::accumulate(++indices.cbegin(), indices.cend(), Scalar{1}, f);
113  return factor * (constant_diagonal_coefficient{arg.lhsExpression()} * constant_coefficient{arg.rhsExpression()});
114  }
115  }
116 
117 
118  // one_dimensional not defined
119 
120  // is_square not defined
121 
122  //template<TriangleType t>
123  //static constexpr bool is_triangular = std::tuple_size_v<decltype(std::declval<T>().indices())> == 1 and
124  // triangular_matrix<LhsXprType, t> and triangular_matrix<RhsXprType, t>;
125 
126 
127  static constexpr bool is_triangular_adapter = false;
128 
129 
130  static constexpr bool is_hermitian = std::tuple_size_v<decltype(std::declval<Xpr>().indices())> == 1 and
131  ((constant_diagonal_matrix<LhsXprType> and hermitian_matrix<RhsXprType, Applicability::permitted>) or
132  (constant_diagonal_matrix<RhsXprType> and hermitian_matrix<LhsXprType, Applicability::permitted>));
133 
134 
135  static constexpr bool is_writable = false;
136 
137 
138  // raw_data() not defined
139 
140 
141  // layout not defined
142 
143  };
144 
145 } // namespace OpenKalman::interface
146 
147 #endif //OPENKALMAN_EIGEN_TRAITS_TENSORCONTRACTIONOP_HPP
Definition: indexible_object_traits.hpp:36
Definition: basics.hpp:41
An operation involving some number of values.
Definition: operation.hpp:69
typename scalar_type_of< T >::type scalar_type_of_t
helper template for scalar_type_of.
Definition: scalar_type_of.hpp:54
Definition: eigen-comma-initializers.hpp:20
Trait object providing get and set routines for Eigen tensors.
Definition: eigen-tensor-forward-declarations.hpp:114
The constant associated with T, assuming T is a constant_matrix.
Definition: constant_coefficient.hpp:36
The constant associated with T, assuming T is a constant_diagonal_matrix.
Definition: constant_diagonal_coefficient.hpp:32
constexpr auto get_vector_space_descriptor(const T &t, const N &n)
Get the coordinates::pattern object for index N of indexible object T.
Definition: get_vector_space_descriptor.hpp:56