16 #ifndef OPENKALMAN_N_ARY_OPERATION_HPP 17 #define OPENKALMAN_N_ARY_OPERATION_HPP 34 template<std::size_t...ixs,
typename DTup,
typename...Args>
37 return ([](
const DTup& d_tup,
const Args&...args){
38 constexpr
auto ix = ixs;
39 return ([](
const DTup& d_tup,
const auto& arg){
40 using Arg = decltype(arg);
43 auto arg_d = get_pattern_collection<ix>(arg);
44 auto tup_d = std::get<ix>(d_tup);
47 if (not (arg_d == tup_d) and not coordinates::is_uniform_pattern_component_of(arg_d, tup_d))
48 throw std::logic_error {
"In an argument to n_ary_operation, the dimension of index " +
49 std::to_string(ix) +
" is " + std::to_string(dim_arg_d) +
", but should be 1 " +
50 (dim_tup_d == 1 ?
"" :
"or " + std::to_string(dim_tup_d)) +
51 "(the dimension of Ds number " + std::to_string(ix)};
55 using D_Arg = std::decay_t<decltype(get_index_pattern(arg, std::integral_constant<std::size_t, ix>{}))>;
56 using D = collections::collection_element_t<ix, DTup>;
57 static_assert(coordinates::compares_with<D_Arg, D> or equivalent_to_uniform_pattern_component_of<D_Arg, D> or
58 (ix >= index_count_v<Arg> and coordinates::uniform_pattern<D>),
59 "In argument to n_ary_operation, the dimension of each index must be either 1 or that of Ds.");
62 }(d_tup, args...),...);
66 template<
typename Op,
typename...Args, std::size_t...I>
69 return std::is_invocable_v<Op, Args..., decltype(I)...>;
73 template<
typename Op, std::size_t...I,
typename...Args>
74 constexpr decltype(
auto) n_ary_invoke_op(const Op& op,
std::index_sequence<I...> seq, Args&&...args)
76 if constexpr (is_invocable_with_indices<const Op&, Args&&...>(seq))
77 return op(std::forward<Args>(args)...,
static_cast<decltype(I)
>(0)...);
79 return op(std::forward<Args>(args)...);
84 template<
typename Operation, std::size_t indices,
typename...Args>
86 template<
typename Operation, std::size_t indices,
typename = void,
typename...Args>
92 template<
typename Op, std::size_t indices,
typename...Args>
93 requires (is_invocable_with_indices<Op, Args...>(std::make_index_sequence<indices> {})) or
94 std::is_invocable_v<Op, Args...>
97 template<typename Op, std::size_t indices, typename...Args>
99 is_invocable_with_indices<Op, Args...>(std::make_index_sequence<indices> {}) or
100 std::is_invocable_v<Op, Args...>>, Args...>
103 using type = decltype(n_ary_invoke_op(std::declval<Op>(), std::make_index_sequence<indices> {}, std::declval<Args>()...));
107 template<
typename Op, std::size_t indices,
typename...Args>
109 #ifdef __cpp_concepts
116 #ifndef __cpp_concepts 117 template<
typename Op, std::size_t Indices,
typename = void,
typename...Args>
120 template<
typename Op, std::size_t Indices,
typename...Args>
122 std::is_invocable<Op, typename std::add_lvalue_reference<typename element_type_of<Args>::type>::type...>::value or
123 is_invocable_with_indices<Op, typename std::add_lvalue_reference<typename element_type_of<Args>::type>::type...>(
124 std::make_index_sequence<Indices> {})>, Args...>
129 template<
typename Op, std::size_t Indices,
typename...Args>
130 #ifdef __cpp_concepts 131 concept n_ary_operator = std::is_invocable_v<Op, std::add_lvalue_reference_t<element_type_of_t<Args>>...> or
132 is_invocable_with_indices<Op, std::add_lvalue_reference_t<element_type_of_t<Args>>...>(std::make_index_sequence<Indices> {});
138 template<
typename Arg, std::size_t...I,
typename...J>
141 if constexpr (
sizeof...(I) ==
sizeof...(J))
142 return access(std::forward<Arg>(arg), (j < get_index_extent<I>(arg) ? j : 0)...);
144 return access(std::forward<Arg>(arg), [](
auto dim,
const auto& j_tup){
145 auto j = std::get<I>(j_tup);
146 if (j < dim)
return j;
148 }(get_index_extent<I>(arg), std::tuple {j...})...);
152 template<
typename Op,
typename ArgsTup, std::size_t...ArgI,
typename...J>
156 return op(n_ary_operation_get_component_impl(
157 std::get<ArgI>(std::forward<ArgsTup>(args_tup)),
161 return op(n_ary_operation_get_component_impl(
162 std::get<ArgI>(std::forward<ArgsTup>(args_tup)),
168 template<
typename M,
typename Op,
typename ArgsTup,
typename...J>
169 inline void n_ary_operation_iterate(M& m,
const Op& op, ArgsTup&& args_tup,
std::index_sequence<>, J...j)
171 std::make_index_sequence<collections::size_of_v<ArgsTup>> seq;
172 set_component(m, n_ary_operation_get_component(op, std::forward<ArgsTup>(args_tup), seq, j...), j...);
176 template<
typename M,
typename Op,
typename ArgsTup, std::size_t I, std::size_t...Is,
typename...J>
179 for (std::size_t i = 0; i < get_index_extent<I>(m); i++)
180 n_ary_operation_iterate(m, op, std::forward<ArgsTup>(args_tup),
std::index_sequence<Is...> {}, j..., i);
185 template<
typename...Ds,
typename Arg, std::size_t...Ix_Ds>
186 static decltype(
auto)
193 get_index_extent<Ix_Ds>(arg))...);
197 template<
typename PatternMatrix,
typename...Ds,
typename Op,
typename...Args>
198 static constexpr
auto 199 n_ary_operation_impl(
const std::tuple<Ds...>& d_tup, Op&& op, Args&&...args)
201 constexpr std::index_sequence_for<Ds...> seq;
204 if constexpr (
sizeof...(Args) > 0 and (constant_object<Args> and ...) and not is_invocable_with_indices<Op,
element_type_of_t<Args>...>(seq))
208 [](
auto&&...as){
return make_constant<PatternMatrix>(std::forward<decltype(as)>(as)...); },
209 std::tuple_cat(std::tuple{std::move(c)}, d_tup));
212 else if constexpr (is_invocable_with_indices<Op, std::add_lvalue_reference_t<
element_type_of_t<Args>>...>(seq) and
213 interface::n_ary_operation_defined_for<PatternMatrix,
const std::tuple<Ds...>&, Op&&, Args&&...>)
218 return attach_pattern(std::forward<decltype(a)>(a), std::forward<decltype(vs)>(vs)...);
219 }, std::tuple_cat(std::forward_as_tuple(std::move(ret)), d_tup));
227 if constexpr (((coordinates::dimension_of_v<Ds> == 1) and ...))
230 auto e = op(
access(std::forward<Args>(args))...);
231 return make_dense_object_from<PatternMatrix, layout_of_t<PatternMatrix>, Scalar>(d_tup, e);
236 return make_dense_object<PatternMatrix, layout_of_t<PatternMatrix>, Scalar>(std::forward<decltype(ds)>(ds)...);
238 n_ary_operation_iterate(m, op, std::forward_as_tuple(std::forward<Args>(args)...), seq);
315 #ifdef __cpp_concepts 317 detail::n_ary_operator<Operation,
sizeof...(Ds), Args...> and (... and (coordinates::dimension_of_v<Ds> != 0))
320 template<
typename...Ds,
typename Operation,
typename...Args, std::enable_if_t<
321 (coordinates::pattern<Ds> and ...) and (indexible<Args> and ...) and (
sizeof...(Args) > 0) and
322 detail::n_ary_operator<Operation,
sizeof...(Ds), Args...> and (... and (coordinates::dimension_of_v<Ds> != 0)),
int> = 0>
327 detail::check_n_ary_dims(std::index_sequence_for<Ds...> {}, d_tup, args...);
329 return detail::n_ary_operation_impl<Arg0>(d_tup, std::forward<Operation>(
operation), std::forward<Args>(args)...);
335 template<std::size_t ix,
typename Arg,
typename...Args>
336 constexpr
auto find_max_dim(
const Arg& arg,
const Args&...args)
338 if constexpr (
sizeof...(Args) == 0)
340 auto ret = get_pattern_collection<ix>(arg);
343 if (
get_dimension(ret) == 0)
throw std::invalid_argument {
"A dimension of an arguments " 344 "to n_ary_operation is zero for at least index " + std::to_string(ix) +
"."};
346 else static_assert(index_dimension_of_v<Arg, ix> != 0,
"Arguments to n_ary_operation cannot have zero dimensions");
351 auto max_d = find_max_dim<ix>(args...);
352 using Arg_D = std::decay_t<decltype(get_index_pattern(arg, std::integral_constant<std::size_t, ix>{}))>;
353 using Max_D = decltype(max_d);
355 if constexpr (fixed_pattern<Arg_D> and fixed_pattern<Max_D>)
357 constexpr
auto dim_arg_d = coordinates::dimension_of_v<Arg_D>;
358 if constexpr (compares_with<Arg_D, Max_D>or (dim_arg_d == 1 and equivalent_to_uniform_pattern_component_of<Arg_D, Max_D>))
364 constexpr
auto dim_max_d = coordinates::dimension_of_v<Max_D>;
365 static_assert(dim_max_d != 1 or not equivalent_to_uniform_pattern_component_of<Max_D, Arg_D>,
366 "The dimension of arguments to n_ary_operation are not compatible with each other for at least one index.");
367 return get_pattern_collection<ix>(arg);
370 else if constexpr (coordinates::euclidean_pattern<Arg_D> and coordinates::euclidean_pattern<Max_D>)
372 if constexpr (fixed_pattern<Arg_D>)
374 constexpr std::size_t a = coordinates::dimension_of_v<Arg_D>;
376 if (a != m and a != 1 and m != 1)
throw std::invalid_argument {
"The dimension of arguments to n_ary_operation " 377 "are not compatible with each other for at least index " + std::to_string(ix) +
"."};
379 if constexpr (a == 1)
return max_d;
380 else return get_pattern_collection<ix>(arg);
382 else if constexpr (fixed_pattern<Max_D>)
384 auto arg_d = get_pattern_collection<ix>(arg);
386 constexpr std::size_t m = coordinates::dimension_of_v<Max_D>;
387 if (a != m and a != 1 and m != 1)
throw std::invalid_argument {
"The dimension of arguments to n_ary_operation " 388 "are not compatible with each other for at least index " + std::to_string(ix) +
"."};
390 if constexpr (m == 1)
return arg_d;
395 std::size_t a = get_index_extent<ix>(arg);
397 if (a == m or a == 1)
return m;
398 else if (m == 1 and m <= a)
return a;
399 else throw std::invalid_argument {
"The dimension of arguments to n_ary_operation are not compatible with " 400 "each other for at least index " + std::to_string(ix) +
"."};
405 auto arg_d = get_pattern_collection<ix>(arg);
407 if (coordinates::is_uniform_pattern_component_of(arg_d, max_d))
413 else if (coordinates::is_uniform_pattern_component_of(max_d, arg_d))
419 else throw std::invalid_argument {
"The dimension of arguments to n_ary_operation are not compatible with " 420 "each other for at least index " + std::to_string(ix) +
"."};
426 template<std::size_t...ixs,
typename...Args>
429 return std::tuple {find_max_dim<ixs>(args...)...};
483 #ifdef __cpp_concepts 484 template<
typename Operation,
indexible...Args> requires (
sizeof...(Args) > 0) and
485 detail::n_ary_operator<Operation, std::max({index_count_v<Args>...}), Args...>
487 template<
typename Operation,
typename...Args, std::enable_if_t<(indexible<Args> and ...) and
488 (
sizeof...(Args) > 0) and detail::n_ary_operator<Operation, std::max({
index_count<Args>::value...}), Args...>,
int> = 0>
493 auto d_tup = detail::find_max_dims(std::make_index_sequence<std::max({index_count_v<Args>...})> {}, args...);
495 return detail::n_ary_operation_impl<Arg0>(std::move(d_tup), std::forward<Operation>(
operation), std::forward<Args>(args)...);
501 template<
typename M,
typename Operation,
typename Vs_tuple,
typename Index_seq,
typename K_seq,
typename...Is>
502 void nullary_set_components(M& m,
const Operation& op,
const Vs_tuple&, Index_seq, K_seq, Is...is)
504 constexpr
auto seq = std::make_index_sequence<
sizeof...(Is)> {};
505 if constexpr (detail::is_invocable_with_indices<Operation>(seq))
506 set_component(m, op(is...), is...);
508 set_component(m, op(), is...);
512 template<std::size_t DsIndex, std::size_t...DsIndices,
typename M,
typename Operation,
513 typename Vs_tuple, std::size_t...indices, std::size_t...Ks,
typename...Is>
514 void nullary_set_components(M& m,
const Operation& op,
const Vs_tuple& ds_tup,
517 if constexpr (((DsIndex == indices) or ...))
519 constexpr std::integral_constant<size_t, ((DsIndex == indices ? Ks : 0) + ...)> i;
520 nullary_set_components<DsIndices...>(m, op, ds_tup, index_seq, k_seq, is..., i);
525 for (std::size_t i = 0; i < get_dimension(std::get<DsIndex>(ds_tup)); ++i)
527 nullary_set_components<DsIndices...>(m, op, ds_tup, index_seq, k_seq, is..., i);
533 template<std::size_t CurrentOpIndex, std::size_t factor,
typename M,
typename Operations_tuple,
534 typename Vs_tuple,
typename UniqueIndicesSeq, std::size_t...AllDsIndices,
typename K_seq>
535 void nullary_iterate(M& m,
const Operations_tuple& op_tup,
const Vs_tuple& ds_tup,
538 nullary_set_components<AllDsIndices...>(m, std::get<CurrentOpIndex>(op_tup), ds_tup, unique_indices_seq, k_seq);
542 template<std::size_t CurrentOpIndex, std::size_t factor, std::size_t
index, std::size_t...indices,
543 typename M,
typename Operations_tuple,
typename Vs_tuple,
typename UniqueIndicesSeq,
typename AllDsSeq,
544 std::size_t...Ks, std::size_t...Js,
typename...J_seqs>
545 void nullary_iterate(M& m,
const Operations_tuple& op_tup,
const Vs_tuple& ds_tup,
549 constexpr std::size_t new_factor = factor / coordinates::dimension_of_v<collections::collection_element_t<index, Vs_tuple>>;
551 (nullary_iterate<CurrentOpIndex + new_factor * Js, new_factor, indices...>(
552 m, op_tup, ds_tup, unique_indices_seq, all_ds_seq,
std::index_sequence<Ks..., Js>{}, j_seqs...),...);
632 #ifdef __cpp_concepts 636 (detail::n_ary_operator<Operations,
sizeof...(Ds)> and ...)
638 template<
typename PatternMatrix, std::size_t...indices,
typename...Ds,
typename...Operations, std::enable_if_t<
639 indexible<PatternMatrix> and (coordinates::pattern<Ds> and ...) and
642 (detail::n_ary_operator<Operations,
sizeof...(Ds)> and ...),
int> = 0>
650 if constexpr (
sizeof...(Operations) == 1)
652 return detail::n_ary_operation_impl<std::decay_t<PatternMatrix>>(d_tup, operations...);
655 else if constexpr (((not dynamic_pattern<Ds>) and ...) and
656 sizeof...(operations) == (coordinates::dimension_of_v<Ds> * ...) and
657 not (detail::is_invocable_with_indices<const Operations&>(std::make_index_sequence<
sizeof...(Ds)> {}) or ...))
659 return make_dense_object_from<PatternMatrix, layout_of_t<PatternMatrix>, Scalar>(d_tup, operations()...);
664 auto m = make_dense_object<PatternMatrix, layout_of_t<PatternMatrix>, Scalar>(d_tup);
665 detail::nullary_iterate<0,
sizeof...(Operations), indices...>(
667 std::forward_as_tuple(operations...),
670 std::index_sequence_for<Ds...> {},
728 #ifdef __cpp_concepts 730 requires ((not dynamic_dimension<PatternMatrix, indices>) and ...) and
731 (
sizeof...(Operations) == (1 * ... * index_dimension_of_v<PatternMatrix, indices>))
733 template<
typename PatternMatrix, std::size_t...indices,
typename...Ds,
typename...Operations, std::enable_if_t<
734 indexible<PatternMatrix> and (coordinates::pattern<Ds> and ...) and
735 ((not dynamic_dimension<PatternMatrix, indices>) and ...) and
741 auto d_tup = get_pattern_collection<PatternMatrix>();
742 return n_ary_operation<PatternMatrix, indices...>(d_tup, operations...);
752 template<
typename Operation,
typename Elem,
typename...J>
753 inline void do_elem_operation_in_place_impl(
const Operation&
operation, Elem& elem, J...j)
755 if constexpr (std::is_invocable_v<const Operation&, Elem&, J...>)
762 template<
typename Operation,
typename Arg,
typename...J>
763 inline void do_elem_operation_in_place(
const Operation& operation, Arg& arg, J...j)
765 auto&& elem =
access(arg, j...);
766 if constexpr (std::is_assignable_v<decltype((elem)), std::decay_t<decltype(elem)>>)
768 do_elem_operation_in_place_impl(operation, elem, j...);
772 auto e {std::forward<decltype(elem)>(elem)};
773 static_assert(std::is_assignable_v<decltype((e)), std::decay_t<decltype(elem)>>);
774 do_elem_operation_in_place_impl(operation, e, j...);
775 set_component(arg, std::move(e), j...);
780 template<
typename Operation,
typename Arg,
typename Count,
typename...J>
781 inline void unary_operation_in_place_impl(
const Operation& operation, Arg& arg,
const Count& count, J...j)
783 constexpr
auto n =
sizeof...(J);
786 for (std::size_t i = 0; i < get_index_extent<n>(arg); ++i)
787 unary_operation_in_place_impl(operation, arg, count, j..., i);
791 do_elem_operation_in_place(operation, arg, j...);
832 #ifdef __cpp_concepts 833 template<
typename Operation, writable Arg> requires detail::n_ary_operator<Operation, index_count_v<Arg>, Arg>
835 template<
typename Operation,
typename Arg, std::enable_if_t<writable<Arg> and
836 detail::n_ary_operator<Operation, index_count_v<Arg>, Arg>,
int> = 0>
838 constexpr decltype(
auto)
843 detail::unary_operation_in_place_impl(operation, arg,
count_indices(arg));
844 return std::forward<Arg>(arg);
constexpr auto n_ary_operation(const Operations &...operations)
This is an overloaded member function, provided for convenience. It differs from the above function o...
Definition: n_ary_operation.hpp:739
constexpr auto attach_pattern(Arg &&arg, P &&p)
Attach a pattern_collection to an indexible object.
Definition: attach_pattern.hpp:36
constexpr auto n_ary_operation(const std::tuple< Ds... > &d_tup, Operation &&operation, Args &&...args)
Perform a component-wise n-ary operation, using broadcasting to match the size of a pattern matrix...
Definition: n_ary_operation.hpp:325
Definition: n_ary_operation.hpp:118
Definition for layout_of.
Definition for constant_value.
Definition for constant_object.
constexpr auto count_indices(const T &)
Get the number of indices necessary to address all the components of an indexible object...
Definition: count_indices.hpp:51
constexpr bool pattern
An object describing the set of coordinates associated with a tensor index.
Definition: pattern.hpp:31
decltype(auto) constexpr unary_operation_in_place(const Operation &operation, Arg &&arg)
Perform a component-wise, in-place unary operation.
Definition: n_ary_operation.hpp:839
constexpr auto dimension_of_v
Helper template for coordinates::dimension_of.
Definition: dimension_of.hpp:56
constexpr bool indexible
T is a multidimensional array type.
Definition: indexible.hpp:32
constexpr bool value
T is a fixed or dynamic value that is reducible to a number.
Definition: value.hpp:45
The size of a coordinates::pattern.
Definition: dimension_of.hpp:36
decltype(auto) constexpr broadcast(Arg &&arg, const Factors &...factors)
Broadcast an object by replicating it by factors specified for each index.
Definition: broadcast.hpp:49
Definition for dynamic_dimension.
decltype(auto) constexpr apply(F &&f, T &&t)
A generalization of std::apply.
Definition: apply.hpp:49
constexpr auto get_dimension(const Arg &arg)
Get the vector dimension of coordinates::pattern Arg.
Definition: get_dimension.hpp:54
constexpr bool compares_with_pattern_collection
Compares the associated pattern collection of indexible T with pattern_collection D...
Definition: compares_with_pattern_collection.hpp:33
The root namespace for OpenKalman.
Definition: basics.hpp:34
An interface to various routines from the linear algebra library associated with indexible object T...
Definition: library_interface.hpp:42
decltype(auto) constexpr access(Arg &&arg, const Indices &indices)
Access a component of an indexible object at a given set of indices.
Definition: access.hpp:74
Definition: n_ary_operation.hpp:108
constexpr bool fixed_pattern
A coordinates::pattern for which the dimension is fixed at compile time.
Definition: fixed_pattern.hpp:46
constexpr bool dynamic_pattern
A coordinates::pattern for which the size is defined at runtime.
Definition: dynamic_pattern.hpp:31
constexpr auto constant_value(T &&t)
The constant value associated with a constant_object or constant_diagonal_object. ...
Definition: constant_value.hpp:37
Definition: trait_backports.hpp:64
typename collection_element< i, T >::type collection_element_t
Helper template for collection_element.
Definition: collection_element.hpp:116
The dimension of an index for a matrix, expression, or array.
Definition: index_dimension_of.hpp:35
The minimum number of indices needed to access all the components of an object (i.e., the rank or order).
Definition: index_count.hpp:34
Definition: n_ary_operation.hpp:88
typename element_type_of< T >::type element_type_of_t
helper template for element_type_of.
Definition: element_type_of.hpp:54
constexpr auto operation(Operation &&op, Args &&...args)
A potentially constant-evaluated operation involving some number of values.
Definition: operation.hpp:98
constexpr bool index
An object describing a collection of /ref values::index objects.
Definition: index.hpp:77