Encapsulate Python sequence-like indexers (#12669)

This encapsulates a lot of the common logic around Python sequence-like
indexers (`SliceOrInt`) into iterators that handle adapting negative
indices and slices in `usize` for containers of a given size.

These indexers now all implement `ExactSizeIterator` and
`DoubleEndedIterator`, so they can be used with all `Iterator` methods,
and can be used (with `Iterator::map` and friends) as inputs to
`PyList::new_bound`, which makes code simpler at all points of use.

The special-cased uses of this kind of thing from `CircuitData` are
replaced with the new forms.  This had no measurable impact on
performance on my machine, and removes a lot noise from error-handling
and highly specialised functions.
This commit is contained in:
Jake Lishman 2024-07-01 13:59:21 +01:00 committed by GitHub
parent a7fc2daf4c
commit 373e8a68c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 524 additions and 303 deletions

2
Cargo.lock generated
View File

@ -1180,6 +1180,7 @@ dependencies = [
"rayon",
"rustworkx-core",
"smallvec",
"thiserror",
]
[[package]]
@ -1192,6 +1193,7 @@ dependencies = [
"numpy",
"pyo3",
"smallvec",
"thiserror",
]
[[package]]

View File

@ -20,6 +20,7 @@ num-complex = "0.4"
ndarray = "^0.15.6"
numpy = "0.21.0"
smallvec = "1.13"
thiserror = "1.0"
# Most of the crates don't need the feature `extension-module`, since only `qiskit-pyext` builds an
# actual C extension (the feature disables linking in `libpython`, which is forbidden in Python

View File

@ -23,6 +23,7 @@ rustworkx-core = "0.15"
faer = "0.19.1"
itertools = "0.13.0"
qiskit-circuit.workspace = true
thiserror.workspace = true
[dependencies.smallvec]
workspace = true

View File

@ -21,9 +21,9 @@ use std::f64::consts::PI;
use std::ops::Deref;
use std::str::FromStr;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyString;
use pyo3::types::{PyList, PyString};
use pyo3::wrap_pyfunction;
use pyo3::Python;
@ -31,8 +31,8 @@ use ndarray::prelude::*;
use numpy::PyReadonlyArray2;
use pyo3::pybacked::PyBackedStr;
use qiskit_circuit::slice::{PySequenceIndex, SequenceIndex};
use qiskit_circuit::util::c64;
use qiskit_circuit::SliceOrInt;
pub const ANGLE_ZERO_EPSILON: f64 = 1e-12;
@ -97,46 +97,15 @@ impl OneQubitGateSequence {
Ok(self.gates.len())
}
fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult<PyObject> {
match idx {
SliceOrInt::Slice(slc) => {
let len = self.gates.len().try_into().unwrap();
let indices = slc.indices(len)?;
let mut out_vec: Vec<(String, SmallVec<[f64; 3]>)> = Vec::new();
// Start and stop will always be positive the slice api converts
// negatives to the index for example:
// list(range(5))[-1:-3:-1]
// will return start=4, stop=2, and step=-1
let mut pos: isize = indices.start;
let mut cond = if indices.step < 0 {
pos > indices.stop
} else {
pos < indices.stop
};
while cond {
if pos < len as isize {
out_vec.push(self.gates[pos as usize].clone());
}
pos += indices.step;
if indices.step < 0 {
cond = pos > indices.stop;
} else {
cond = pos < indices.stop;
}
}
Ok(out_vec.into_py(py))
}
SliceOrInt::Int(idx) => {
let len = self.gates.len() as isize;
if idx >= len || idx < -len {
Err(PyIndexError::new_err(format!("Invalid index, {idx}")))
} else if idx < 0 {
let len = self.gates.len();
Ok(self.gates[len - idx.unsigned_abs()].to_object(py))
} else {
Ok(self.gates[idx as usize].to_object(py))
}
}
fn __getitem__(&self, py: Python, idx: PySequenceIndex) -> PyResult<PyObject> {
match idx.with_len(self.gates.len())? {
SequenceIndex::Int(idx) => Ok(self.gates[idx].to_object(py)),
indices => Ok(PyList::new_bound(
py,
indices.iter().map(|pos| self.gates[pos].to_object(py)),
)
.into_any()
.unbind()),
}
}
}

View File

@ -21,10 +21,6 @@
use approx::{abs_diff_eq, relative_eq};
use num_complex::{Complex, Complex64, ComplexFloat};
use num_traits::Zero;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3::Python;
use smallvec::{smallvec, SmallVec};
use std::f64::consts::{FRAC_1_SQRT_2, PI};
use std::ops::Deref;
@ -37,7 +33,11 @@ use ndarray::prelude::*;
use ndarray::Zip;
use numpy::PyReadonlyArray2;
use numpy::{IntoPyArray, ToPyArray};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyList;
use crate::convert_2q_block_matrix::change_basis;
use crate::euler_one_qubit_decomposer::{
@ -52,8 +52,8 @@ use rand_distr::StandardNormal;
use rand_pcg::Pcg64Mcg;
use qiskit_circuit::gate_matrix::{CX_GATE, H_GATE, ONE_QUBIT_IDENTITY, SX_GATE, X_GATE};
use qiskit_circuit::slice::{PySequenceIndex, SequenceIndex};
use qiskit_circuit::util::{c64, GateArray1Q, GateArray2Q, C_M_ONE, C_ONE, C_ZERO, IM, M_IM};
use qiskit_circuit::SliceOrInt;
const PI2: f64 = PI / 2.;
const PI4: f64 = PI / 4.;
@ -1131,46 +1131,15 @@ impl TwoQubitGateSequence {
Ok(self.gates.len())
}
fn __getitem__(&self, py: Python, idx: SliceOrInt) -> PyResult<PyObject> {
match idx {
SliceOrInt::Slice(slc) => {
let len = self.gates.len().try_into().unwrap();
let indices = slc.indices(len)?;
let mut out_vec: TwoQubitSequenceVec = Vec::new();
// Start and stop will always be positive the slice api converts
// negatives to the index for example:
// list(range(5))[-1:-3:-1]
// will return start=4, stop=2, and step=-
let mut pos: isize = indices.start;
let mut cond = if indices.step < 0 {
pos > indices.stop
} else {
pos < indices.stop
};
while cond {
if pos < len as isize {
out_vec.push(self.gates[pos as usize].clone());
}
pos += indices.step;
if indices.step < 0 {
cond = pos > indices.stop;
} else {
cond = pos < indices.stop;
}
}
Ok(out_vec.into_py(py))
}
SliceOrInt::Int(idx) => {
let len = self.gates.len() as isize;
if idx >= len || idx < -len {
Err(PyIndexError::new_err(format!("Invalid index, {idx}")))
} else if idx < 0 {
let len = self.gates.len();
Ok(self.gates[len - idx.unsigned_abs()].to_object(py))
} else {
Ok(self.gates[idx as usize].to_object(py))
}
}
fn __getitem__(&self, py: Python, idx: PySequenceIndex) -> PyResult<PyObject> {
match idx.with_len(self.gates.len())? {
SequenceIndex::Int(idx) => Ok(self.gates[idx].to_object(py)),
indices => Ok(PyList::new_bound(
py,
indices.iter().map(|pos| self.gates[pos].to_object(py)),
)
.into_any()
.unbind()),
}
}
}

View File

@ -14,6 +14,7 @@ hashbrown.workspace = true
num-complex.workspace = true
ndarray.workspace = true
numpy.workspace = true
thiserror.workspace = true
[dependencies.pyo3]
workspace = true

View File

@ -22,11 +22,12 @@ use crate::imports::{BUILTIN_LIST, QUBIT};
use crate::interner::{IndexedInterner, Interner, InternerKey};
use crate::operations::{Operation, OperationType, Param, StandardGate};
use crate::parameter_table::{ParamEntry, ParamTable, GLOBAL_PHASE_INDEX};
use crate::{Clbit, Qubit, SliceOrInt};
use crate::slice::{PySequenceIndex, SequenceIndex};
use crate::{Clbit, Qubit};
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyList, PySet, PySlice, PyTuple, PyType};
use pyo3::types::{PyList, PySet, PyTuple, PyType};
use pyo3::{intern, PyTraverseError, PyVisit};
use hashbrown::{HashMap, HashSet};
@ -321,7 +322,7 @@ impl CircuitData {
}
pub fn append_inner(&mut self, py: Python, value: PyRef<CircuitInstruction>) -> PyResult<bool> {
let packed = self.pack(py, value)?;
let packed = self.pack(value)?;
let new_index = self.data.len();
self.data.push(packed);
self.update_param_table(py, new_index, None)
@ -744,184 +745,130 @@ impl CircuitData {
}
// Note: we also rely on this to make us iterable!
pub fn __getitem__(&self, py: Python, index: &Bound<PyAny>) -> PyResult<PyObject> {
// Internal helper function to get a specific
// instruction by index.
fn get_at(
self_: &CircuitData,
py: Python<'_>,
index: isize,
) -> PyResult<Py<CircuitInstruction>> {
let index = self_.convert_py_index(index)?;
if let Some(inst) = self_.data.get(index) {
let qubits = self_.qargs_interner.intern(inst.qubits_id);
let clbits = self_.cargs_interner.intern(inst.clbits_id);
Py::new(
py,
CircuitInstruction::new(
py,
inst.op.clone(),
self_.qubits.map_indices(qubits.value),
self_.clbits.map_indices(clbits.value),
inst.params.clone(),
inst.extra_attrs.clone(),
),
)
} else {
Err(PyIndexError::new_err(format!(
"No element at index {:?} in circuit data",
index
)))
}
}
if index.is_exact_instance_of::<PySlice>() {
let slice = self.convert_py_slice(index.downcast_exact::<PySlice>()?)?;
let result = slice
.into_iter()
.map(|i| get_at(self, py, i))
.collect::<PyResult<Vec<_>>>()?;
Ok(result.into_py(py))
} else {
Ok(get_at(self, py, index.extract()?)?.into_py(py))
pub fn __getitem__(&self, py: Python, index: PySequenceIndex) -> PyResult<PyObject> {
// Get a single item, assuming the index is validated as in bounds.
let get_single = |index: usize| {
let inst = &self.data[index];
let qubits = self.qargs_interner.intern(inst.qubits_id);
let clbits = self.cargs_interner.intern(inst.clbits_id);
CircuitInstruction::new(
py,
inst.op.clone(),
self.qubits.map_indices(qubits.value),
self.clbits.map_indices(clbits.value),
inst.params.clone(),
inst.extra_attrs.clone(),
)
.into_py(py)
};
match index.with_len(self.data.len())? {
SequenceIndex::Int(index) => Ok(get_single(index)),
indices => Ok(PyList::new_bound(py, indices.iter().map(get_single)).into_py(py)),
}
}
pub fn __delitem__(&mut self, py: Python, index: SliceOrInt) -> PyResult<()> {
match index {
SliceOrInt::Slice(slice) => {
let slice = {
let mut s = self.convert_py_slice(&slice)?;
if s.len() > 1 && s.first().unwrap() < s.last().unwrap() {
// Reverse the order so we're sure to delete items
// at the back first (avoids messing up indices).
s.reverse()
}
s
};
for i in slice.into_iter() {
self.__delitem__(py, SliceOrInt::Int(i))?;
pub fn __delitem__(&mut self, py: Python, index: PySequenceIndex) -> PyResult<()> {
self.delitem(py, index.with_len(self.data.len())?)
}
pub fn setitem_no_param_table_update(
&mut self,
index: usize,
value: PyRef<CircuitInstruction>,
) -> PyResult<()> {
let mut packed = self.pack(value)?;
std::mem::swap(&mut packed, &mut self.data[index]);
Ok(())
}
pub fn __setitem__(&mut self, index: PySequenceIndex, value: &Bound<PyAny>) -> PyResult<()> {
fn set_single(slf: &mut CircuitData, index: usize, value: &Bound<PyAny>) -> PyResult<()> {
let py = value.py();
let mut packed = slf.pack(value.downcast::<CircuitInstruction>()?.borrow())?;
slf.remove_from_parameter_table(py, index)?;
std::mem::swap(&mut packed, &mut slf.data[index]);
slf.update_param_table(py, index, None)?;
Ok(())
}
let py = value.py();
match index.with_len(self.data.len())? {
SequenceIndex::Int(index) => set_single(self, index, value),
indices @ SequenceIndex::PosRange {
start,
stop,
step: 1,
} => {
// `list` allows setting a slice with step +1 to an arbitrary length.
let values = value.iter()?.collect::<PyResult<Vec<_>>>()?;
for (index, value) in indices.iter().zip(values.iter()) {
set_single(self, index, value)?;
}
if indices.len() > values.len() {
self.delitem(
py,
SequenceIndex::PosRange {
start: start + values.len(),
stop,
step: 1,
},
)?
} else {
for value in values[indices.len()..].iter().rev() {
self.insert(stop as isize, value.downcast()?.borrow())?;
}
}
self.reindex_parameter_table(py)?;
Ok(())
}
SliceOrInt::Int(index) => {
let index = self.convert_py_index(index)?;
if self.data.get(index).is_some() {
if index == self.data.len() {
// For individual removal from param table before
// deletion
self.remove_from_parameter_table(py, index)?;
self.data.remove(index);
} else {
// For delete in the middle delete before reindexing
self.data.remove(index);
self.reindex_parameter_table(py)?;
indices => {
let values = value.iter()?.collect::<PyResult<Vec<_>>>()?;
if indices.len() == values.len() {
for (index, value) in indices.iter().zip(values.iter()) {
set_single(self, index, value)?;
}
Ok(())
} else {
Err(PyIndexError::new_err(format!(
"No element at index {:?} in circuit data",
index
Err(PyValueError::new_err(format!(
"attempt to assign sequence of size {:?} to extended slice of size {:?}",
values.len(),
indices.len(),
)))
}
}
}
}
pub fn setitem_no_param_table_update(
&mut self,
py: Python<'_>,
index: isize,
value: &Bound<PyAny>,
) -> PyResult<()> {
let index = self.convert_py_index(index)?;
let value: PyRef<CircuitInstruction> = value.downcast()?.borrow();
let mut packed = self.pack(py, value)?;
std::mem::swap(&mut packed, &mut self.data[index]);
Ok(())
}
pub fn __setitem__(
&mut self,
py: Python<'_>,
index: SliceOrInt,
value: &Bound<PyAny>,
) -> PyResult<()> {
match index {
SliceOrInt::Slice(slice) => {
let indices = slice.indices(self.data.len().try_into().unwrap())?;
let slice = self.convert_py_slice(&slice)?;
let values = value.iter()?.collect::<PyResult<Vec<Bound<PyAny>>>>()?;
if indices.step != 1 && slice.len() != values.len() {
// A replacement of a different length when step isn't exactly '1'
// would result in holes.
return Err(PyValueError::new_err(format!(
"attempt to assign sequence of size {:?} to extended slice of size {:?}",
values.len(),
slice.len(),
)));
}
for (i, v) in slice.iter().zip(values.iter()) {
self.__setitem__(py, SliceOrInt::Int(*i), v)?;
}
if slice.len() > values.len() {
// Delete any extras.
let slice = PySlice::new_bound(
py,
indices.start + values.len() as isize,
indices.stop,
1isize,
);
self.__delitem__(py, SliceOrInt::Slice(slice))?;
} else {
// Insert any extra values.
for v in values.iter().skip(slice.len()).rev() {
let v: PyRef<CircuitInstruction> = v.extract()?;
self.insert(py, indices.stop, v)?;
}
}
Ok(())
pub fn insert(&mut self, mut index: isize, value: PyRef<CircuitInstruction>) -> PyResult<()> {
// `list.insert` has special-case extra clamping logic for its index argument.
let index = {
if index < 0 {
// This can't exceed `isize::MAX` because `self.data[0]` is larger than a byte.
index += self.data.len() as isize;
}
SliceOrInt::Int(index) => {
let index = self.convert_py_index(index)?;
let value: PyRef<CircuitInstruction> = value.extract()?;
let mut packed = self.pack(py, value)?;
self.remove_from_parameter_table(py, index)?;
std::mem::swap(&mut packed, &mut self.data[index]);
self.update_param_table(py, index, None)?;
Ok(())
if index < 0 {
0
} else if index as usize > self.data.len() {
self.data.len()
} else {
index as usize
}
}
}
pub fn insert(
&mut self,
py: Python<'_>,
index: isize,
value: PyRef<CircuitInstruction>,
) -> PyResult<()> {
let index = self.convert_py_index_clamped(index);
let old_len = self.data.len();
let packed = self.pack(py, value)?;
};
let py = value.py();
let packed = self.pack(value)?;
self.data.insert(index, packed);
if index == old_len {
self.update_param_table(py, old_len, None)?;
if index == self.data.len() - 1 {
self.update_param_table(py, index, None)?;
} else {
self.reindex_parameter_table(py)?;
}
Ok(())
}
pub fn pop(&mut self, py: Python<'_>, index: Option<PyObject>) -> PyResult<PyObject> {
let index =
index.unwrap_or_else(|| std::cmp::max(0, self.data.len() as isize - 1).into_py(py));
let item = self.__getitem__(py, index.bind(py))?;
self.__delitem__(py, index.bind(py).extract()?)?;
pub fn pop(&mut self, py: Python<'_>, index: Option<PySequenceIndex>) -> PyResult<PyObject> {
let index = index.unwrap_or(PySequenceIndex::Int(-1));
let native_index = index.with_len(self.data.len())?;
let item = self.__getitem__(py, index)?;
self.delitem(py, native_index)?;
Ok(item)
}
@ -931,7 +878,7 @@ impl CircuitData {
value: &Bound<CircuitInstruction>,
params: Option<Vec<(usize, Vec<PyObject>)>>,
) -> PyResult<bool> {
let packed = self.pack(py, value.try_borrow()?)?;
let packed = self.pack(value.try_borrow()?)?;
let new_index = self.data.len();
self.data.push(packed);
self.update_param_table(py, new_index, params)
@ -1175,56 +1122,22 @@ impl CircuitData {
}
impl CircuitData {
/// Converts a Python slice to a `Vec` of indices into
/// the instruction listing, [CircuitData.data].
fn convert_py_slice(&self, slice: &Bound<PySlice>) -> PyResult<Vec<isize>> {
let indices = slice.indices(self.data.len().try_into().unwrap())?;
if indices.step > 0 {
Ok((indices.start..indices.stop)
.step_by(indices.step as usize)
.collect())
} else {
let mut out = Vec::with_capacity(indices.slicelength as usize);
let mut x = indices.start;
while x > indices.stop {
out.push(x);
x += indices.step;
}
Ok(out)
/// Native internal driver of `__delitem__` that uses a Rust-space version of the
/// `SequenceIndex`. This assumes that the `SequenceIndex` contains only in-bounds indices, and
/// panics if not.
fn delitem(&mut self, py: Python, indices: SequenceIndex) -> PyResult<()> {
// We need to delete in reverse order so we don't invalidate higher indices with a deletion.
for index in indices.descending() {
self.data.remove(index);
}
}
/// Converts a Python index to an index into the instruction listing,
/// or one past its end.
/// If the resulting index would be < 0, clamps to 0.
/// If the resulting index would be > len(data), clamps to len(data).
fn convert_py_index_clamped(&self, index: isize) -> usize {
let index = if index < 0 {
index + self.data.len() as isize
} else {
index
};
std::cmp::min(std::cmp::max(0, index), self.data.len() as isize) as usize
}
/// Converts a Python index to an index into the instruction listing.
fn convert_py_index(&self, index: isize) -> PyResult<usize> {
let index = if index < 0 {
index + self.data.len() as isize
} else {
index
};
if index < 0 || index >= self.data.len() as isize {
return Err(PyIndexError::new_err(format!(
"Index {:?} is out of bounds.",
index,
)));
if !indices.is_empty() {
self.reindex_parameter_table(py)?;
}
Ok(index as usize)
Ok(())
}
fn pack(&mut self, py: Python, inst: PyRef<CircuitInstruction>) -> PyResult<PackedInstruction> {
fn pack(&mut self, inst: PyRef<CircuitInstruction>) -> PyResult<PackedInstruction> {
let py = inst.py();
let qubits = Interner::intern(
&mut self.qargs_interner,
InternerKey::Value(self.qubits.map_bits(inst.qubits.bind(py))?.collect()),

View File

@ -17,23 +17,13 @@ pub mod gate_matrix;
pub mod imports;
pub mod operations;
pub mod parameter_table;
pub mod slice;
pub mod util;
mod bit_data;
mod interner;
use pyo3::prelude::*;
use pyo3::types::PySlice;
/// A private enumeration type used to extract arguments to pymethod
/// that may be either an index or a slice
#[derive(FromPyObject)]
pub enum SliceOrInt<'a> {
// The order here defines the order the variants are tried in the FromPyObject` derivation.
// `Int` is _much_ more common, so that should be first.
Int(isize),
Slice(Bound<'a, PySlice>),
}
pub type BitType = u32;
#[derive(Copy, Clone, Debug, Hash, Ord, PartialOrd, Eq, PartialEq)]

375
crates/circuit/src/slice.rs Normal file
View File

@ -0,0 +1,375 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.
use thiserror::Error;
use pyo3::exceptions::PyIndexError;
use pyo3::prelude::*;
use pyo3::types::PySlice;
use self::sealed::{Descending, SequenceIndexIter};
/// A Python-space indexer for the standard `PySequence` type; a single integer or a slice.
///
/// These come in as `isize`s from Python space, since Python typically allows negative indices.
/// Use `with_len` to specialize the index to a valid Rust-space indexer into a collection of the
/// given length.
pub enum PySequenceIndex<'py> {
Int(isize),
Slice(Bound<'py, PySlice>),
}
impl<'py> FromPyObject<'py> for PySequenceIndex<'py> {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
// `slice` can't be subclassed in Python, so it's safe (and faster) to check for it exactly.
// The `downcast_exact` check is just a pointer comparison, so while `slice` is the less
// common input, doing that first has little-to-no impact on the speed of the `isize` path,
// while the reverse makes `slice` inputs significantly slower.
if let Ok(slice) = ob.downcast_exact::<PySlice>() {
return Ok(Self::Slice(slice.clone()));
}
Ok(Self::Int(ob.extract()?))
}
}
impl<'py> PySequenceIndex<'py> {
/// Specialize this index to a collection of the given `len`, returning a Rust-native type.
pub fn with_len(&self, len: usize) -> Result<SequenceIndex, PySequenceIndexError> {
match self {
PySequenceIndex::Int(index) => {
let index = if *index >= 0 {
let index = *index as usize;
if index >= len {
return Err(PySequenceIndexError::OutOfRange);
}
index
} else {
len.checked_sub(index.unsigned_abs())
.ok_or(PySequenceIndexError::OutOfRange)?
};
Ok(SequenceIndex::Int(index))
}
PySequenceIndex::Slice(slice) => {
let indices = slice
.indices(len as ::std::os::raw::c_long)
.map_err(PySequenceIndexError::from)?;
if indices.step > 0 {
Ok(SequenceIndex::PosRange {
start: indices.start as usize,
stop: indices.stop as usize,
step: indices.step as usize,
})
} else {
Ok(SequenceIndex::NegRange {
// `indices.start` can be negative if the collection length is 0.
start: (indices.start >= 0).then_some(indices.start as usize),
// `indices.stop` can be negative if the 0 index should be output.
stop: (indices.stop >= 0).then_some(indices.stop as usize),
step: indices.step.unsigned_abs(),
})
}
}
}
}
}
/// Error type for problems encountered when calling methods on `PySequenceIndex`.
#[derive(Error, Debug)]
pub enum PySequenceIndexError {
#[error("index out of range")]
OutOfRange,
#[error(transparent)]
InnerPy(#[from] PyErr),
}
impl From<PySequenceIndexError> for PyErr {
fn from(value: PySequenceIndexError) -> PyErr {
match value {
PySequenceIndexError::OutOfRange => PyIndexError::new_err("index out of range"),
PySequenceIndexError::InnerPy(inner) => inner,
}
}
}
/// Rust-native version of a Python sequence-like indexer.
///
/// Typically this is constructed by a call to `PySequenceIndex::with_len`, which guarantees that
/// all the indices will be in bounds for a collection of the given length.
///
/// This splits the positive- and negative-step versions of the slice in two so it can be translated
/// more easily into static dispatch. This type can be converted into several types of iterator.
#[derive(Clone, Copy, Debug)]
pub enum SequenceIndex {
Int(usize),
PosRange {
start: usize,
stop: usize,
step: usize,
},
NegRange {
start: Option<usize>,
stop: Option<usize>,
step: usize,
},
}
impl SequenceIndex {
/// The number of indices this refers to.
pub fn len(&self) -> usize {
match self {
Self::Int(_) => 1,
Self::PosRange { start, stop, step } => {
let gap = stop.saturating_sub(*start);
gap / *step + (gap % *step != 0) as usize
}
Self::NegRange { start, stop, step } => 'arm: {
let Some(start) = start else { break 'arm 0 };
let gap = stop
.map(|stop| start.saturating_sub(stop))
.unwrap_or(*start + 1);
gap / step + (gap % step != 0) as usize
}
}
}
pub fn is_empty(&self) -> bool {
// This is just to keep clippy happy; the length is already fairly inexpensive to calculate.
self.len() == 0
}
/// Get an iterator over the indices. This will be a single-item iterator for the case of
/// `Self::Int`, but you probably wanted to destructure off that case beforehand anyway.
pub fn iter(&self) -> SequenceIndexIter {
match self {
Self::Int(value) => SequenceIndexIter::Int(Some(*value)),
Self::PosRange { start, step, .. } => SequenceIndexIter::PosRange {
lowest: *start,
step: *step,
indices: 0..self.len(),
},
Self::NegRange { start, step, .. } => SequenceIndexIter::NegRange {
// We can unwrap `highest` to an arbitrary value if `None`, because in that case the
// `len` is 0 and the iterator will not yield any objects.
highest: start.unwrap_or_default(),
step: *step,
indices: 0..self.len(),
},
}
}
// Get an iterator over the contained indices that is guaranteed to iterate from the highest
// index to the lowest.
pub fn descending(&self) -> Descending {
Descending(self.iter())
}
}
impl IntoIterator for SequenceIndex {
type Item = usize;
type IntoIter = SequenceIndexIter;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
// Private module to make it impossible to construct or inspect the internals of the iterator types
// from outside this file, while still allowing them to be used.
mod sealed {
/// Custom iterator for indices for Python sequence-likes.
///
/// In the range types, the `indices ` are `Range` objects that run from 0 to the length of the
/// iterator. In theory, we could generate the iterators ourselves, but that ends up with a lot of
/// boilerplate.
#[derive(Clone, Debug)]
pub enum SequenceIndexIter {
Int(Option<usize>),
PosRange {
lowest: usize,
step: usize,
indices: ::std::ops::Range<usize>,
},
NegRange {
highest: usize,
// The step of the iterator, but note that this is a negative range, so the forwards method
// steps downwards from `upper` towards `lower`.
step: usize,
indices: ::std::ops::Range<usize>,
},
}
impl Iterator for SequenceIndexIter {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self {
Self::Int(value) => value.take(),
Self::PosRange {
lowest,
step,
indices,
} => indices.next().map(|idx| *lowest + idx * *step),
Self::NegRange {
highest,
step,
indices,
} => indices.next().map(|idx| *highest - idx * *step),
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
match self {
Self::Int(None) => (0, Some(0)),
Self::Int(Some(_)) => (1, Some(1)),
Self::PosRange { indices, .. } | Self::NegRange { indices, .. } => {
indices.size_hint()
}
}
}
}
impl DoubleEndedIterator for SequenceIndexIter {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
match self {
Self::Int(value) => value.take(),
Self::PosRange {
lowest,
step,
indices,
} => indices.next_back().map(|idx| *lowest + idx * *step),
Self::NegRange {
highest,
step,
indices,
} => indices.next_back().map(|idx| *highest - idx * *step),
}
}
}
impl ExactSizeIterator for SequenceIndexIter {}
pub struct Descending(pub SequenceIndexIter);
impl Iterator for Descending {
type Item = usize;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
match self.0 {
SequenceIndexIter::Int(_) | SequenceIndexIter::NegRange { .. } => self.0.next(),
SequenceIndexIter::PosRange { .. } => self.0.next_back(),
}
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
self.0.size_hint()
}
}
impl DoubleEndedIterator for Descending {
#[inline]
fn next_back(&mut self) -> Option<Self::Item> {
match self.0 {
SequenceIndexIter::Int(_) | SequenceIndexIter::NegRange { .. } => {
self.0.next_back()
}
SequenceIndexIter::PosRange { .. } => self.0.next(),
}
}
}
impl ExactSizeIterator for Descending {}
}
#[cfg(test)]
mod test {
use super::*;
/// Get a set of test parametrisations for iterator methods. The second argument is the
/// expected values from a normal forward iteration.
fn index_iterator_cases() -> impl Iterator<Item = (SequenceIndex, Vec<usize>)> {
let pos = |start, stop, step| SequenceIndex::PosRange { start, stop, step };
let neg = |start, stop, step| SequenceIndex::NegRange { start, stop, step };
[
(SequenceIndex::Int(3), vec![3]),
(pos(0, 5, 2), vec![0, 2, 4]),
(pos(2, 10, 1), vec![2, 3, 4, 5, 6, 7, 8, 9]),
(pos(1, 15, 3), vec![1, 4, 7, 10, 13]),
(neg(Some(3), None, 1), vec![3, 2, 1, 0]),
(neg(Some(3), None, 2), vec![3, 1]),
(neg(Some(2), Some(0), 1), vec![2, 1]),
(neg(Some(2), Some(0), 2), vec![2]),
(neg(Some(2), Some(0), 3), vec![2]),
(neg(Some(10), Some(2), 3), vec![10, 7, 4]),
(neg(None, None, 1), vec![]),
(neg(None, None, 3), vec![]),
]
.into_iter()
}
/// Test that the index iterator's implementation of `ExactSizeIterator` is correct.
#[test]
fn index_iterator() {
for (index, forwards) in index_iterator_cases() {
// We're testing that all the values are the same, and the `size_hint` is correct at
// every single point.
let mut actual = Vec::new();
let mut sizes = Vec::new();
let mut iter = index.iter();
loop {
sizes.push(iter.size_hint().0);
if let Some(next) = iter.next() {
actual.push(next);
} else {
break;
}
}
assert_eq!(
actual, forwards,
"values for {:?}\nActual : {:?}\nExpected: {:?}",
index, actual, forwards,
);
let expected_sizes = (0..=forwards.len()).rev().collect::<Vec<_>>();
assert_eq!(
sizes, expected_sizes,
"sizes for {:?}\nActual : {:?}\nExpected: {:?}",
index, sizes, expected_sizes,
);
}
}
/// Test that the index iterator's implementation of `DoubleEndedIterator` is correct.
#[test]
fn reversed_index_iterator() {
for (index, forwards) in index_iterator_cases() {
let actual = index.iter().rev().collect::<Vec<_>>();
let expected = forwards.into_iter().rev().collect::<Vec<_>>();
assert_eq!(
actual, expected,
"reversed {:?}\nActual : {:?}\nExpected: {:?}",
index, actual, expected,
);
}
}
/// Test that `descending` produces its values in reverse-sorted order.
#[test]
fn descending() {
for (index, mut expected) in index_iterator_cases() {
let actual = index.descending().collect::<Vec<_>>();
expected.sort_by(|left, right| right.cmp(left));
assert_eq!(
actual, expected,
"descending {:?}\nActual : {:?}\nExpected: {:?}",
index, actual, expected,
);
}
}
}