Skip to content

Commit f92fb4c

Browse files
author
Nikita Kulikov
committed
Minor changes and simplifications in testing
1 parent 7426eb1 commit f92fb4c

4 files changed

Lines changed: 39 additions & 8 deletions

File tree

onedal/interop/tests/test_csr_table_interop.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
from scipy.sparse import csr_matrix, find, isspmatrix_csr
2020

21+
import onedal
2122
from onedal.interop.array import from_array
2223
from onedal.interop.csr_table import from_csr_table, is_csr_entity, to_csr_table
2324

@@ -78,6 +79,10 @@ def generate_csr_data(gen, shape, per_row, dtypes):
7879
return (data, indices, offsets)
7980

8081

82+
sp_indexing = onedal._backend.data_management.sparse_indexing
83+
indexing_offset_map = {sp_indexing.zero_based: 0, sp_indexing.one_based: 1}
84+
85+
8186
@pytest.mark.parametrize("shape", table_dimensions)
8287
@pytest.mark.parametrize("dtype", get_dtype_list())
8388
@pytest.mark.parametrize("itype", [np.int32, np.uint32, np.int64])
@@ -100,9 +105,16 @@ def test_host_csr_table_functionality(shape, dtype, itype):
100105
assert onedal_table.get_row_count() == row_count
101106
assert onedal_table.get_column_count() == col_count
102107

103-
onedal_indices = from_array(onedal_table.get_column_indices())
108+
curr_indexing = onedal_table.get_indexing()
109+
offset = indexing_offset_map[curr_indexing]
110+
111+
def get_indices(array):
112+
raw = from_array(array)
113+
return raw - offset
114+
115+
onedal_indices = get_indices(onedal_table.get_column_indices())
104116
np.testing.assert_equal(scipy_indices, onedal_indices)
105-
onedal_offsets = from_array(onedal_table.get_row_offsets())
117+
onedal_offsets = get_indices(onedal_table.get_row_offsets())
106118
np.testing.assert_equal(scipy_offsets, onedal_offsets)
107119
onedal_data = from_array(onedal_table.get_data())
108120
np.testing.assert_equal(scipy_data, onedal_data)

onedal/interop/tests/test_homogen_table_interop.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import pytest
1919

20+
import onedal
2021
from onedal.interop.homogen_table import (
2122
from_homogen_table,
2223
is_homogen_entity,
@@ -47,6 +48,24 @@
4748
(123, 999),
4849
]
4950

51+
data_layout = onedal._backend.data_management.data_layout
52+
53+
54+
def check_table_dimensions(table, shape, transpose):
55+
assert table.get_row_count() == shape[0]
56+
assert table.get_column_count() == shape[1]
57+
58+
is_simple = shape[0] == 1
59+
is_simple = is_simple or shape[1] == 1
60+
61+
curr_layout = table.get_data_layout()
62+
if transpose and not is_simple:
63+
column_major = data_layout.column_major
64+
assert curr_layout == column_major
65+
else:
66+
row_major = data_layout.row_major
67+
assert curr_layout == row_major
68+
5069

5170
@pytest.mark.skipif(not dpctl_available, reason="requires dpctl>=0.14")
5271
@pytest.mark.parametrize("queue", get_queues("cpu,gpu"))
@@ -68,8 +87,8 @@ def test_device_array_functionality(queue, backend, transpose, shape, dtype):
6887
assert is_homogen_entity(onedal_table)
6988
del dpctl_tensor, wrapped_tensor
7089

71-
assert onedal_table.get_row_count() == dpctl_sua["shape"][0]
72-
assert onedal_table.get_column_count() == dpctl_sua["shape"][1]
90+
curr_shape = numpy_array.shape
91+
check_table_dimensions(onedal_table, curr_shape, transpose)
7392

7493
return_table = from_homogen_table(onedal_table)
7594

@@ -100,8 +119,8 @@ def test_host_homogen_table_functionality(backend, transpose, shape, dtype):
100119
onedal_table = to_homogen_table(wrapped_tensor)
101120
assert is_homogen_entity(onedal_table)
102121

103-
assert onedal_table.get_row_count() == numpy_iface["shape"][0]
104-
assert onedal_table.get_column_count() == numpy_iface["shape"][1]
122+
curr_shape = numpy_array.shape
123+
check_table_dimensions(onedal_table, curr_shape, transpose)
105124

106125
return_table = from_homogen_table(onedal_table)
107126
return_iface = return_table.__array_interface__

onedal/svm/tests/test_svc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _restore_from_saved(md, saved_dict):
4848
setattr(md, check_f, saved_dict[check_f])
4949

5050

51-
#def test_estimator():
51+
# def test_estimator():
5252
# def dummy(*args, **kwargs):
5353
# pass
5454
#

onedal/svm/tests/test_svr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _restore_from_saved(md, saved_dict):
4949
setattr(md, check_f, saved_dict[check_f])
5050

5151

52-
#def test_estimator():
52+
# def test_estimator():
5353
# def dummy(*args, **kwargs):
5454
# pass
5555
#

0 commit comments

Comments
 (0)