Skip to content

Commit 7516458

Browse files
committed
Improved trimming strategy
Signed-off-by: Dan Hoeflinger <dan.hoeflinger@intel.com>
1 parent c0c3ac0 commit 7516458

File tree

1 file changed

+35
-28
lines changed

1 file changed

+35
-28
lines changed

include/oneapi/dpl/pstl/algorithm_impl.h

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3676,45 +3676,52 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _
36763676
if (__n1 == 0 || __n2 == 0)
36773677
return __result;
36783678

3679-
// testing whether the sequences are intersected
3680-
_RandomAccessIterator1 __left_bound_seq_1 = std::lower_bound(__first1, __last1, *__first2, __comp);
3681-
//{1} < {2}: seq 2 is wholly greater than seq 1, so, the intersection is empty
3682-
if (__left_bound_seq_1 == __last1)
3683-
return __result;
3684-
3685-
// testing whether the sequences are intersected
3686-
_RandomAccessIterator2 __left_bound_seq_2 = std::lower_bound(__first2, __last2, *__first1, __comp);
3687-
//{2} < {1}: seq 1 is wholly greater than seq 2, so, the intersection is empty
3688-
if (__left_bound_seq_2 == __last2)
3689-
return __result;
3690-
3691-
// Two trimming strategies are available (mutually exclusive):
3692-
// Strategy A: Trim range1 (elements < *__first2), keep range2 full
3693-
// Strategy B: Trim range2 (elements < *__first1), keep range1 full
3694-
// Choose the strategy that trims more elements (eliminates more non-overlapping work).
3695-
3696-
const _DifferenceType1 __trimmed_from_range1 = __left_bound_seq_1 - __first1;
3697-
const _DifferenceType2 __trimmed_from_range2 = __left_bound_seq_2 - __first2;
3698-
3679+
// Trim non-overlapping portions from both ends.
36993680
_RandomAccessIterator1 __begin1 = __first1;
3681+
_RandomAccessIterator1 __end1 = __last1;
37003682
_RandomAccessIterator2 __begin2 = __first2;
3683+
_RandomAccessIterator2 __end2 = __last2;
37013684

3702-
if (__trimmed_from_range1 >= __trimmed_from_range2)
3685+
// Trim the beginning of whichever range starts earlier
3686+
if (__comp(*__first2, *__first1))
37033687
{
3704-
__begin1 = __left_bound_seq_1;
3705-
__n1 = __last1 - __begin1;
3688+
// range 2 starts before range 1; trim beginning of range 2 to *__first1
3689+
__begin2 = std::lower_bound(__first2, __last2, *__first1, __comp);
3690+
if (__begin2 == __last2)
3691+
return __result;
37063692
}
3707-
else
3693+
else if (__comp(*__first1, *__first2))
3694+
{
3695+
// range 1 starts before range 2; trim beginning of range 1 to *__first2
3696+
__begin1 = std::lower_bound(__first1, __last1, *__first2, __comp);
3697+
if (__begin1 == __last1)
3698+
return __result;
3699+
}
3700+
3701+
// Trim the end of whichever range ends later
3702+
if (__comp(*(__end1 - 1), *(__end2 - 1)))
37083703
{
3709-
__begin2 = __left_bound_seq_2;
3710-
__n2 = __last2 - __begin2;
3704+
// range 1 ends before range 2; trim end of range 2 to *(__end1 - 1)
3705+
__end2 = std::upper_bound(__begin2, __end2, *(__end1 - 1), __comp);
37113706
}
3707+
else if (__comp(*(__end2 - 1), *(__end1 - 1)))
3708+
{
3709+
// range 2 ends before range 1; trim end of range 1 to *(__end2 - 1)
3710+
__end1 = std::upper_bound(__begin1, __end1, *(__end2 - 1), __comp);
3711+
}
3712+
3713+
// End trimming may have eliminated all overlap
3714+
if (__begin1 == __end1 || __begin2 == __end2)
3715+
return __result;
3716+
3717+
__n1 = __end1 - __begin1;
3718+
__n2 = __end2 - __begin2;
37123719

37133720
const _DifferenceType __total_work = __n1 + __n2;
37143721
if (__total_work > __set_algo_cut_off)
37153722
{
37163723
return __internal::__parallel_set_op(
3717-
__tag, std::forward<_ExecutionPolicy>(__exec), __begin1, __last1, __begin2, __last2, __result,
3724+
__tag, std::forward<_ExecutionPolicy>(__exec), __begin1, __end1, __begin2, __end2, __result,
37183725
[](_DifferenceType __n, _DifferenceType __m) { return std::min(__n, __m); },
37193726
[](_RandomAccessIterator1 __lmda_first1, _RandomAccessIterator1 __lmda_last1,
37203727
_RandomAccessIterator2 __lmda_first2, _RandomAccessIterator2 __lmda_last2, _T* __result, _Compare __comp,
@@ -3728,7 +3735,7 @@ __pattern_set_intersection(__parallel_tag<_IsVector> __tag, _ExecutionPolicy&& _
37283735
}
37293736

37303737
// Work too small for parallelization - use serial algorithm
3731-
return std::set_intersection(__left_bound_seq_1, __last1, __left_bound_seq_2, __last2, __result, __comp);
3738+
return std::set_intersection(__begin1, __end1, __begin2, __end2, __result, __comp);
37323739
}
37333740

37343741
//------------------------------------------------------------------------

0 commit comments

Comments
 (0)