16 #ifndef OPENKALMAN_CONTRACT_HPP 17 #define OPENKALMAN_CONTRACT_HPP 23 template<
typename C,
typename A,
typename B, std::size_t...Is>
24 static constexpr
auto contract_constant(C&& c, A&& a, B&& b, std::index_sequence<Is...>)
26 return make_constant<A>(std::forward<C>(c),
27 get_vector_space_descriptor<0>(a), get_vector_space_descriptor<1>(b), get_vector_space_descriptor<Is + 2>(a)...);
45 template<indexible A, indexible B> requires dimension_size_of_index_is<A, 1, index_dimension_of_v<B, 0>,
Applicability::permitted> and
46 (index_count_v<A> ==
dynamic_size or index_count_v<A> <= 2) and (index_count_v<B> ==
dynamic_size or index_count_v<B> <= 2)
49 template<
typename A,
typename B, std::enable_if_t<indexible<A> and indexible<B> and
50 (dimension_size_of_index_is<A, 1, index_dimension_of<B, 0>::value, Applicability::permitted>) and
51 (index_count<A>::value == dynamic_size or index_count<A>::value <= 2) and (index_count<B>::value == dynamic_size or index_count<B>::value <= 2),
int> = 0>
52 constexpr decltype(auto)
54 contract(A&& a, B&& b)
56 if constexpr (dynamic_dimension<A, 1> or dynamic_dimension<B, 0>) if (get_vector_space_descriptor<1>(a) != get_vector_space_descriptor<0>(b))
57 throw std::domain_error {"In contract, columns of a (" + std::to_
string(get_index_dimension_of<1>(a)) +
58 ") do not match rows of b (" + std::to_
string(get_index_dimension_of<0>(b)) + ")"};
60 constexpr std::
size_t dims = std::max({index_count_v<A>, index_count_v<B>, 2_uz});
61 constexpr std::make_index_sequence<dims - 2> seq;
63 using Scalar = std::decay_t<decltype(std::declval<scalar_type_of_t<A>>() * std::declval<scalar_type_of_t<B>>())>;
65 if constexpr (
identity_matrix<B> and square_shaped<B>)
67 if constexpr (dynamic_dimension<A, 1> and not dynamic_dimension<B, 1>)
68 return
internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::forward<A>(a));
70 return std::forward<A>(a);
72 else if constexpr (
identity_matrix<A> and square_shaped<A>)
74 if constexpr (dynamic_dimension<B, 0> and not dynamic_dimension<A, 0>)
75 return
internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::forward<B>(b));
77 return std::forward<B>(b);
79 else if constexpr (zero<A> or zero<B>)
81 return detail::contract_constant(values::Fixed<Scalar, 0>{}, std::forward<A>(a), std::forward<B>(b), seq);
83 else if constexpr (constant_matrix<A> and constant_matrix<B>)
85 auto dim_const = [](const auto& a, const auto& b) {
86 if constexpr (dynamic_dimension<A, 1>) return values::cast_to<Scalar>(get_index_dimension_of<0>(b));
87 else return values::cast_to<Scalar>(get_index_dimension_of<1>(a));
90 auto abd = constant_coefficient{a} * constant_coefficient{b} * std::move(dim_const);
91 return detail::contract_constant(std::move(abd), std::forward<A>(a), std::forward<B>(b), seq);
93 else if constexpr (diagonal_matrix<A> and constant_matrix<B>)
95 auto col = diagonal_of(std::forward<A>(a)) * constant_coefficient{b}();
96 return chipwise_operation<1>([&]{ return col; }, get_index_dimension_of<1>(b));
102 else if constexpr (constant_matrix<A> and diagonal_matrix<B>)
104 auto row = transpose(diagonal_of(std::forward<B>(b))) * constant_coefficient{a}();
105 return chipwise_operation<0>([&]{ return row; }, get_index_dimension_of<0>(a));
111 else if constexpr (diagonal_matrix<A> and diagonal_matrix<B>)
113 auto ret {to_diagonal(n_ary_operation(std::multiplies<Scalar>{}, diagonal_of(std::forward<A>(a)), diagonal_of(std::forward<B>(b))))};
116 else if constexpr (
interface::contract_defined_for<A, A, B>)
118 auto x =
interface::library_
interface<std::decay_t<A>>::contract(std::forward<A>(a), std::forward<B>(b));
119 auto ret =
internal::make_fixed_size_adapter<vector_space_descriptor_of_t<A, 0>, vector_space_descriptor_of_t<B, 1>>(std::move(x));
121 constexpr TriangleType tri = triangle_type_of_v<A, B>;
122 if constexpr (tri != TriangleType::any and not triangular_matrix<decltype(ret), tri>)
123 return make_triangular_matrix<tri>(std::move(ret));
127 else if constexpr ((hermitian_matrix<A> or hermitian_matrix<B>) and
128 interface::contract_defined_for<B, decltype(adjo
int(std::declval<B>())), decltype(adjo
int(std::declval<A>()))>)
130 return adjo
int(
interface::library_
interface<std::decay_t<B>>::contract(adjo
int(std::forward<B>(b)), adjo
int(std::forward<A>(a))));
132 else if constexpr (
interface::contract_defined_for<B, decltype(transpose(std::declval<B>())), decltype(transpose(std::declval<A>()))>)
134 return transpose(
interface::library_
interface<std::decay_t<B>>::contract(transpose(std::forward<B>(b)), transpose(std::forward<A>(a))));
138 return
interface::library_
interface<std::decay_t<A>>::contract(std::forward<A>(a), to_native_matrix<A>(std::forward<B>(b)));
Definition: tuple_reverse.hpp:103
The root namespace for OpenKalman.
Definition: basics.hpp:34
The concept, trait, or restraint is permitted, but whether it applies is not necessarily known at com...
typename vector_space_descriptor_of< T, N >::type vector_space_descriptor_of_t
helper template for vector_space_descriptor_of.
Definition: vector_space_descriptor_of.hpp:56
constexpr bool compatible_with_vector_space_descriptor_collection
indexible T is compatible with pattern_collection D.
Definition: compatible_with_vector_space_descriptor_collection.hpp:74
constexpr std::size_t dynamic_size
A constant indicating that a size or index is dynamic.
Definition: global-definitions.hpp:33