Skip to content

Commit bfd7e34

Browse files
author
b-shi
authored
Update precision support for gemm, gemm batch, rot (#432)
1 parent 8f64ff8 commit bfd7e34

File tree

3 files changed

+122
-39
lines changed

3 files changed

+122
-39
lines changed

source/elements/oneMKL/source/domains/blas/gemm.rst

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,38 @@ op(``X``) = ``X``\ :sup:`H`,
4141
.. list-table::
4242
:header-rows: 1
4343

44-
* - Ts
45-
- Ta
46-
- Tb
47-
- Tc
48-
* - ``float``
49-
- ``half``
44+
* - | Ta
45+
| (A matrix)
46+
- | Tb
47+
| (B matrix)
48+
- | Tc
49+
| (C matrix)
50+
- | Ts
51+
| (alpha/beta)
52+
* - ``std::int8_t``
53+
- ``std::int8_t``
54+
- ``std::int32_t``
55+
- ``float``
56+
* - ``std::int8_t``
57+
- ``std::int8_t``
58+
- ``float``
59+
- ``float``
60+
* - ``half``
5061
- ``half``
5162
- ``float``
63+
- ``float``
5264
* - ``half``
5365
- ``half``
5466
- ``half``
5567
- ``half``
56-
* - ``float``
57-
- ``bfloat16``
58-
- ``bfloat16``
59-
- ``float``
68+
* - ``bfloat16``
69+
- ``bfloat16``
70+
- ``float``
71+
- ``float``
72+
* - ``bfloat16``
73+
- ``bfloat16``
74+
- ``bfloat16``
75+
- ``float``
6076
* - ``float``
6177
- ``float``
6278
- ``float``

source/elements/oneMKL/source/domains/blas/gemm_batch.rst

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,54 @@ operation perform a matrix-matrix product with general matrices.
2222
.. list-table::
2323
:header-rows: 1
2424

25-
* - T
26-
* - ``half``
25+
* - | Ta
26+
| (A matrix)
27+
- | Tb
28+
| (B matrix)
29+
- | Tc
30+
| (C matrix)
31+
- | Ts
32+
| (alpha/beta)
33+
* - ``std::int8_t``
34+
- ``std::int8_t``
35+
- ``std::int32_t``
36+
- ``float``
37+
* - ``std::int8_t``
38+
- ``std::int8_t``
39+
- ``float``
40+
- ``float``
41+
* - ``half``
42+
- ``half``
43+
- ``float``
44+
- ``float``
45+
* - ``half``
46+
- ``half``
47+
- ``half``
48+
- ``half``
49+
* - ``bfloat16``
50+
- ``bfloat16``
51+
- ``float``
52+
- ``float``
53+
* - ``bfloat16``
54+
- ``bfloat16``
55+
- ``bfloat16``
56+
- ``float``
2757
* - ``float``
58+
- ``float``
59+
- ``float``
60+
- ``float``
2861
* - ``double``
62+
- ``double``
63+
- ``double``
64+
- ``double``
2965
* - ``std::complex<float>``
66+
- ``std::complex<float>``
67+
- ``std::complex<float>``
68+
- ``std::complex<float>``
3069
* - ``std::complex<double>``
70+
- ``std::complex<double>``
71+
- ``std::complex<double>``
72+
- ``std::complex<double>``
3173

3274
.. _onemkl_blas_gemm_batch_buffer:
3375

source/elements/oneMKL/source/domains/blas/rot.rst

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ Performs rotation of points in the plane.
1515

1616
Given two vectors ``x`` and ``y`` of ``n`` elements, the ``rot`` routines
1717
compute four scalar-vector products and update the input vectors with
18-
the sum of two of these scalar-vector products as follow:
18+
the sum of two of these scalar-vector products as follows:
1919

2020
.. math::
21-
21+
2222
\left[\begin{array}{c}
2323
x\\y
2424
\end{array}\right]
@@ -28,25 +28,50 @@ the sum of two of these scalar-vector products as follow:
2828
-s*x + c*y
2929
\end{array}\right]
3030
31+
If ``s`` is a complex type, the operation is defined as:
32+
33+
.. math::
34+
\left[\begin{array}{c}
35+
x\\y
36+
\end{array}\right]
37+
\leftarrow
38+
\left[\begin{array}{c}
39+
\phantom{-}c*x + s*y\\
40+
-conj(s)*x + c*y
41+
\end{array}\right]
42+
3143
``rot`` supports the following precisions.
3244

33-
.. list-table::
34-
:header-rows: 1
35-
36-
* - T
37-
- T_scalar
38-
* - ``half``
39-
- ``half``
40-
* - ``bfloat16``
41-
- ``bfloat16``
42-
* - ``float``
43-
- ``float``
44-
* - ``double``
45-
- ``double``
46-
* - ``std::complex<float>``
47-
- ``float``
48-
* - ``std::complex<double>``
49-
- ``double``
45+
.. list-table::
46+
:header-rows: 1
47+
48+
* - T
49+
- T_scalarC
50+
- T_scalarS
51+
* - ``sycl::half``
52+
- ``sycl::half``
53+
- ``sycl::half``
54+
* - ``oneapi::mkl::bfloat16``
55+
- ``oneapi::mkl::bfloat16``
56+
- ``oneapi::mkl::bfloat16``
57+
* - ``float``
58+
- ``float``
59+
- ``float``
60+
* - ``double``
61+
- ``double``
62+
- ``double``
63+
* - ``std::complex<float>``
64+
- ``float``
65+
- ``std::complex<float>``
66+
* - ``std::complex<double>``
67+
- ``double``
68+
- ``std::complex<double>``
69+
* - ``std::complex<float>``
70+
- ``float``
71+
- ``float``
72+
* - ``std::complex<double>``
73+
- ``double``
74+
- ``double``
5075

5176
.. _onemkl_blas_rot_buffer:
5277

@@ -64,8 +89,8 @@ rot (Buffer Version)
6489
std::int64_t incx,
6590
sycl::buffer<T,1> &y,
6691
std::int64_t incy,
67-
T_scalar c,
68-
T_scalar s)
92+
T_scalarC c,
93+
T_scalarS s)
6994
}
7095
.. code-block:: cpp
7196
@@ -76,8 +101,8 @@ rot (Buffer Version)
76101
std::int64_t incx,
77102
sycl::buffer<T,1> &y,
78103
std::int64_t incy,
79-
T_scalar c,
80-
T_scalar s)
104+
T_scalarC c,
105+
T_scalarS s)
81106
}
82107
83108
.. container:: section
@@ -159,8 +184,8 @@ rot (USM Version)
159184
std::int64_t incx,
160185
T *y,
161186
std::int64_t incy,
162-
T_scalar c,
163-
T_scalar s,
187+
T_scalarC c,
188+
T_scalarS s,
164189
const std::vector<sycl::event> &dependencies = {})
165190
}
166191
.. code-block:: cpp
@@ -172,8 +197,8 @@ rot (USM Version)
172197
std::int64_t incx,
173198
T *y,
174199
std::int64_t incy,
175-
T_scalar c,
176-
T_scalar s,
200+
T_scalarC c,
201+
T_scalarS s,
177202
const std::vector<sycl::event> &dependencies = {})
178203
}
179204

0 commit comments

Comments
 (0)