16 #ifndef OPENKALMAN_SOLVE_HPP 17 #define OPENKALMAN_SOLVE_HPP 23 template<
typename A,
typename B>
24 void solve_check_A_and_B_rows_match(
const A& a,
const B& b)
26 if (get_vector_space_descriptor<0>(a) != get_vector_space_descriptor<0>(b))
27 throw std::domain_error {
"The rows of the two operands of the solve function must be the same, but instead " 28 "the first operand has " + std::to_string(get_index_dimension_of<0>(a)) +
" rows and the second operand has " +
29 std::to_string(get_index_dimension_of<0>(b)) +
" rows"};
33 template<
typename A,
typename B,
typename Arg>
34 decltype(
auto) wrap_solve_result_impl(Arg&& arg)
38 return make_triangular_matrix<tri>(std::forward<Arg>(arg));
39 else if constexpr (((constant_diagonal_matrix<A> and hermitian_matrix<B>) or (constant_diagonal_matrix<B> and hermitian_matrix<A>)))
42 return std::forward<Arg>(arg);
46 template<
typename A,
typename B,
typename Arg>
47 decltype(
auto) wrap_solve_result(Arg&& arg)
49 using V0 = vector_space_descriptor_of_t<A, 1>;
50 using V1 = vector_space_descriptor_of_t<B, 1>;
51 return internal::make_fixed_size_adapter<V0, V1>(wrap_solve_result_impl<A, B>(std::forward<Arg>(arg)));
71 template<
bool must_be_unique = false,
bool must_be_exact = false,
typename A,
typename B> requires
72 (not zero<A> or not zero<B> or not must_be_unique) and
73 (not zero<A> or not (constant_matrix<B> or constant_diagonal_matrix<B>) or zero<B> or not must_be_exact) and
74 (not constant_matrix<A> or not constant_diagonal_matrix<B> or has_dynamic_dimensions<A> or
75 (index_dimension_of_v<A, 0> <= index_dimension_of_v<A, 1> and index_dimension_of_v<B, 0> <= index_dimension_of_v<A, 1>) or
76 (index_dimension_of_v<A, 0> == 1 and index_dimension_of_v<B, 0> == 1) or not must_be_exact)
77 constexpr compatible_with_vector_space_descriptor_collection<std::tuple<vector_space_descriptor_of_t<A, 1>, vector_space_descriptor_of_t<B, 1>>>
auto 79 template<
bool must_be_unique =
false,
bool must_be_exact =
false,
typename A,
typename B, std::enable_if_t<
80 (not zero<A> or not zero<B> or not must_be_unique) and
81 (not zero<A> or not (constant_matrix<B> or constant_diagonal_matrix<B>) or zero<B> or not must_be_exact) and
82 (not constant_matrix<A> or not constant_diagonal_matrix<B> or has_dynamic_dimensions<A> or
83 (index_dimension_of_v<A, 0> <= index_dimension_of_v<A, 1> and index_dimension_of_v<B, 0> <= index_dimension_of_v<A, 1>) or
84 (index_dimension_of_v<A, 0> == 1 and index_dimension_of_v<B, 0> == 1) or not must_be_exact),
int> = 0>
89 static_assert(dynamic_dimension<A, 0> or dynamic_dimension<B, 0> or index_dimension_of_v<A, 0> == index_dimension_of_v<B, 0>,
90 "The rows of two operands of the solve function must be the same.");
94 if constexpr (zero<B>)
96 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
98 if constexpr (must_be_unique and not constant_matrix<A> and not constant_diagonal_matrix<A>)
101 if (
reduce([](
auto c1,
auto c2) {
if (c1 == 0 and c2 == 0)
return 0;
else return 1; }, a) == 0)
102 throw std::runtime_error {
"solve function requires a unique solution, " 103 "but because operands A and B are both zero matrices, result X may take on any value"};
104 else return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
106 else return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
108 else if constexpr (zero<A>)
110 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
111 return make_zero<B>(get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
113 else if constexpr (index_dimension_of_v<A, 1> == 1 and (index_dimension_of_v<A, 0> == 1 or index_dimension_of_v<B, 0> == 1))
115 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
119 if constexpr (identity_matrix<A>)
120 return internal::make_fixed_size_adapter<coordinates::Axis, V1>(std::forward<B>(b));
122 return internal::make_fixed_size_adapter<coordinates::Axis, V1>(scalar_quotient(std::forward<B>(b), internal::get_singular_component(std::forward<A>(a))));
124 else if constexpr (constant_diagonal_matrix<A> and (square_shaped<A> or (not dynamic_dimension<A, 1> and index_dimension_of_v<A, 1> == index_dimension_of_v<B, 0>)))
126 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
128 using V0 = decltype(internal::best_vector_space_descriptor(get_vector_space_descriptor<0>(b), get_vector_space_descriptor<0>(a), get_vector_space_descriptor<1>(a)));
131 if constexpr (identity_matrix<A> and square_shaped<A>)
132 return internal::make_fixed_size_adapter<V0, V1>(std::forward<B>(b));
136 else if constexpr (constant_matrix<A> and (constant_matrix<B>))
138 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
140 return make_constant<B>(
142 get_vector_space_descriptor<1>(a), get_vector_space_descriptor<1>(b));
144 else if constexpr (constant_matrix<A> and (index_dimension_of_v<A, 0> == 1 or index_dimension_of_v<B, 0> == 1 or
145 (not must_be_exact and (not must_be_unique or (not has_dynamic_dimensions<A> and index_dimension_of_v<A, 0> >= index_dimension_of_v<A, 1>)))))
147 if constexpr (dynamic_dimension<A, 0> or dynamic_dimension<B, 0>) detail::solve_check_A_and_B_rows_match(a, b);
149 return detail::wrap_solve_result<A, B>(
152 else if constexpr (diagonal_matrix<A> and square_shaped<A>)
154 auto op = [](
auto&& b_elem,
auto&& a_elem) {
157 if constexpr (not std::numeric_limits<
scalar_type_of_t<B>>::has_infinity)
throw std::logic_error {
158 "In solve function, an element should be infinite, but the scalar type does not have infinite values"};
159 else return std::numeric_limits<scalar_type_of_t<B>>::infinity();
163 return std::forward<decltype(b_elem)>(b_elem) /
static_cast<scalar_type_of_t<B>>(std::forward<decltype(a_elem)>(a_elem));
167 return detail::wrap_solve_result<A, B>(
170 else if constexpr (interface::solve_defined_for<A, must_be_unique, must_be_exact, A, B>)
172 return detail::wrap_solve_result<A, B>(
173 Interface::template solve<must_be_unique, must_be_exact>(std::forward<A>(a), std::forward<B>(b)));
177 return detail::wrap_solve_result<A, B>(
178 Interface::template solve<must_be_unique, must_be_exact>(std::forward<A>(a), to_native_matrix<A>(std::forward<B>(b))));
185 #endif //OPENKALMAN_SOLVE_HPP constexpr auto n_ary_operation(const std::tuple< Ds... > &d_tup, Operation &&operation, Args &&...args)
Perform a component-wise n-ary operation, using broadcasting to match the size of a pattern matrix...
Definition: n_ary_operation.hpp:319
TriangleType
The type of a triangular matrix.
Definition: global-definitions.hpp:60
decltype(auto) constexpr make_hermitian_matrix(Arg &&arg)
Creates a hermitian_matrix by, if necessary, wrapping the argument in a hermitian_adapter.
Definition: make_hermitian_matrix.hpp:37
typename scalar_type_of< T >::type scalar_type_of_t
helper template for scalar_type_of.
Definition: scalar_type_of.hpp:54
Lower, upper, or diagonal matrix.
decltype(auto) constexpr reduce(BinaryFunction &&b, Arg &&arg)
Perform a partial reduction based on an associative binary function, across one or more indices...
Definition: reduce.hpp:143
The constant associated with T, assuming T is a constant_matrix.
Definition: constant_coefficient.hpp:36
decltype(auto) constexpr all_vector_space_descriptors(const T &t)
Return a collection of coordinates::pattern objects associated with T.
Definition: all_vector_space_descriptors.hpp:52
The root namespace for OpenKalman.
Definition: basics.hpp:34
An interface to various routines from the linear algebra library associated with indexible object T...
Definition: library_interface.hpp:37
The constant associated with T, assuming T is a constant_diagonal_matrix.
Definition: constant_diagonal_coefficient.hpp:32
constexpr auto solve(A &&a, B &&b)
Solve the equation AX = B for X, which may or may not be a unique solution.
Definition: solve.hpp:87
decltype(auto) constexpr diagonal_of(Arg &&arg)
Extract a column vector (or column slice for rank>2 tensors) comprising the diagonal elements...
Definition: diagonal_of.hpp:33
typename vector_space_descriptor_of< T, N >::type vector_space_descriptor_of_t
helper template for vector_space_descriptor_of.
Definition: vector_space_descriptor_of.hpp:56