Auto merge of #3662 - RalfJung:simd-bitmask, r=RalfJung

simd_bitmask: work correctly for sizes like 24
This commit is contained in:
bors 2024-06-09 09:48:26 +00:00
commit de822dc602
4 changed files with 130 additions and 51 deletions

View File

@ -374,7 +374,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let val = if dest.layout().abi.is_signed() {
Scalar::from_int(i, dest.layout().size)
} else {
Scalar::from_uint(u64::try_from(i.into()).unwrap(), dest.layout().size)
// `unwrap` can only fail here if `i` is negative
Scalar::from_uint(u128::try_from(i.into()).unwrap(), dest.layout().size)
};
self.eval_context_mut().write_scalar(val, dest)
}

View File

@ -1,3 +1,5 @@
#![warn(clippy::arithmetic_side_effects)]
mod atomic;
mod simd;

View File

@ -458,26 +458,48 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
);
}
// The mask must be an integer or an array.
assert!(
mask.layout.ty.is_integral()
|| matches!(mask.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
);
assert_eq!(bitmask_len, mask.layout.size.bits());
assert_eq!(dest_len, yes_len);
assert_eq!(dest_len, no_len);
// Read the mask, either as an integer or as an array.
let mask: u64 = match mask.layout.ty.kind() {
ty::Uint(_) => {
// Any larger integer type is fine.
assert!(mask.layout.size.bits() >= bitmask_len);
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
}
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
// The array must have exactly the right size.
assert_eq!(mask.layout.size.bits(), bitmask_len);
// Read the raw bytes.
let mask = mask.assert_mem_place(); // arrays cannot be immediate
let mask_bytes =
this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
// Turn them into a `u64` in the right way.
let mask_size = mask.layout.size.bytes_usize();
let mut mask_arr = [0u8; 8];
match this.data_layout().endian {
Endian::Little => {
// Fill the first N bytes.
mask_arr[..mask_size].copy_from_slice(mask_bytes);
u64::from_le_bytes(mask_arr)
}
Endian::Big => {
// Fill the last N bytes.
let i = mask_arr.len().strict_sub(mask_size);
mask_arr[i..].copy_from_slice(mask_bytes);
u64::from_be_bytes(mask_arr)
}
}
}
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
};
let dest_len = u32::try_from(dest_len).unwrap();
let bitmask_len = u32::try_from(bitmask_len).unwrap();
// To read the mask, we transmute it to an integer.
// That does the right thing wrt endianness.
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
let mask = mask.transmute(mask_ty, this)?;
let mask: u64 = this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap();
for i in 0..dest_len {
let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
let mask = mask & 1u64.checked_shl(bit_i).unwrap();
let mask = mask & 1u64.strict_shl(bit_i);
let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
let dest = this.project_index(&dest, i.into())?;
@ -489,7 +511,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// If the mask is "padded", ensure that padding is all-zero.
// This deliberately does not use `simd_bitmask_index`; these bits are outside
// the bitmask. It does not matter in which order we check them.
let mask = mask & 1u64.checked_shl(i).unwrap();
let mask = mask & 1u64.strict_shl(i);
if mask != 0 {
throw_ub_format!(
"a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
@ -508,28 +530,43 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
);
}
// Returns either an unsigned integer or array of `u8`.
assert!(
dest.layout.ty.is_integral()
|| matches!(dest.layout.ty.kind(), ty::Array(elemty, _) if elemty == &this.tcx.types.u8)
);
assert_eq!(bitmask_len, dest.layout.size.bits());
let op_len = u32::try_from(op_len).unwrap();
let mut res = 0u64;
for i in 0..op_len {
let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
if simd_element_to_bool(op)? {
res |= 1u64
.checked_shl(simd_bitmask_index(i, op_len, this.data_layout().endian))
.unwrap();
let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
res |= 1u64.strict_shl(bit_i);
}
}
// We have to change the type of the place to be able to write `res` into it. This
// transmutes the integer to an array, which does the right thing wrt endianness.
let dest =
dest.transmute(this.machine.layouts.uint(dest.layout.size).unwrap(), this)?;
this.write_int(res, &dest)?;
// Write the result, depending on the `dest` type.
// Returns either an unsigned integer or array of `u8`.
match dest.layout.ty.kind() {
ty::Uint(_) => {
// Any larger integer type is fine, it will be zero-extended.
assert!(dest.layout.size.bits() >= bitmask_len);
this.write_int(res, dest)?;
}
ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
// The array must have exactly the right size.
assert_eq!(dest.layout.size.bits(), bitmask_len);
// We have to write the result byte-for-byte.
let res_size = dest.layout.size.bytes_usize();
let res_bytes;
let res_bytes_slice = match this.data_layout().endian {
Endian::Little => {
res_bytes = res.to_le_bytes();
&res_bytes[..res_size] // take the first N bytes
}
Endian::Big => {
res_bytes = res.to_be_bytes();
&res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
}
};
this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
}
_ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
}
}
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
let [op] = check_arg_count(args)?;
@ -615,8 +652,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let val = if src_index < left_len {
this.read_immediate(&this.project_index(&left, src_index)?)?
} else if src_index < left_len.checked_add(right_len).unwrap() {
let right_idx = src_index.checked_sub(left_len).unwrap();
} else if src_index < left_len.strict_add(right_len) {
let right_idx = src_index.strict_sub(left_len);
this.read_immediate(&this.project_index(&right, right_idx)?)?
} else {
throw_ub_format!(
@ -655,8 +692,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let val = if src_index < left_len {
this.read_immediate(&this.project_index(&left, src_index)?)?
} else if src_index < left_len.checked_add(right_len).unwrap() {
let right_idx = src_index.checked_sub(left_len).unwrap();
} else if src_index < left_len.strict_add(right_len) {
let right_idx = src_index.strict_sub(left_len);
this.read_immediate(&this.project_index(&right, right_idx)?)?
} else {
throw_ub_format!(

View File

@ -323,38 +323,77 @@ fn simd_mask() {
#[repr(simd, packed)]
#[allow(non_camel_case_types)]
#[derive(Copy, Clone, Debug, PartialEq)]
struct i32x10(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32);
struct i32x10([i32; 10]);
impl i32x10 {
fn splat(x: i32) -> Self {
Self(x, x, x, x, x, x, x, x, x, x)
}
fn from_array(a: [i32; 10]) -> Self {
unsafe { std::mem::transmute(a) }
Self([x; 10])
}
}
unsafe {
let mask = i32x10::from_array([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
let mask = i32x10([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0]);
let mask_bits = if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 };
let mask_bytes =
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] };
let bitmask1: u16 = simd_bitmask(mask);
let bitmask2: [u8; 2] = simd_bitmask(mask);
if cfg!(target_endian = "little") {
assert_eq!(bitmask1, 0b0101001011);
assert_eq!(bitmask2, [0b01001011, 0b01]);
} else {
assert_eq!(bitmask1, 0b1101001010);
assert_eq!(bitmask2, [0b11, 0b01001010]);
}
assert_eq!(bitmask1, mask_bits);
assert_eq!(bitmask2, mask_bytes);
let selected1 = simd_select_bitmask::<u16, _>(
if cfg!(target_endian = "little") { 0b0101001011 } else { 0b1101001010 },
mask_bits,
i32x10::splat(!0), // yes
i32x10::splat(0), // no
);
let selected2 = simd_select_bitmask::<[u8; 2], _>(
if cfg!(target_endian = "little") { [0b01001011, 0b01] } else { [0b11, 0b01001010] },
mask_bytes,
i32x10::splat(!0), // yes
i32x10::splat(0), // no
);
assert_eq!(selected1, mask);
assert_eq!(selected2, selected1);
assert_eq!(selected2, mask);
}
// Test for a mask where the next multiple of 8 is not a power of two.
#[repr(simd, packed)]
#[allow(non_camel_case_types)]
#[derive(Copy, Clone, Debug, PartialEq)]
struct i32x20([i32; 20]);
impl i32x20 {
fn splat(x: i32) -> Self {
Self([x; 20])
}
}
unsafe {
let mask = i32x20([!0, !0, 0, !0, 0, 0, !0, 0, !0, 0, 0, 0, 0, !0, !0, !0, !0, !0, !0, !0]);
let mask_bits = if cfg!(target_endian = "little") {
0b11111110000101001011
} else {
0b11010010100001111111
};
let mask_bytes = if cfg!(target_endian = "little") {
[0b01001011, 0b11100001, 0b1111]
} else {
[0b1101, 0b00101000, 0b01111111]
};
let bitmask1: u32 = simd_bitmask(mask);
let bitmask2: [u8; 3] = simd_bitmask(mask);
assert_eq!(bitmask1, mask_bits);
assert_eq!(bitmask2, mask_bytes);
let selected1 = simd_select_bitmask::<u32, _>(
mask_bits,
i32x20::splat(!0), // yes
i32x20::splat(0), // no
);
let selected2 = simd_select_bitmask::<[u8; 3], _>(
mask_bytes,
i32x20::splat(!0), // yes
i32x20::splat(0), // no
);
assert_eq!(selected1, mask);
assert_eq!(selected2, mask);
}
}