Skip to content

Commit 95bedef

Browse files
akukanovSergeyKopienkodmitriy-sobolev
authored
Add limited output support to copy_if with device policies (#2501)
Co-authored-by: Sergey Kopienko <sergey.kopienko@intel.com> Co-authored-by: Dmitriy Sobolev <Dmitriy.Sobolev@intel.com>
1 parent 98db1cb commit 95bedef

File tree

9 files changed

+329
-241
lines changed

9 files changed

+329
-241
lines changed

include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -918,22 +918,19 @@ _Iterator2
918918
__pattern_copy_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator1 __first, _Iterator1 __last,
919919
_Iterator2 __result_first, _Predicate __pred)
920920
{
921-
using _It1DifferenceType = typename ::std::iterator_traits<_Iterator1>::difference_type;
922-
923921
if (__first == __last)
924922
return __result_first;
925923

926-
_It1DifferenceType __n = __last - __first;
924+
typename std::iterator_traits<_Iterator1>::difference_type __n = __last - __first;
927925

928926
auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator1>();
929927
auto __buf1 = __keep1(__first, __last);
930928
auto __keep2 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, _Iterator2>();
931929
auto __buf2 = __keep2(__result_first, __result_first + __n);
932930

933-
auto __res = __par_backend_hetero::__parallel_copy_if(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
934-
__buf1.all_view(), __buf2.all_view(), __n, __pred);
931+
std::size_t __num_copied = __par_backend_hetero::__parallel_copy_if(_BackendTag{},
932+
std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __n, __n, __pred)[0];
935933

936-
::std::size_t __num_copied = __res.get(); //is a blocking call
937934
return __result_first + __num_copied;
938935
}
939936

include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <cstddef>
3737
#include <functional>
3838
#include <type_traits>
39+
#include <array>
3940
#endif
4041

4142
namespace oneapi
@@ -654,16 +655,17 @@ oneapi::dpl::__internal::__difference_t<_Range2>
654655
__pattern_copy_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2,
655656
_Predicate __pred, _Assign __assign)
656657
{
657-
oneapi::dpl::__internal::__difference_t<_Range2> __n = oneapi::dpl::__ranges::__size(__rng1);
658-
if (__n == 0)
658+
using _Size = oneapi::dpl::__ranges::__common_size_t<_Range1, _Range2>;
659+
_Size __n = oneapi::dpl::__ranges::__size(__rng1);
660+
_Size __n_out = oneapi::dpl::__ranges::__size(__rng2);
661+
if (__n == 0 || __n_out == 0)
659662
return 0;
660663

661-
auto __res = oneapi::dpl::__par_backend_hetero::__parallel_copy_if(
664+
return oneapi::dpl::__par_backend_hetero::__parallel_copy_if(
662665
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec),
663666
oneapi::dpl::__ranges::__get_subscription_view(std::forward<_Range1>(__rng1)),
664-
oneapi::dpl::__ranges::__get_subscription_view(std::forward<_Range2>(__rng2)), __n, __pred, __assign);
665-
666-
return __res.get(); //is a blocking call
667+
oneapi::dpl::__ranges::__get_subscription_view(std::forward<_Range2>(__rng2)),
668+
__n, __n_out, __pred, __assign)[0];
667669
}
668670

669671
#if _ONEDPL_CPP20_RANGES_PRESENT
@@ -673,15 +675,28 @@ std::ranges::copy_if_result<std::ranges::borrowed_iterator_t<_InRange>, std::ran
673675
__pattern_copy_if_ranges(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _InRange&& __in_r,
674676
_OutRange&& __out_r, _Pred __pred, _Proj __proj)
675677
{
678+
using _Size = oneapi::dpl::__ranges::__common_size_t<_InRange, _OutRange>;
679+
_Size __n = oneapi::dpl::__ranges::__size(__in_r);
680+
if (__n == 0)
681+
return {std::ranges::begin(__in_r), std::ranges::begin(__out_r)};
682+
683+
_Size __n_out = oneapi::dpl::__ranges::__size(__out_r);
684+
if (__n_out == 0)
685+
{
686+
auto __found_it = __pattern_find_if(__tag, std::forward<_ExecutionPolicy>(__exec),
687+
std::forward<_InRange>(__in_r), __pred, __proj);
688+
return {__found_it, std::ranges::begin(__out_r)};
689+
}
690+
676691
oneapi::dpl::__internal::__unary_op<_Pred, _Proj> __pred_1{__pred, __proj};
677692

678-
auto __res_idx = oneapi::dpl::__internal::__ranges::__pattern_copy_if(__tag,
679-
std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::__ranges::views::all_read(__in_r),
680-
oneapi::dpl::__ranges::views::all_write(__out_r), __pred_1,
681-
oneapi::dpl::__internal::__pstl_assign());
693+
std::array<_Size, 2> __stops = oneapi::dpl::__par_backend_hetero::__parallel_copy_if(
694+
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec),
695+
oneapi::dpl::__ranges::views::all_read(std::forward<_InRange>(__in_r)),
696+
oneapi::dpl::__ranges::views::all_write(std::forward<_OutRange>(__out_r)),
697+
__n, __n_out, __pred_1, oneapi::dpl::__internal::__pstl_assign());
682698

683-
return {std::ranges::begin(__in_r) + oneapi::dpl::__ranges::__size(__in_r),
684-
std::ranges::begin(__out_r) + __res_idx};
699+
return {std::ranges::begin(__in_r) + __stops[1], std::ranges::begin(__out_r) + __stops[0]};
685700
}
686701
#endif //_ONEDPL_CPP20_RANGES_PRESENT
687702

0 commit comments

Comments
 (0)