Make internal mask implementation safe

This commit is contained in:
Caleb Zulawski 2022-02-09 04:54:05 +00:00 committed by Jubilee Young
parent 11c3eefa35
commit 20fa4b7623
3 changed files with 75 additions and 30 deletions

View File

@ -1,7 +1,7 @@
#![allow(unused_imports)]
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
use core::marker::PhantomData;
/// A mask where each lane is represented by a single bit.
@ -116,13 +116,20 @@ where
}
#[inline]
pub unsafe fn to_bitmask_integer<U>(self) -> U {
pub fn to_bitmask_integer<U>(self) -> U
where
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { core::mem::transmute_copy(&self.0) }
}
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
#[inline]
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
where
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
{
// Safety: these are the same types
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
}

View File

@ -2,7 +2,7 @@
use super::MaskElement;
use crate::simd::intrinsics;
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
#[repr(transparent)]
pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
@ -66,6 +66,23 @@ where
}
}
// Used for bitmask bit order workaround
pub(crate) trait ReverseBits {
fn reverse_bits(self) -> Self;
}
macro_rules! impl_reverse_bits {
{ $($int:ty),* } => {
$(
impl ReverseBits for $int {
fn reverse_bits(self) -> Self { <$int>::reverse_bits(self) }
}
)*
}
}
impl_reverse_bits! { u8, u16, u32, u64 }
impl<T, const LANES: usize> Mask<T, LANES>
where
T: MaskElement,
@ -110,16 +127,34 @@ where
}
#[inline]
pub unsafe fn to_bitmask_integer<U>(self) -> U {
// Safety: caller must only return bitmask types
unsafe { intrinsics::simd_bitmask(self.0) }
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
where
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
{
// Safety: U is required to be the appropriate bitmask type
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
// LLVM assumes bit order should match endianness
if cfg!(target_endian = "big") {
bitmask.reverse_bits()
} else {
bitmask
}
}
// Safety: U must be the integer with the exact number of bits required to hold the bitmask for
// this mask
#[inline]
pub unsafe fn from_bitmask_integer<U>(bitmask: U) -> Self {
// Safety: caller must only pass bitmask types
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
where
super::Mask<T, LANES>: ToBitMask<BitMask = U>,
{
// LLVM assumes bit order should match endianness
let bitmask = if cfg!(target_endian = "big") {
bitmask.reverse_bits()
} else {
bitmask
};
// Safety: U is required to be the appropriate bitmask type
unsafe {
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
bitmask,

View File

@ -1,9 +1,26 @@
use super::{mask_impl, Mask, MaskElement};
use crate::simd::{LaneCount, SupportedLaneCount};
mod sealed {
pub trait Sealed {}
}
pub use sealed::Sealed;
impl<T, const LANES: usize> Sealed for Mask<T, LANES>
where
T: MaskElement,
LaneCount<LANES>: SupportedLaneCount,
{
}
/// Converts masks to and from integer bitmasks.
///
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB.
pub trait ToBitMask {
///
/// # Safety
/// This trait is `unsafe` and sealed, since the `BitMask` type must match the number of lanes in
/// the mask.
pub unsafe trait ToBitMask: Sealed {
/// The integer bitmask type.
type BitMask;
@ -14,32 +31,18 @@ pub trait ToBitMask {
fn from_bitmask(bitmask: Self::BitMask) -> Self;
}
/// Converts masks to and from byte array bitmasks.
///
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte.
pub trait ToBitMaskArray {
/// The length of the bitmask array.
const BYTES: usize;
/// Converts a mask to a bitmask.
fn to_bitmask_array(self) -> [u8; Self::BYTES];
/// Converts a bitmask to a mask.
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self;
}
macro_rules! impl_integer_intrinsic {
{ $(unsafe impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
$(
impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
unsafe impl<T: MaskElement> ToBitMask for Mask<T, $lanes> {
type BitMask = $int;
fn to_bitmask(self) -> $int {
unsafe { self.0.to_bitmask_integer() }
self.0.to_bitmask_integer()
}
fn from_bitmask(bitmask: $int) -> Self {
unsafe { Self(mask_impl::Mask::from_bitmask_integer(bitmask)) }
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
}
}
)*