Skip to content

Commit 950b890

Browse files
[oneMKL][BLAS] value_or_pointer wrapper for BLAS USM scalar parameters (#503)
* [oneMKL][BLAS] value_or_pointer wrapper for BLAS UUSM scalar parameters
1 parent 145b62a commit 950b890

49 files changed

Lines changed: 325 additions & 212 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

source/elements/oneMKL/source/domains/blas/axpby.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ axpby (USM Version)
127127
namespace oneapi::mkl::blas::column_major {
128128
sycl::event axpby(sycl::queue &queue,
129129
std::int64_t n,
130-
T alpha,
130+
value_or_pointer<T> alpha,
131131
const T *x, std::int64_t incx,
132-
const T beta,
132+
value_or_pointer<T> beta,
133133
T *y, std::int64_t incy,
134134
const std::vector<sycl::event> &dependencies = {})
135135
}
@@ -138,9 +138,9 @@ axpby (USM Version)
138138
namespace oneapi::mkl::blas::row_major {
139139
sycl::event axpby(sycl::queue &queue,
140140
std::int64_t n,
141-
T alpha,
141+
value_or_pointer<T> alpha,
142142
const T *x, std::int64_t incx,
143-
const T beta,
143+
value_or_pointer<T> beta,
144144
T *y, std::int64_t incy,
145145
const std::vector<sycl::event> &dependencies = {})
146146
}
@@ -156,10 +156,10 @@ axpby (USM Version)
156156
Number of elements in vector ``x`` and ``y``.
157157

158158
alpha
159-
Specifies the scalar alpha.
159+
Specifies the scalar alpha. See :ref:`value_or_pointer` for more details.
160160

161161
beta
162-
Specifies the scalar beta.
162+
Specifies the scalar beta. See :ref:`value_or_pointer` for more details.
163163

164164
x
165165
Pointer to the input vector ``x``. The allocated memory must be

source/elements/oneMKL/source/domains/blas/axpy.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ axpy (USM Version)
138138
namespace oneapi::mkl::blas::column_major {
139139
sycl::event axpy(sycl::queue &queue,
140140
std::int64_t n,
141-
T alpha,
141+
value_or_pointer<T> alpha,
142142
const T *x,
143143
std::int64_t incx,
144144
T *y,
@@ -150,7 +150,7 @@ axpy (USM Version)
150150
namespace oneapi::mkl::blas::row_major {
151151
sycl::event axpy(sycl::queue &queue,
152152
std::int64_t n,
153-
T alpha,
153+
value_or_pointer<T> alpha,
154154
const T *x,
155155
std::int64_t incx,
156156
T *y,
@@ -169,7 +169,7 @@ axpy (USM Version)
169169
Number of elements in vector ``x``.
170170

171171
alpha
172-
Specifies the scalar alpha.
172+
Specifies the scalar alpha. See :ref:`value_or_pointer` for more details.
173173

174174
x
175175
Pointer to the input vector ``x``. The array holding the vector

source/elements/oneMKL/source/domains/blas/axpy_batch.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ The total number of vectors in ``x`` and ``y`` are given by the ``batch_size`` p
292292
namespace oneapi::mkl::blas::column_major {
293293
sycl::event axpy_batch(sycl::queue &queue,
294294
std::int64_t n,
295-
T alpha,
295+
value_or_pointer<T> alpha,
296296
const T *x,
297297
std::int64_t incx,
298298
std::int64_t stridex,
@@ -307,7 +307,7 @@ The total number of vectors in ``x`` and ``y`` are given by the ``batch_size`` p
307307
namespace oneapi::mkl::blas::row_major {
308308
sycl::event axpy_batch(sycl::queue &queue,
309309
std::int64_t n,
310-
T alpha,
310+
value_or_pointer<T> alpha,
311311
const T *x,
312312
std::int64_t incx,
313313
std::int64_t stridex,
@@ -329,7 +329,7 @@ The total number of vectors in ``x`` and ``y`` are given by the ``batch_size`` p
329329
Number of elements in ``X`` and ``Y``.
330330

331331
alpha
332-
Specifies the scalar ``alpha``.
332+
Specifies the scalar ``alpha``. See :ref:`value_or_pointer` for more details.
333333

334334
x
335335
Pointer to input vectors ``X`` with size ``stridex`` * ``batch_size``.

source/elements/oneMKL/source/domains/blas/gbmv.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,12 @@ gbmv (USM Version)
195195
std::int64_t n,
196196
std::int64_t kl,
197197
std::int64_t ku,
198-
T alpha,
198+
value_or_pointer<T> alpha,
199199
const T *a,
200200
std::int64_t lda,
201201
const T *x,
202202
std::int64_t incx,
203-
T beta,
203+
value_or_pointer<T> beta,
204204
T *y,
205205
std::int64_t incy,
206206
const std::vector<sycl::event> &dependencies = {})
@@ -214,12 +214,12 @@ gbmv (USM Version)
214214
std::int64_t n,
215215
std::int64_t kl,
216216
std::int64_t ku,
217-
T alpha,
217+
value_or_pointer<T> alpha,
218218
const T *a,
219219
std::int64_t lda,
220220
const T *x,
221221
std::int64_t incx,
222-
T beta,
222+
value_or_pointer<T> beta,
223223
T *y,
224224
std::int64_t incy,
225225
const std::vector<sycl::event> &dependencies = {})
@@ -253,7 +253,7 @@ gbmv (USM Version)
253253
zero.
254254

255255
alpha
256-
Scaling factor for the matrix-vector product.
256+
Scaling factor for the matrix-vector product. See :ref:`value_or_pointer` for more details.
257257

258258
a
259259
Pointer to input matrix ``A``. The array holding input matrix
@@ -276,7 +276,7 @@ gbmv (USM Version)
276276
Stride of vector ``x``. Must not be zero.
277277

278278
beta
279-
Scaling factor for vector ``y``.
279+
Scaling factor for vector ``y``. See :ref:`value_or_pointer` for more details.
280280

281281
y
282282
Pointer to input/output vector ``y``. The length ``len`` of

source/elements/oneMKL/source/domains/blas/gemm.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,12 @@ gemm (USM Version)
308308
std::int64_t m,
309309
std::int64_t n,
310310
std::int64_t k,
311-
Ts alpha,
311+
value_or_pointer<Ts> alpha,
312312
const Ta *a,
313313
std::int64_t lda,
314314
const Tb *b,
315315
std::int64_t ldb,
316-
Ts beta,
316+
value_or_poitner<Ts> beta,
317317
Tc *c,
318318
std::int64_t ldc,
319319
const std::vector<sycl::event> &dependencies = {})
@@ -327,12 +327,12 @@ gemm (USM Version)
327327
std::int64_t m,
328328
std::int64_t n,
329329
std::int64_t k,
330-
Ts alpha,
330+
value_or_pointer<Ts> alpha,
331331
const Ta *a,
332332
std::int64_t lda,
333333
const Tb *b,
334334
std::int64_t ldb,
335-
Ts beta,
335+
value_or_pointer<Ts> beta,
336336
Tc *c,
337337
std::int64_t ldc,
338338
const std::vector<sycl::event> &dependencies = {})
@@ -373,7 +373,7 @@ gemm (USM Version)
373373

374374

375375
alpha
376-
Scaling factor for the matrix-matrix product.
376+
Scaling factor for the matrix-matrix product. See :ref:`value_or_pointer` for more details.
377377

378378

379379
a
@@ -453,7 +453,7 @@ gemm (USM Version)
453453
- ``ldb`` must be at least ``k``.
454454

455455
beta
456-
Scaling factor for matrix ``C``.
456+
Scaling factor for matrix ``C``. See :ref:`value_or_pointer` for more details.
457457

458458
c
459459
The pointer to input/output matrix ``C``. It must have a

source/elements/oneMKL/source/domains/blas/gemm_batch.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -593,14 +593,14 @@ in ``a``, ``b`` and ``c`` are given by the ``batch_size`` parameter.
593593
std::int64_t m,
594594
std::int64_t n,
595595
std::int64_t k,
596-
T alpha,
596+
value_or_pointer<T> alpha,
597597
const T *a,
598598
std::int64_t lda,
599599
std::int64_t stridea,
600600
const T *b,
601601
std::int64_t ldb,
602602
std::int64_t strideb,
603-
T beta,
603+
value_or_pointer<T> beta,
604604
T *c,
605605
std::int64_t ldc,
606606
std::int64_t stridec,
@@ -616,14 +616,14 @@ in ``a``, ``b`` and ``c`` are given by the ``batch_size`` parameter.
616616
std::int64_t m,
617617
std::int64_t n,
618618
std::int64_t k,
619-
T alpha,
619+
value_or_pointer<T> alpha,
620620
const T *a,
621621
std::int64_t lda,
622622
std::int64_t stridea,
623623
const T *b,
624624
std::int64_t ldb,
625625
std::int64_t strideb,
626-
T beta,
626+
value_or_pointer<T> beta,
627627
T *c,
628628
std::int64_t ldc,
629629
std::int64_t stridec,
@@ -657,7 +657,7 @@ in ``a``, ``b`` and ``c`` are given by the ``batch_size`` parameter.
657657
least zero.
658658

659659
alpha
660-
Scaling factor for the matrix-matrix products.
660+
Scaling factor for the matrix-matrix products. See :ref:`value_or_pointer` for more details.
661661

662662
a
663663
Pointer to input matrices ``A`` with size ``stridea`` * ``batch_size``.
@@ -704,7 +704,7 @@ in ``a``, ``b`` and ``c`` are given by the ``batch_size`` parameter.
704704
Stride between different ``B`` matrices.
705705

706706
beta
707-
Scaling factor for the matrices ``C``.
707+
Scaling factor for the matrices ``C``. See :ref:`value_or_pointer` for more details.
708708

709709
c
710710
Pointer to input/output matrices ``C`` with size ``stridec`` * ``batch_size``.

source/elements/oneMKL/source/domains/blas/gemm_bias.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,14 +311,14 @@ gemm_bias (USM Version)
311311
std::int64_t m,
312312
std::int64_t n,
313313
std::int64_t k,
314-
float alpha,
314+
value_or_pointer<float> alpha,
315315
const Ta *a,
316316
std::int64_t lda,
317317
Ta ao,
318318
const Tb *b,
319319
std::int64_t ldb,
320320
Tb bo,
321-
float beta,
321+
value_or_pointer<float> beta,
322322
std::int32_t *c,
323323
std::int64_t ldc,
324324
const std::int32_t *co,
@@ -334,14 +334,14 @@ gemm_bias (USM Version)
334334
std::int64_t m,
335335
std::int64_t n,
336336
std::int64_t k,
337-
float alpha,
337+
value_or_pointer<float> alpha,
338338
const Ta *a,
339339
std::int64_t lda,
340340
Ta ao,
341341
const Tb *b,
342342
std::int64_t ldb,
343343
Tb bo,
344-
float beta,
344+
value_or_pointer<float> beta,
345345
std::int32_t *c,
346346
std::int64_t ldc,
347347
const std::int32_t *co,
@@ -385,7 +385,7 @@ gemm_bias (USM Version)
385385
at least zero.
386386

387387
alpha
388-
Scaling factor for the matrix-matrix product.
388+
Scaling factor for the matrix-matrix product. See :ref:`value_or_pointer` for more details.
389389

390390
a
391391
Pointer to input matrix ``A``.
@@ -470,7 +470,7 @@ gemm_bias (USM Version)
470470
Specifies the scalar offset value for matrix ``B``.
471471

472472
beta
473-
Scaling factor for matrix ``C``.
473+
Scaling factor for matrix ``C``. See :ref:`value_or_pointer` for more details.
474474

475475
c
476476
Pointer to input/output matrix ``C``. It must have a

source/elements/oneMKL/source/domains/blas/gemmt.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -259,12 +259,12 @@ gemmt (USM Version)
259259
onemkl::transpose transb,
260260
std::int64_t n,
261261
std::int64_t k,
262-
T alpha,
262+
value_or_pointer<T> alpha,
263263
const T *a,
264264
std::int64_t lda,
265265
const T *b,
266266
std::int64_t ldb,
267-
T beta,
267+
value_or_pointer<T> beta,
268268
T *c,
269269
std::int64_t ldc,
270270
const std::vector<sycl::event> &dependencies = {})
@@ -278,12 +278,12 @@ gemmt (USM Version)
278278
onemkl::transpose transb,
279279
std::int64_t n,
280280
std::int64_t k,
281-
T alpha,
281+
value_or_pointer<T> alpha,
282282
const T *a,
283283
std::int64_t lda,
284284
const T *b,
285285
std::int64_t ldb,
286-
T beta,
286+
value_or_pointer<T> beta,
287287
T *c,
288288
std::int64_t ldc,
289289
const std::vector<sycl::event> &dependencies = {})
@@ -323,7 +323,7 @@ gemmt (USM Version)
323323
at least zero.
324324

325325
alpha
326-
Scaling factor for the matrix-matrix product.
326+
Scaling factor for the matrix-matrix product. See :ref:`value_or_pointer` for more details.
327327

328328
a
329329
Pointer to input matrix ``A``.
@@ -402,7 +402,7 @@ gemmt (USM Version)
402402
- ``ldb`` must be at least ``k``.
403403

404404
beta
405-
Scaling factor for matrix ``C``.
405+
Scaling factor for matrix ``C``. See :ref:`value_or_pointer` for more details.
406406

407407
c
408408
Pointer to input/output matrix ``C``. Must have size at least

source/elements/oneMKL/source/domains/blas/gemv.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,12 @@ gemv (USM Version)
177177
onemkl::transpose trans,
178178
std::int64_t m,
179179
std::int64_t n,
180-
T alpha,
180+
value_or_pointer<T> alpha,
181181
const T *a,
182182
std::int64_t lda,
183183
const T *x,
184184
std::int64_t incx,
185-
T beta,
185+
value_or_pointer<T> beta,
186186
T *y,
187187
std::int64_t incy,
188188
const std::vector<sycl::event> &dependencies = {})
@@ -194,12 +194,12 @@ gemv (USM Version)
194194
onemkl::transpose trans,
195195
std::int64_t m,
196196
std::int64_t n,
197-
T alpha,
197+
value_or_pointer<T> alpha,
198198
const T *a,
199199
std::int64_t lda,
200200
const T *x,
201201
std::int64_t incx,
202-
T beta,
202+
value_or_pointer<T> beta,
203203
T *y,
204204
std::int64_t incy,
205205
const std::vector<sycl::event> &dependencies = {})
@@ -227,7 +227,7 @@ gemv (USM Version)
227227
of ``n`` must be at least zero.
228228

229229
alpha
230-
Scaling factor for the matrix-vector product.
230+
Scaling factor for the matrix-vector product. See :ref:`value_or_pointer` for more details.
231231

232232
a
233233
Pointer to the input matrix ``A``. Must have a size of at
@@ -251,7 +251,7 @@ gemv (USM Version)
251251
The stride of vector ``x``. Must not be zero.
252252

253253
beta
254-
The scaling factor for vector ``y``.
254+
The scaling factor for vector ``y``. See :ref:`value_or_pointer` for more details.
255255

256256
y
257257
Pointer to input/output vector ``y``. The length ``len`` of

0 commit comments

Comments
 (0)