diff --git a/src/array_index.rs b/src/array_index.rs index d83ac73..3f19cb6 100644 --- a/src/array_index.rs +++ b/src/array_index.rs @@ -71,9 +71,14 @@ impl ArrayIndex { pub fn get_ranges_and_squeeze_dims( &self, shape: &Vec, - ) -> PyResult<(Vec>, Vec)> { - // Input validation - if self.0.len() > shape.len() { + ) -> PyResult<(Vec>, Vec, Vec)> { + // Count how many actual dimensions are consumed by non-NewAxis indices + let consumed_dims: usize = self + .0 + .iter() + .filter(|&x| !matches!(x, IndexType::NewAxis | IndexType::Ellipsis)) + .count(); + if consumed_dims > shape.len() { return Err(PyErr::new::( "Too many indices for array", )); @@ -81,6 +86,7 @@ impl ArrayIndex { let mut ranges = Vec::new(); let mut squeeze_dims = Vec::new(); + let mut newaxis_dims = Vec::new(); let mut current_dim = 0; let mut shape_idx = 0; @@ -88,7 +94,7 @@ impl ArrayIndex { let explicit_dims: usize = self .0 .iter() - .filter(|&x| !matches!(x, IndexType::Ellipsis)) + .filter(|&x| !matches!(x, IndexType::Ellipsis | IndexType::NewAxis)) .count(); let ellipsis_dims = shape.len().saturating_sub(explicit_dims); @@ -100,7 +106,6 @@ impl ArrayIndex { "Only one ellipsis allowed in index", )); } - // Add full ranges for all dimensions represented by the ellipsis for _ in 0..ellipsis_dims { ranges.push(Range { start: 0, @@ -117,7 +122,6 @@ impl ArrayIndex { start: normalized_idx, end: normalized_idx + 1, }); - // Mark this dimension for squeezing squeeze_dims.push(current_dim); shape_idx += 1; @@ -180,11 +184,7 @@ impl ArrayIndex { current_dim += 1; } IndexType::NewAxis => { - ranges.push(Range { - start: 0, - end: shape[shape_idx], - }); - + newaxis_dims.push(current_dim); current_dim += 1; } } @@ -198,11 +198,23 @@ impl ArrayIndex { }); shape_idx += 1; } - // We must sort squeeze_dims in descending order so removing one doesn't shift - // the indices of subsequent ones we want to remove. squeeze_dims.sort_by(|a, b| b.cmp(a)); - Ok((ranges, squeeze_dims)) + // Adjust squeeze_dims from output space to read-array space: + // each NewAxis before a squeeze dim shifts it left by 1. + for pos in squeeze_dims.iter_mut() { + let shift = newaxis_dims.iter().filter(|&&n| n < *pos).count(); + *pos -= shift; + } + + // Adjust newaxis_dims from output space to post-squeeze space: + // each squeezed dim before a NewAxis shifts it left by 1. + for pos in newaxis_dims.iter_mut() { + let shift = squeeze_dims.iter().filter(|&&s| s < *pos).count(); + *pos -= shift; + } + + Ok((ranges, squeeze_dims, newaxis_dims)) } fn normalize_index(idx: i64, dim_size: u64) -> PyResult { @@ -380,6 +392,123 @@ mod tests { }); } + #[test] + fn test_newaxis() { + Python::initialize(); + + Python::attach(|py| { + // arr[np.newaxis] on shape (5,) + let shape = vec![5]; + let none = py.None(); + let none_value = none.bind(py); + let tuple = pyo3::types::PyTuple::new(py, [none_value]).unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], Range { start: 0, end: 5 }); + assert!(squeeze.is_empty()); + assert_eq!(newaxis, vec![0]); + + // arr[np.newaxis, :3] on shape (5,) + let slice = PySlice::new(py, 0, 3, 1); + let tuple = pyo3::types::PyTuple::new( + py, + [none_value, &slice.into_any()], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], Range { start: 0, end: 3 }); + assert!(squeeze.is_empty()); + assert_eq!(newaxis, vec![0]); + + // arr[1, np.newaxis] on shape (5,) + let int_value = 1i64.into_pyobject(py).unwrap(); + let tuple = pyo3::types::PyTuple::new( + py, + [&int_value.clone().into_any(), none_value], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], Range { start: 1, end: 2 }); + assert_eq!(squeeze, vec![0]); + assert_eq!(newaxis, vec![0]); + + // arr[np.newaxis, 1] on shape (5,) — newaxis then int + let tuple = pyo3::types::PyTuple::new( + py, + [none_value, &int_value.into_any()], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], Range { start: 1, end: 2 }); + assert_eq!(squeeze, vec![0]); + assert_eq!(newaxis, vec![0]); + + // arr[np.newaxis, np.newaxis] on shape (5,) — two newaxis + let tuple = pyo3::types::PyTuple::new( + py, + [none_value, none_value], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 1); + assert_eq!(ranges[0], Range { start: 0, end: 5 }); + assert!(squeeze.is_empty()); + assert_eq!(newaxis, vec![0, 1]); + }); + } + + #[test] + fn test_ellipsis_with_newaxis() { + Python::initialize(); + + Python::attach(|py| { + let shape = vec![2, 3, 4, 5]; + let ellipsis = pyo3::types::PyEllipsis::get(py).into_any(); + let none = py.None(); + let none_value = none.bind(py); + + // arr[..., np.newaxis] on shape (2,3,4,5) + let tuple = pyo3::types::PyTuple::new( + py, + [&ellipsis, none_value], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 4); + assert_eq!(ranges[0], Range { start: 0, end: 2 }); + assert_eq!(ranges[1], Range { start: 0, end: 3 }); + assert_eq!(ranges[2], Range { start: 0, end: 4 }); + assert_eq!(ranges[3], Range { start: 0, end: 5 }); + assert!(squeeze.is_empty()); + assert_eq!(newaxis, vec![4]); + + // arr[np.newaxis, ...] on shape (2,3,4,5) + let tuple = pyo3::types::PyTuple::new( + py, + [none_value, &ellipsis], + ) + .unwrap(); + let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap(); + let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap(); + assert_eq!(ranges.len(), 4); + assert_eq!(ranges[0], Range { start: 0, end: 2 }); + assert_eq!(ranges[1], Range { start: 0, end: 3 }); + assert_eq!(ranges[2], Range { start: 0, end: 4 }); + assert_eq!(ranges[3], Range { start: 0, end: 5 }); + assert!(squeeze.is_empty()); + assert_eq!(newaxis, vec![0]); + }); + } + #[test] #[should_panic] fn test_invalid_input() { diff --git a/src/reader.rs b/src/reader.rs index eeb7eb8..e226640 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -208,8 +208,8 @@ impl OmFileReader { if !bound_object.hasattr("cat_file")? || !bound_object.hasattr("size")? { return Err(PyErr::new::( - "Input must be a valid fsspec file object with read, seek methods and fs attribute", - )); + "Input must be a valid fsspec file object with `cat_file` and `size` methods.", + )); } let backend = ReaderBackendImpl::FsSpec(FsSpecBackend::new(fs_obj, path)?); @@ -500,7 +500,7 @@ impl OmFileReader { let array_reader = reader .expect_array_with_io_sizes(65536, 512) .map_err(|_| Self::only_arrays_error())?; - let (read_ranges, squeeze_dims) = + let (read_ranges, squeeze_dims, newaxis_dims) = ranges.get_ranges_and_squeeze_dims(&self.shape)?; let dtype = array_reader.data_type(); @@ -522,6 +522,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Int8(array)) } @@ -530,6 +531,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Uint8(array)) } @@ -538,6 +540,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Int16(array)) } @@ -546,6 +549,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Uint16(array)) } @@ -554,6 +558,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Int32(array)) } @@ -562,6 +567,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Uint32(array)) } @@ -570,6 +576,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Int64(array)) } @@ -578,6 +585,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Uint64(array)) } @@ -586,6 +594,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Float(array)) } @@ -594,6 +603,7 @@ impl OmFileReader { &array_reader, &read_ranges, &squeeze_dims, + &newaxis_dims, )?; Ok(OmFileTypedArray::Double(array)) } @@ -646,15 +656,14 @@ fn read_and_process_array( reader: &OmFileArrayRs, read_ranges: &[Range], squeeze_dims: &[usize], + newaxis_dims: &[usize], ) -> PyResult> { let array = reader .read::(read_ranges) .map_err(convert_omfilesrs_error)?; // Filter out dimensions of size 1 that correspond to integer indices - // This assumes the `array` returned by `read` has the full dimensionality - // matching `read_ranges` (which it does in omfiles-rs). - let new_shape: Vec = array + let mut new_shape: Vec = array .shape() .iter() .enumerate() @@ -667,6 +676,13 @@ fn read_and_process_array( }) .collect(); + // Insert size-1 dimensions for NewAxis, sorted ascending + let mut newaxis_sorted = newaxis_dims.to_vec(); + newaxis_sorted.sort(); + for &pos in newaxis_sorted.iter() { + new_shape.insert(pos, 1); + } + Ok(array .into_shape_with_order(new_shape) .map_err(|e| PyValueError::new_err(e.to_string()))?) diff --git a/src/reader_async.rs b/src/reader_async.rs index 507c943..bc25990 100644 --- a/src/reader_async.rs +++ b/src/reader_async.rs @@ -459,7 +459,7 @@ impl OmFileReaderAsync { /// TypeError: If the data type is not supported. async fn read_array<'py>(&self, ranges: ArrayIndex) -> PyResult { // Convert the Python ranges to Rust ranges - let (read_ranges, squeeze_dims) = ranges.get_ranges_and_squeeze_dims(&self.shape)?; + let (read_ranges, squeeze_dims, newaxis_dims) = ranges.get_ranges_and_squeeze_dims(&self.shape)?; let guard = self .reader @@ -479,52 +479,52 @@ impl OmFileReaderAsync { let result = match data_type { OmDataType::Int8Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Int8(array)) } OmDataType::Int16Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Int16(array)) } OmDataType::Int32Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Int32(array)) } OmDataType::Int64Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Int64(array)) } OmDataType::Uint8Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Uint8(array)) } OmDataType::Uint16Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Uint16(array)) } OmDataType::Uint32Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Uint32(array)) } OmDataType::Uint64Array => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Uint64(array)) } OmDataType::FloatArray => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Float(array)) } OmDataType::DoubleArray => { let array = - read_and_process_array::(reader, &read_ranges, &squeeze_dims).await?; + read_and_process_array::(reader, &read_ranges, &squeeze_dims, &newaxis_dims).await?; Ok(OmFileTypedArray::Double(array)) } _ => { @@ -568,6 +568,7 @@ async fn read_and_process_array( reader: &OmFileReaderAsyncRs, read_ranges: &[Range], squeeze_dims: &[usize], + newaxis_dims: &[usize], ) -> PyResult> where T: Element + OmFileArrayDataType + Clone + Zero + Send + Sync + 'static, @@ -580,10 +581,7 @@ where .await .map_err(convert_omfilesrs_error)?; - // Filter out dimensions of size 1 that correspond to integer indices - // This assumes the `array` returned by `read` has the full dimensionality - // matching `read_ranges` (which it does in omfiles-rs). - let new_shape: Vec = array + let mut new_shape: Vec = array .shape() .iter() .enumerate() @@ -596,6 +594,12 @@ where }) .collect(); + let mut newaxis_sorted = newaxis_dims.to_vec(); + newaxis_sorted.sort(); + for &pos in newaxis_sorted.iter() { + new_shape.insert(pos, 1); + } + Ok(array .into_shape_with_order(new_shape) .map_err(|e| PyValueError::new_err(e.to_string()))?) diff --git a/tests/test_read_write.py b/tests/test_read_write.py index 9667820..14b714d 100644 --- a/tests/test_read_write.py +++ b/tests/test_read_write.py @@ -463,6 +463,85 @@ def test_reader_close(temp_om_file): ) +def _ref_data(): + return np.arange(25, dtype=np.float32).reshape(5, 5) + + +def test_indexing_integer(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[0], ref[0]) + np.testing.assert_array_equal(reader[-1], ref[-1]) + np.testing.assert_array_equal(reader[2, 3], np.float32(ref[2, 3])) + reader.close() + + +def test_indexing_slice(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[1:4], ref[1:4]) + np.testing.assert_array_equal(reader[1:4, 2:5], ref[1:4, 2:5]) + np.testing.assert_array_equal(reader[:3, :2], ref[:3, :2]) + np.testing.assert_array_equal(reader[2:, 3:], ref[2:, 3:]) + reader.close() + + +def test_indexing_ellipsis(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[...], ref[...]) + np.testing.assert_array_equal(reader[1, ...], ref[1, ...]) + np.testing.assert_array_equal(reader[..., 2], ref[..., 2]) + np.testing.assert_array_equal(reader[1:4, ...], ref[1:4, ...]) + np.testing.assert_array_equal(reader[..., 2:5], ref[..., 2:5]) + reader.close() + + +def test_indexing_negative_slice(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[-3:], ref[-3:]) + np.testing.assert_array_equal(reader[-3:-1], ref[-3:-1]) + np.testing.assert_array_equal(reader[-4:-1, -3:], ref[-4:-1, -3:]) + reader.close() + + +def test_indexing_newaxis(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[None], ref[None]) + np.testing.assert_array_equal(reader[None, :3], ref[None, :3]) + np.testing.assert_array_equal(reader[1, None], ref[1, None]) + np.testing.assert_array_equal(reader[None, 1], ref[None, 1]) + np.testing.assert_array_equal(reader[None, None], ref[None, None]) + np.testing.assert_array_equal(reader[..., None], ref[..., None]) + np.testing.assert_array_equal(reader[None, ...], ref[None, ...]) + reader.close() + + +def test_indexing_mixed(temp_om_file): + ref = _ref_data() + reader = omfiles.OmFileReader(temp_om_file) + np.testing.assert_array_equal(reader[1:4, 2], ref[1:4, 2]) + np.testing.assert_array_equal(reader[0, 1:4], ref[0, 1:4]) + np.testing.assert_array_equal(reader[1:4, ...], ref[1:4, ...]) + np.testing.assert_array_equal(reader[..., 2:5], ref[..., 2:5]) + reader.close() + + +def test_indexing_errors(temp_om_file): + reader = omfiles.OmFileReader(temp_om_file) + with pytest.raises(IndexError): + _ = reader[10] + with pytest.raises(IndexError): + _ = reader[0, 10] + with pytest.raises(IndexError): + _ = reader[-10] + with pytest.raises(IndexError): + _ = reader[0, 0, 0] + reader.close() + + def test_child_traversal(temp_hierarchical_om_file): reader = omfiles.OmFileReader(temp_hierarchical_om_file)