Skip to content

Commit b3dd095

Browse files
[oneDPL][Tests] Refactoring of utility functions to evaluate projected values and compare result in one place (#2461)
1 parent d075f89 commit b3dd095

8 files changed

Lines changed: 149 additions & 129 deletions

File tree

include/oneapi/dpl/pstl/algorithm_impl.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3318,21 +3318,19 @@ __parallel_set_op(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomA
33183318

33193319
//try searching for the first element which not equal to *__b
33203320
if (__b != __first1)
3321-
__b += __internal::__pstl_upper_bound(__b, _DifferenceType1{0}, __last1 - __b,
3322-
std::invoke(__proj1, *__b), __comp, __proj1);
3321+
__b += __internal::__pstl_upper_bound(__b, _DifferenceType1{0}, __last1 - __b, __b, __comp, __proj1, __proj1);
33233322

33243323
//try searching for the first element which not equal to *__e
33253324
if (__e != __last1)
3326-
__e += __internal::__pstl_upper_bound(__e, _DifferenceType1{0}, __last1 - __e,
3327-
std::invoke(__proj1, *__e), __comp, __proj1);
3325+
__e += __internal::__pstl_upper_bound(__e, _DifferenceType1{0}, __last1 - __e, __e, __comp, __proj1, __proj1);
33283326

33293327
//check is [__b; __e) empty
33303328
if (__e - __b < 1)
33313329
{
33323330
_RandomAccessIterator2 __bb = __last2;
33333331
if (__b != __last1)
33343332
__bb = __first2 + __internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2,
3335-
std::invoke(__proj1, *__b), __comp, __proj2);
3333+
__b, __comp, __proj2, __proj1);
33363334

33373335
const _DifferenceType __buf_pos = __size_func((__b - __first1), (__bb - __first2));
33383336
return _SetRange{0, 0, __buf_pos};
@@ -3342,12 +3340,12 @@ __parallel_set_op(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _RandomA
33423340
_RandomAccessIterator2 __bb = __first2;
33433341
if (__b != __first1)
33443342
__bb = __first2 + __internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2,
3345-
std::invoke(__proj1, *__b), __comp, __proj2);
3343+
__b, __comp, __proj2, __proj1);
33463344

33473345
_RandomAccessIterator2 __ee = __last2;
33483346
if (__e != __last1)
3349-
__ee = __bb + __internal::__pstl_lower_bound(__bb, _DifferenceType2{0}, __last2 - __bb,
3350-
std::invoke(__proj1, *__e), __comp, __proj2);
3347+
__ee = __bb + __internal::__pstl_lower_bound(__bb, _DifferenceType2{0}, __last2 - __bb, __e, __comp,
3348+
__proj2, __proj1);
33513349

33523350
const _DifferenceType __buf_pos = __size_func((__b - __first1), (__bb - __first2));
33533351
auto __buffer_b = __tmp_memory + __buf_pos;
@@ -3402,8 +3400,8 @@ __parallel_set_union_op(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __ex
34023400

34033401
// testing whether the sequences are intersected
34043402
_RandomAccessIterator1 __left_bound_seq_1 =
3405-
__first1 + __internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1,
3406-
std::invoke(__proj2, *__first2), __comp, __proj1);
3403+
__first1 + __internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1, __first2, __comp,
3404+
__proj1, __proj2);
34073405

34083406
if (__left_bound_seq_1 == __last1)
34093407
{
@@ -3421,8 +3419,8 @@ __parallel_set_union_op(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __ex
34213419

34223420
// testing whether the sequences are intersected
34233421
_RandomAccessIterator2 __left_bound_seq_2 =
3424-
__first2 + __internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2,
3425-
std::invoke(__proj1, *__first1), __comp, __proj2);
3422+
__first2 + __internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2, __first1, __comp,
3423+
__proj2, __proj1);
34263424

34273425
if (__left_bound_seq_2 == __last2)
34283426
{

include/oneapi/dpl/pstl/algorithm_ranges_impl.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,8 @@ __pattern_includes(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _
692692
std::invoke(__comp, std::invoke(__proj1, *(__last1 - 1)), std::invoke(__proj2, *(__last2 - 1))))
693693
return false;
694694

695-
__first1 += oneapi::dpl::__internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1,
696-
std::invoke(__proj2, *__first2), __comp, __proj1);
695+
__first1 += oneapi::dpl::__internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1, __first2,
696+
__comp, __proj1, __proj2);
697697
if (__first1 == __last1)
698698
return false;
699699

@@ -721,19 +721,18 @@ __pattern_includes(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __exec, _
721721
if (__is_equal_sorted(__i, __j - 1))
722722
return false;
723723

724-
__i += oneapi::dpl::__internal::__pstl_upper_bound(__i, _DifferenceType2{0}, __last2 - __i,
725-
std::invoke(__proj2, *__i), __comp, __proj2);
724+
__i += oneapi::dpl::__internal::__pstl_upper_bound(__i, _DifferenceType2{0}, __last2 - __i, __i, __comp,
725+
__proj2, __proj2);
726726
}
727727

728728
//1.2 right bound, case "[...aaa]aaaxyz" - searching "x"
729729
if (__j < __last2 && __is_equal_sorted(__j - 1, __j))
730-
__j += oneapi::dpl::__internal::__pstl_upper_bound(__j, _DifferenceType2{0}, __last2 - __j,
731-
std::invoke(__proj2, *__j), __comp, __proj2);
730+
__j += oneapi::dpl::__internal::__pstl_upper_bound(__j, _DifferenceType2{0}, __last2 - __j, __j, __comp,
731+
__proj2, __proj2);
732732

733733
//2. testing is __a subsequence of the second range included into the first range
734-
auto __b = __first1 +
735-
oneapi::dpl::__internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1,
736-
std::invoke(__proj2, *__i), __comp, __proj1);
734+
auto __b = __first1 + oneapi::dpl::__internal::__pstl_lower_bound(
735+
__first1, _DifferenceType1{0}, __last1 - __first1, __i, __comp, __proj1, __proj2);
737736

738737
return !std::ranges::includes(__b, __last1, __i, __j, __comp, __proj1, __proj2);
739738
});
@@ -885,15 +884,15 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _
885884
// testing whether the sequences are intersected
886885
auto __left_bound_seq_1 =
887886
__first1 + oneapi::dpl::__internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1,
888-
std::invoke(__proj2, *__first2), __comp, __proj1);
887+
__first2, __comp, __proj1, __proj2);
889888
//{1} < {2}: seq 2 is wholly greater than seq 1, so, the intersection is empty
890889
if (__left_bound_seq_1 == __last1)
891890
return {__last1, __last2, __result};
892891

893892
// testing whether the sequences are intersected
894893
auto __left_bound_seq_2 =
895894
__first2 + oneapi::dpl::__internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2,
896-
std::invoke(__proj1, *__first1), __comp, __proj2);
895+
__first1, __comp, __proj2, __proj1);
897896
//{2} < {1}: seq 1 is wholly greater than seq 2, so, the intersection is empty
898897
if (__left_bound_seq_2 == __last2)
899898
return {__last1, __last2, __result};
@@ -1021,7 +1020,7 @@ __pattern_set_difference(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __e
10211020
// testing whether the sequences are intersected
10221021
auto __left_bound_seq_1 =
10231022
__first1 + oneapi::dpl::__internal::__pstl_lower_bound(__first1, _DifferenceType1{0}, __last1 - __first1,
1024-
std::invoke(__proj2, *__first2), __comp, __proj1);
1023+
__first2, __comp, __proj1, __proj2);
10251024
//{1} < {2}: seq 2 is wholly greater than seq 1, so, parallel copying just first sequence
10261025
if (__left_bound_seq_1 == __last1)
10271026
{
@@ -1033,7 +1032,7 @@ __pattern_set_difference(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& __e
10331032
// testing whether the sequences are intersected
10341033
auto __left_bound_seq_2 =
10351034
__first2 + oneapi::dpl::__internal::__pstl_lower_bound(__first2, _DifferenceType2{0}, __last2 - __first2,
1036-
std::invoke(__proj1, *__first1), __comp, __proj2);
1035+
__first1, __comp, __proj2, __proj1);
10371036
//{2} < {1}: seq 1 is wholly greater than seq 2, so, parallel copying just first sequence
10381037
if (__left_bound_seq_2 == __last2)
10391038
{

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,8 +1996,9 @@ struct __partial_merge_kernel
19961996
const auto __shift =
19971997
/* index inside p1 */ __global_idx - __start_1 +
19981998
/* relative position in p3 */
1999-
oneapi::dpl::__internal::__pstl_lower_bound(__in_acc2, __start_2, __part_end_2, __in_acc1[__global_idx],
2000-
__comp, oneapi::dpl::identity{}) -
1999+
oneapi::dpl::__internal::__pstl_lower_bound_idx(__in_acc2, __start_2, __part_end_2, __in_acc1,
2000+
__global_idx, __comp, oneapi::dpl::identity{},
2001+
oneapi::dpl::identity{}) -
20012002
__start_2;
20022003
__out_acc[__out_shift + __shift] = __in_acc1[__global_idx];
20032004
}
@@ -2015,8 +2016,9 @@ struct __partial_merge_kernel
20152016
const auto __shift =
20162017
/* index inside p3 */ __global_idx - __start_2 +
20172018
/* relative position in p1 */
2018-
oneapi::dpl::__internal::__pstl_upper_bound(__in_acc1, __start_1, __part_end_1, __in_acc2[__global_idx],
2019-
__comp, oneapi::dpl::identity{}) -
2019+
oneapi::dpl::__internal::__pstl_upper_bound_idx(__in_acc1, __start_1, __part_end_1, __in_acc2,
2020+
__global_idx, __comp, oneapi::dpl::identity{},
2021+
oneapi::dpl::identity{}) -
20202022
__start_1;
20212023
__out_acc[__out_shift + __shift] = __in_acc2[__global_idx];
20222024
}

include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce_then_scan.h

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -389,42 +389,32 @@ struct __gen_set_mask
389389

390390
std::size_t __nb = __set_b.size();
391391

392-
// This reference extends the lifetime of a temporary object returned by operator[]
393-
// so that it can be safely used with identity projections
394-
auto&& __val_a = __set_a[__id];
395-
auto&& __val_a_proj = std::invoke(__proj1, std::forward<decltype(__val_a)>(__val_a));
396-
397-
auto __res =
398-
oneapi::dpl::__internal::__pstl_lower_bound(__set_b, std::size_t{0}, __nb, __val_a_proj, __comp, __proj2);
392+
auto __res = oneapi::dpl::__internal::__pstl_lower_bound_idx(__set_b, std::size_t{0}, __nb, __set_a, __id,
393+
__comp, __proj2, __proj1);
399394
constexpr bool __is_difference = std::is_same_v<_SetTag, oneapi::dpl::unseq_backend::_DifferenceTag>;
400395

401396
//initialization is true in case of difference operation; false - intersection.
402397
bool bres = __is_difference;
403398

404-
if (__res == __nb || std::invoke(__comp, __val_a_proj, std::invoke(__proj2, __set_b[__res])))
399+
if (__res == __nb ||
400+
std::invoke(__comp, std::invoke(__proj1, __set_a[__id]), std::invoke(__proj2, __set_b[__res])))
405401
{
406-
// there is no __val_a in __set_b, so __set_b in the difference {__set_a}/{__set_b};
402+
// there is no __set_a[__id] in __set_b, so __set_b in the difference {__set_a}/{__set_b};
407403
}
408404
else
409405
{
410-
// This reference extends the lifetime of a temporary object returned by operator[]
411-
// so that it can be safely used with identity projections
412-
auto&& __val_b = __set_b[__res];
413-
auto&& __val_b_proj = std::invoke(__proj2, std::forward<decltype(__val_b)>(__val_b));
414-
415406
//Difference operation logic: if number of duplication in __set_a on left side from __id > total number of
416407
//duplication in __set_b then a mask is 1
417408

418409
//Intersection operation logic: if number of duplication in __set_a on left side from __id <= total number of
419410
//duplication in __set_b then a mask is 1
420411

421412
const std::size_t __count_a_left =
422-
__id - oneapi::dpl::__internal::__pstl_left_bound(__set_a, std::size_t{0}, __id, __val_a_proj, __comp, __proj1) + 1;
413+
__id - oneapi::dpl::__internal::__pstl_left_bound_idx(__set_a, std::size_t{0}, __id, __set_a, __id, __comp, __proj1, __proj1) + 1;
423414

424-
const std::size_t __count_b =
425-
oneapi::dpl::__internal::__pstl_right_bound(__set_b, __res, __nb, __val_b_proj, __comp, __proj2) -
426-
oneapi::dpl::__internal::__pstl_left_bound(__set_b, std::size_t{0}, __res, __val_b_proj, __comp,
427-
__proj2);
415+
const std::size_t __count_b =
416+
oneapi::dpl::__internal::__pstl_right_bound_idx(__set_b, __res, __nb, __set_b, __res, __comp, __proj2, __proj2) -
417+
oneapi::dpl::__internal::__pstl_left_bound_idx(__set_b, std::size_t{0}, __res, __set_b, __res, __comp, __proj2, __proj2);
428418

429419
if constexpr (__is_difference)
430420
bres = __count_a_left > __count_b; /*difference*/
@@ -713,24 +703,20 @@ struct __gen_set_balanced_path
713703
return std::make_tuple(__merge_path_rng1, __merge_path_rng2, false);
714704
}
715705

716-
// This reference extends the lifetime of a temporary object returned by operator[]
717-
// so that it can be safely used with identity projections
718-
auto&& __ele_val = __rng1[__merge_path_rng1 - 1];
719-
auto&& __ele_val_proj = std::invoke(__proj1, std::forward<decltype(__ele_val)>(__ele_val));
720-
721-
if (std::invoke(__comp, __ele_val_proj, std::invoke(__proj2, __rng2[__merge_path_rng2])))
706+
if (std::invoke(__comp, std::invoke(__proj1, __rng1[__merge_path_rng1 - 1]),
707+
std::invoke(__proj2, __rng2[__merge_path_rng2])))
722708
{
723709
// There is no chance that the balanced path differs from the merge path here, because the previous element of
724710
// rng1 does not match the next element of rng2. We can just return the merge path.
725711
return std::make_tuple(__merge_path_rng1, __merge_path_rng2, false);
726712
}
727713

728714
// find first element of repeating sequence in the first set of the previous element
729-
_Index __rng1_repeat_start = oneapi::dpl::__internal::__biased_lower_bound</*__last_bias=*/true>(
730-
__rng1, __rng1_begin, __merge_path_rng1, __ele_val_proj, __comp, __proj1);
715+
_Index __rng1_repeat_start = oneapi::dpl::__internal::__biased_lower_bound_idx</*__last_bias=*/true>(
716+
__rng1, __rng1_begin, __merge_path_rng1, __rng1, __merge_path_rng1 - 1, __comp, __proj1, __proj1);
731717
// find first element of repeating sequence in the second set of the next element
732-
_Index __rng2_repeat_start = oneapi::dpl::__internal::__biased_lower_bound</*__last_bias=*/true>(
733-
__rng2, __rng2_begin, __merge_path_rng2, __ele_val_proj, __comp, __proj2);
718+
_Index __rng2_repeat_start = oneapi::dpl::__internal::__biased_lower_bound_idx</*__last_bias=*/true>(
719+
__rng2, __rng2_begin, __merge_path_rng2, __rng1, __merge_path_rng1 - 1, __comp, __proj2, __proj1);
734720

735721
_Index __rng1_repeats = __merge_path_rng1 - __rng1_repeat_start;
736722
_Index __rng2_repeats_bck = __merge_path_rng2 - __rng2_repeat_start;
@@ -748,8 +734,8 @@ struct __gen_set_balanced_path
748734
// Calculate the max location to search in the second set for future repeats, limiting to the edge of the range
749735
_Index __fwd_search_bound = std::min(__merge_path_rng2 + __fwd_search_count, __rng2_end);
750736

751-
_Index __balanced_path_intersection_rng2 = oneapi::dpl::__internal::__pstl_upper_bound(
752-
__rng2, __merge_path_rng2, __fwd_search_bound, __ele_val_proj, __comp, __proj2);
737+
_Index __balanced_path_intersection_rng2 = oneapi::dpl::__internal::__pstl_upper_bound_idx(
738+
__rng2, __merge_path_rng2, __fwd_search_bound, __rng1, __merge_path_rng1 - 1, __comp, __proj2, __proj1);
753739

754740
// Calculate the number of matchable "future" repeats in the second set
755741
_Index __matchable_forward_ele_rng2 = __balanced_path_intersection_rng2 - __merge_path_rng2;

0 commit comments

Comments
 (0)