Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 143 additions & 14 deletions src/array_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,30 @@ impl ArrayIndex {
pub fn get_ranges_and_squeeze_dims(
&self,
shape: &Vec<u64>,
) -> PyResult<(Vec<Range<u64>>, Vec<usize>)> {
// Input validation
if self.0.len() > shape.len() {
) -> PyResult<(Vec<Range<u64>>, Vec<usize>, Vec<usize>)> {
// Count how many actual dimensions are consumed by non-NewAxis indices
let consumed_dims: usize = self
.0
.iter()
.filter(|&x| !matches!(x, IndexType::NewAxis | IndexType::Ellipsis))
.count();
if consumed_dims > shape.len() {
return Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Too many indices for array",
));
}

let mut ranges = Vec::new();
let mut squeeze_dims = Vec::new();
let mut newaxis_dims = Vec::new();
let mut current_dim = 0;

let mut shape_idx = 0;
let mut ellipsis_seen = false;
let explicit_dims: usize = self
.0
.iter()
.filter(|&x| !matches!(x, IndexType::Ellipsis))
.filter(|&x| !matches!(x, IndexType::Ellipsis | IndexType::NewAxis))
.count();
let ellipsis_dims = shape.len().saturating_sub(explicit_dims);

Expand All @@ -100,7 +106,6 @@ impl ArrayIndex {
"Only one ellipsis allowed in index",
));
}
// Add full ranges for all dimensions represented by the ellipsis
for _ in 0..ellipsis_dims {
ranges.push(Range {
start: 0,
Expand All @@ -117,7 +122,6 @@ impl ArrayIndex {
start: normalized_idx,
end: normalized_idx + 1,
});
// Mark this dimension for squeezing
squeeze_dims.push(current_dim);

shape_idx += 1;
Expand Down Expand Up @@ -180,11 +184,7 @@ impl ArrayIndex {
current_dim += 1;
}
IndexType::NewAxis => {
ranges.push(Range {
start: 0,
end: shape[shape_idx],
});

newaxis_dims.push(current_dim);
current_dim += 1;
}
}
Expand All @@ -198,11 +198,23 @@ impl ArrayIndex {
});
shape_idx += 1;
}
// We must sort squeeze_dims in descending order so removing one doesn't shift
// the indices of subsequent ones we want to remove.
squeeze_dims.sort_by(|a, b| b.cmp(a));

Ok((ranges, squeeze_dims))
// Adjust squeeze_dims from output space to read-array space:
// each NewAxis before a squeeze dim shifts it left by 1.
for pos in squeeze_dims.iter_mut() {
let shift = newaxis_dims.iter().filter(|&&n| n < *pos).count();
*pos -= shift;
}

// Adjust newaxis_dims from output space to post-squeeze space:
// each squeezed dim before a NewAxis shifts it left by 1.
for pos in newaxis_dims.iter_mut() {
let shift = squeeze_dims.iter().filter(|&&s| s < *pos).count();
*pos -= shift;
}
Comment on lines 201 to +215

Ok((ranges, squeeze_dims, newaxis_dims))
}

fn normalize_index(idx: i64, dim_size: u64) -> PyResult<u64> {
Expand Down Expand Up @@ -380,6 +392,123 @@ mod tests {
});
}

#[test]
fn test_newaxis() {
Python::initialize();

Python::attach(|py| {
// arr[np.newaxis] on shape (5,)
let shape = vec![5];
let none = py.None();
let none_value = none.bind(py);
let tuple = pyo3::types::PyTuple::new(py, [none_value]).unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], Range { start: 0, end: 5 });
assert!(squeeze.is_empty());
assert_eq!(newaxis, vec![0]);

// arr[np.newaxis, :3] on shape (5,)
let slice = PySlice::new(py, 0, 3, 1);
let tuple = pyo3::types::PyTuple::new(
py,
[none_value, &slice.into_any()],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], Range { start: 0, end: 3 });
assert!(squeeze.is_empty());
assert_eq!(newaxis, vec![0]);

// arr[1, np.newaxis] on shape (5,)
let int_value = 1i64.into_pyobject(py).unwrap();
let tuple = pyo3::types::PyTuple::new(
py,
[&int_value.clone().into_any(), none_value],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], Range { start: 1, end: 2 });
assert_eq!(squeeze, vec![0]);
assert_eq!(newaxis, vec![0]);

// arr[np.newaxis, 1] on shape (5,) — newaxis then int
let tuple = pyo3::types::PyTuple::new(
py,
[none_value, &int_value.into_any()],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], Range { start: 1, end: 2 });
assert_eq!(squeeze, vec![0]);
assert_eq!(newaxis, vec![0]);

// arr[np.newaxis, np.newaxis] on shape (5,) — two newaxis
let tuple = pyo3::types::PyTuple::new(
py,
[none_value, none_value],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], Range { start: 0, end: 5 });
assert!(squeeze.is_empty());
assert_eq!(newaxis, vec![0, 1]);
});
}

#[test]
fn test_ellipsis_with_newaxis() {
Python::initialize();

Python::attach(|py| {
let shape = vec![2, 3, 4, 5];
let ellipsis = pyo3::types::PyEllipsis::get(py).into_any();
let none = py.None();
let none_value = none.bind(py);

// arr[..., np.newaxis] on shape (2,3,4,5)
let tuple = pyo3::types::PyTuple::new(
py,
[&ellipsis, none_value],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 4);
assert_eq!(ranges[0], Range { start: 0, end: 2 });
assert_eq!(ranges[1], Range { start: 0, end: 3 });
assert_eq!(ranges[2], Range { start: 0, end: 4 });
assert_eq!(ranges[3], Range { start: 0, end: 5 });
assert!(squeeze.is_empty());
assert_eq!(newaxis, vec![4]);

// arr[np.newaxis, ...] on shape (2,3,4,5)
let tuple = pyo3::types::PyTuple::new(
py,
[none_value, &ellipsis],
)
.unwrap();
let index = ArrayIndex::extract(tuple.as_any().as_borrowed()).unwrap();
let (ranges, squeeze, newaxis) = index.get_ranges_and_squeeze_dims(&shape).unwrap();
assert_eq!(ranges.len(), 4);
assert_eq!(ranges[0], Range { start: 0, end: 2 });
assert_eq!(ranges[1], Range { start: 0, end: 3 });
assert_eq!(ranges[2], Range { start: 0, end: 4 });
assert_eq!(ranges[3], Range { start: 0, end: 5 });
assert!(squeeze.is_empty());
assert_eq!(newaxis, vec![0]);
});
}

#[test]
#[should_panic]
fn test_invalid_input() {
Expand Down
28 changes: 22 additions & 6 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ impl OmFileReader {

if !bound_object.hasattr("cat_file")? || !bound_object.hasattr("size")? {
return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
"Input must be a valid fsspec file object with read, seek methods and fs attribute",
));
"Input must be a valid fsspec file object with `cat_file` and `size` methods.",
));
}

let backend = ReaderBackendImpl::FsSpec(FsSpecBackend::new(fs_obj, path)?);
Expand Down Expand Up @@ -500,7 +500,7 @@ impl OmFileReader {
let array_reader = reader
.expect_array_with_io_sizes(65536, 512)
.map_err(|_| Self::only_arrays_error())?;
let (read_ranges, squeeze_dims) =
let (read_ranges, squeeze_dims, newaxis_dims) =
ranges.get_ranges_and_squeeze_dims(&self.shape)?;
let dtype = array_reader.data_type();

Expand All @@ -522,6 +522,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Int8(array))
}
Expand All @@ -530,6 +531,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Uint8(array))
}
Expand All @@ -538,6 +540,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Int16(array))
}
Expand All @@ -546,6 +549,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Uint16(array))
}
Expand All @@ -554,6 +558,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Int32(array))
}
Expand All @@ -562,6 +567,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Uint32(array))
}
Expand All @@ -570,6 +576,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Int64(array))
}
Expand All @@ -578,6 +585,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Uint64(array))
}
Expand All @@ -586,6 +594,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Float(array))
}
Expand All @@ -594,6 +603,7 @@ impl OmFileReader {
&array_reader,
&read_ranges,
&squeeze_dims,
&newaxis_dims,
)?;
Ok(OmFileTypedArray::Double(array))
}
Expand Down Expand Up @@ -646,15 +656,14 @@ fn read_and_process_array<T: Element + OmFileArrayDataType + Clone + Zero>(
reader: &OmFileArrayRs<impl OmFileReaderBackend>,
read_ranges: &[Range<u64>],
squeeze_dims: &[usize],
newaxis_dims: &[usize],
) -> PyResult<ndarray::ArrayD<T>> {
let array = reader
.read::<T>(read_ranges)
.map_err(convert_omfilesrs_error)?;

// Filter out dimensions of size 1 that correspond to integer indices
// This assumes the `array` returned by `read` has the full dimensionality
// matching `read_ranges` (which it does in omfiles-rs).
let new_shape: Vec<usize> = array
let mut new_shape: Vec<usize> = array
.shape()
.iter()
.enumerate()
Expand All @@ -667,6 +676,13 @@ fn read_and_process_array<T: Element + OmFileArrayDataType + Clone + Zero>(
})
.collect();

// Insert size-1 dimensions for NewAxis, sorted ascending
let mut newaxis_sorted = newaxis_dims.to_vec();
newaxis_sorted.sort();
for &pos in newaxis_sorted.iter() {
new_shape.insert(pos, 1);
}

Ok(array
.into_shape_with_order(new_shape)
.map_err(|e| PyValueError::new_err(e.to_string()))?)
Expand Down
Loading