Rollup merge of #83608 - Kimundi:index_many, r=Mark-Simulacrum

Add slice methods for indexing via an array of indices.

Disclaimer: It's been a while since I contributed to the main Rust repo, apologies in advance if this is large enough already that it should've been an RFC.

---

# Update:

- Based on feedback, removed the `&[T]` variant of this API, and removed the requirements for the indices to be sorted.

# Description

This adds the following slice methods to `core`:

```rust
impl<T> [T] {
    pub unsafe fn get_many_unchecked_mut<const N: usize>(&mut self, indices: [usize; N]) -> [&mut T; N];
    pub fn get_many_mut<const N: usize>(&mut self, indices: [usize; N]) -> Option<[&mut T; N]>;
}
```

This allows creating multiple mutable references to disjunct positions in a slice, which previously required writing some awkward code with `split_at_mut()` or `iter_mut()`. For the bound-checked variant, the indices are checked against each other and against the bounds of the slice, which requires `N * (N + 1) / 2` comparison operations.

This has a proof-of-concept standalone implementation here: https://crates.io/crates/index_many

Care has been taken that the implementation passes miri borrow checks, and generates straight-forward assembly (though this was only checked on x86_64).

# Example

```rust
let v = &mut [1, 2, 3, 4];
let [a, b] = v.get_many_mut([0, 2]).unwrap();
std::mem::swap(a, b);
*v += 100;
assert_eq!(v, &[3, 2, 101, 4]);
```

# Codegen Examples

<details>
  <summary>Click to expand!</summary>

Disclaimer: Taken from local tests with the standalone implementation.

## Unchecked Indexing:

```rust
pub unsafe fn example_unchecked(slice: &mut [usize], indices: [usize; 3]) -> [&mut usize; 3] {
    slice.get_many_unchecked_mut(indices)
}
```

```nasm
example_unchecked:
 mov     rcx, qword, ptr, [r9]
 mov     r8, qword, ptr, [r9, +, 8]
 mov     r9, qword, ptr, [r9, +, 16]
 lea     rcx, [rdx, +, 8*rcx]
 lea     r8, [rdx, +, 8*r8]
 lea     rdx, [rdx, +, 8*r9]
 mov     qword, ptr, [rax], rcx
 mov     qword, ptr, [rax, +, 8], r8
 mov     qword, ptr, [rax, +, 16], rdx
 ret
```

## Checked Indexing (Option):

```rust
pub unsafe fn example_option(slice: &mut [usize], indices: [usize; 3]) -> Option<[&mut usize; 3]> {
    slice.get_many_mut(indices)
}
```

```nasm
 mov     r10, qword, ptr, [r9, +, 8]
 mov     rcx, qword, ptr, [r9, +, 16]
 cmp     rcx, r10
 je      .LBB0_7
 mov     r9, qword, ptr, [r9]
 cmp     rcx, r9
 je      .LBB0_7
 cmp     rcx, r8
 jae     .LBB0_7
 cmp     r10, r9
 je      .LBB0_7
 cmp     r9, r8
 jae     .LBB0_7
 cmp     r10, r8
 jae     .LBB0_7
 lea     r8, [rdx, +, 8*r9]
 lea     r9, [rdx, +, 8*r10]
 lea     rcx, [rdx, +, 8*rcx]
 mov     qword, ptr, [rax], r8
 mov     qword, ptr, [rax, +, 8], r9
 mov     qword, ptr, [rax, +, 16], rcx
 ret
.LBB0_7:
 mov     qword, ptr, [rax], 0
 ret
```

## Checked Indexing (Panic):

```rust
pub fn example_panic(slice: &mut [usize], indices: [usize; 3]) -> [&mut usize; 3] {
    let len = slice.len();
    match slice.get_many_mut(indices) {
        Some(s) => s,
        None => {
            let tmp = indices;
            index_many::sorted_bound_check_failed(&tmp, len)
        }
    }
}
```

```nasm
example_panic:
 sub     rsp, 56
 mov     rax, qword, ptr, [r9]
 mov     r10, qword, ptr, [r9, +, 8]
 mov     r9, qword, ptr, [r9, +, 16]
 cmp     r9, r10
 je      .LBB0_6
 cmp     r9, rax
 je      .LBB0_6
 cmp     r9, r8
 jae     .LBB0_6
 cmp     r10, rax
 je      .LBB0_6
 cmp     rax, r8
 jae     .LBB0_6
 cmp     r10, r8
 jae     .LBB0_6
 lea     rax, [rdx, +, 8*rax]
 lea     r8, [rdx, +, 8*r10]
 lea     rdx, [rdx, +, 8*r9]
 mov     qword, ptr, [rcx], rax
 mov     qword, ptr, [rcx, +, 8], r8
 mov     qword, ptr, [rcx, +, 16], rdx
 mov     rax, rcx
 add     rsp, 56
 ret
.LBB0_6:
 mov     qword, ptr, [rsp, +, 32], rax
 mov     qword, ptr, [rsp, +, 40], r10
 mov     qword, ptr, [rsp, +, 48], r9
 lea     rcx, [rsp, +, 32]
 mov     edx, 3
 call    index_many::bound_check_failed
 ud2
```
</details>

# Extensions

There are multiple optional extensions to this.

## Indexing With Ranges

This could easily be expanded to allow indexing with `[I; N]` where `I: SliceIndex<Self>`.  I wanted to keep the initial implementation simple, so I didn't include it yet.

## Panicking Variant

We could also add this method:

```rust
impl<T> [T] {
    fn index_many_mut<const N: usize>(&mut self, indices: [usize; N]) -> [&mut T; N];
}
```

This would work similar to the regular index operator and panic with out-of-bound indices. The advantage would be that we could more easily ensure good codegen with a useful panic message, which is non-trivial with the `Option` variant.

This is implemented in the standalone implementation, and used as basis for the codegen examples here and there.
This commit is contained in:
Manish Goregaokar 2022-11-22 01:26:05 -05:00 committed by GitHub
commit 1dd515f273
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 201 additions and 0 deletions

View File

@ -506,3 +506,6 @@ impl Error for crate::ffi::FromBytesWithNulError {
#[unstable(feature = "cstr_from_bytes_until_nul", issue = "95027")]
impl Error for crate::ffi::FromBytesUntilNulError {}
#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> Error for crate::slice::GetManyMutError<N> {}

View File

@ -7,6 +7,7 @@
#![stable(feature = "rust1", since = "1.0.0")]
use crate::cmp::Ordering::{self, Greater, Less};
use crate::fmt;
use crate::intrinsics::{assert_unsafe_precondition, exact_div};
use crate::marker::Copy;
use crate::mem::{self, SizedTypeProperties};
@ -4082,6 +4083,88 @@ impl<T> [T] {
*self = rem;
Some(last)
}
/// Returns mutable references to many indices at once, without doing any checks.
///
/// For a safe alternative see [`get_many_mut`].
///
/// # Safety
///
/// Calling this method with overlapping or out-of-bounds indices is *[undefined behavior]*
/// even if the resulting references are not used.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let x = &mut [1, 2, 4];
///
/// unsafe {
/// let [a, b] = x.get_many_unchecked_mut([0, 2]);
/// *a *= 10;
/// *b *= 100;
/// }
/// assert_eq!(x, &[10, 2, 400]);
/// ```
///
/// [`get_many_mut`]: slice::get_many_mut
/// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> [&mut T; N] {
// NB: This implementation is written as it is because any variation of
// `indices.map(|i| self.get_unchecked_mut(i))` would make miri unhappy,
// or generate worse code otherwise. This is also why we need to go
// through a raw pointer here.
let slice: *mut [T] = self;
let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit();
let arr_ptr = arr.as_mut_ptr();
// SAFETY: We expect `indices` to contain disjunct values that are
// in bounds of `self`.
unsafe {
for i in 0..N {
let idx = *indices.get_unchecked(i);
*(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx);
}
arr.assume_init()
}
}
/// Returns mutable references to many indices at once.
///
/// Returns an error if any index is out-of-bounds, or if the same index was
/// passed more than once.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let v = &mut [1, 2, 3];
/// if let Ok([a, b]) = v.get_many_mut([0, 2]) {
/// *a = 413;
/// *b = 612;
/// }
/// assert_eq!(v, &[413, 2, 612]);
/// ```
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub fn get_many_mut<const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], GetManyMutError<N>> {
if !get_many_check_valid(&indices, self.len()) {
return Err(GetManyMutError { _private: () });
}
// SAFETY: The `get_many_check_valid()` call checked that all indices
// are disjunct and in bounds.
unsafe { Ok(self.get_many_unchecked_mut(indices)) }
}
}
impl<T, const N: usize> [[T; N]] {
@ -4304,3 +4387,56 @@ impl<T, const N: usize> SlicePattern for [T; N] {
self
}
}
/// This checks every index against each other, and against `len`.
///
/// This will do `binomial(N + 1, 2) = N * (N + 1) / 2 = 0, 1, 3, 6, 10, ..`
/// comparison operations.
fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> bool {
// NB: The optimzer should inline the loops into a sequence
// of instructions without additional branching.
let mut valid = true;
for (i, &idx) in indices.iter().enumerate() {
valid &= idx < len;
for &idx2 in &indices[..i] {
valid &= idx != idx2;
}
}
valid
}
/// The error type returned by [`get_many_mut<N>`][`slice::get_many_mut`].
///
/// It indicates one of two possible errors:
/// - An index is out-of-bounds.
/// - The same index appeared multiple times in the array.
///
/// # Examples
///
/// ```
/// #![feature(get_many_mut)]
///
/// let v = &mut [1, 2, 3];
/// assert!(v.get_many_mut([0, 999]).is_err());
/// assert!(v.get_many_mut([1, 1]).is_err());
/// ```
#[unstable(feature = "get_many_mut", issue = "104642")]
// NB: The N here is there to be forward-compatible with adding more details
// to the error type at a later point
pub struct GetManyMutError<const N: usize> {
_private: (),
}
#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> fmt::Debug for GetManyMutError<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GetManyMutError").finish_non_exhaustive()
}
}
#[unstable(feature = "get_many_mut", issue = "104642")]
impl<const N: usize> fmt::Display for GetManyMutError<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt("an index is out of bounds or appeared multiple times in the array", f)
}
}

View File

@ -108,6 +108,7 @@
#![feature(provide_any)]
#![feature(utf8_chunks)]
#![feature(is_ascii_octdigit)]
#![feature(get_many_mut)]
#![deny(unsafe_op_in_unsafe_fn)]
#![deny(fuzzy_provenance_casts)]

View File

@ -2595,3 +2595,63 @@ fn test_flatten_mut_size_overflow() {
let x = &mut [[(); usize::MAX]; 2][..];
let _ = x.flatten_mut();
}
#[test]
fn test_get_many_mut_normal_2() {
let mut v = vec![1, 2, 3, 4, 5];
let [a, b] = v.get_many_mut([3, 0]).unwrap();
*a += 10;
*b += 100;
assert_eq!(v, vec![101, 2, 3, 14, 5]);
}
#[test]
fn test_get_many_mut_normal_3() {
let mut v = vec![1, 2, 3, 4, 5];
let [a, b, c] = v.get_many_mut([0, 4, 2]).unwrap();
*a += 10;
*b += 100;
*c += 1000;
assert_eq!(v, vec![11, 2, 1003, 4, 105]);
}
#[test]
fn test_get_many_mut_empty() {
let mut v = vec![1, 2, 3, 4, 5];
let [] = v.get_many_mut([]).unwrap();
assert_eq!(v, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_get_many_mut_single_first() {
let mut v = vec![1, 2, 3, 4, 5];
let [a] = v.get_many_mut([0]).unwrap();
*a += 10;
assert_eq!(v, vec![11, 2, 3, 4, 5]);
}
#[test]
fn test_get_many_mut_single_last() {
let mut v = vec![1, 2, 3, 4, 5];
let [a] = v.get_many_mut([4]).unwrap();
*a += 10;
assert_eq!(v, vec![1, 2, 3, 4, 15]);
}
#[test]
fn test_get_many_mut_oob_nonempty() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([5]).is_err());
}
#[test]
fn test_get_many_mut_oob_empty() {
let mut v: Vec<i32> = vec![];
assert!(v.get_many_mut([0]).is_err());
}
#[test]
fn test_get_many_mut_duplicate() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([1, 3, 3, 4]).is_err());
}

View File

@ -347,6 +347,7 @@
#![feature(stdsimd)]
#![feature(test)]
#![feature(trace_macros)]
#![feature(get_many_mut)]
//
// Only used in tests/benchmarks:
//