Rollup merge of #118569 - blyxxyz:platform-os-str-slice, r=Mark-Simulacrum

Move `OsStr::slice_encoded_bytes` validation to platform modules

This delegates OS string slicing (`OsStr::slice_encoded_bytes`) validation to the underlying platform implementation. For now that results in increased performance and better error messages on Windows without any changes to semantics. In the future we may want to provide different semantics for different platforms.

The existing implementation is still used on Unix and most other platforms and is now optimized a little better.

Tracking issue: https://github.com/rust-lang/rust/issues/118485

cc `@epage,` `@BurntSushi`
This commit is contained in:
Matthias Krüger 2024-02-18 18:54:32 +01:00 committed by GitHub
commit 99560a428a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 219 additions and 47 deletions

View File

@ -127,6 +127,11 @@
//! trait, which provides a [`from_wide`] method to convert a native Windows
//! string (without the terminating nul character) to an [`OsString`].
//!
//! ## Other platforms
//!
//! Many other platforms provide their own extension traits in a
//! `std::os::*::ffi` module.
//!
//! ## On all platforms
//!
//! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
@ -135,6 +140,8 @@
//! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
//! [`OsStr::from_encoded_bytes_unchecked`].
//!
//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
//!
//! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
//! [Unicode code point]: https://www.unicode.org/glossary/#code_point
//! [`env::set_var()`]: crate::env::set_var "env::set_var"

View File

@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
use crate::ops::{self, Range};
use crate::rc::Rc;
use crate::slice;
use crate::str::{from_utf8 as str_from_utf8, FromStr};
use crate::str::FromStr;
use crate::sync::Arc;
use crate::sys::os_str::{Buf, Slice};
@ -997,42 +997,15 @@ impl OsStr {
/// ```
#[unstable(feature = "os_str_slice", issue = "118485")]
pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
#[track_caller]
fn check_valid_boundary(bytes: &[u8], index: usize) {
if index == 0 || index == bytes.len() {
return;
}
// Fast path
if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
return;
}
let (before, after) = bytes.split_at(index);
// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str_from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}
for len in 2..=4.min(index) {
let before = &before[index - len..];
if str_from_utf8(before).is_ok() {
return;
}
}
panic!("byte index {index} is not an OsStr boundary");
}
let encoded_bytes = self.as_encoded_bytes();
let Range { start, end } = slice::range(range, ..encoded_bytes.len());
check_valid_boundary(encoded_bytes, start);
check_valid_boundary(encoded_bytes, end);
// `check_public_boundary` should panic if the index does not lie on an
// `OsStr` boundary as described above. It's possible to do this in an
// encoding-agnostic way, but details of the internal encoding might
// permit a more efficient implementation.
self.inner.check_public_boundary(start);
self.inner.check_public_boundary(end);
// SAFETY: `slice::range` ensures that `start` and `end` are valid
let slice = unsafe { encoded_bytes.get_unchecked(start..end) };

View File

@ -194,15 +194,65 @@ fn slice_encoded_bytes() {
}
#[test]
#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
#[should_panic]
fn slice_out_of_bounds() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..5);
}
#[test]
#[should_panic]
fn slice_mid_char() {
let crab = OsStr::new("🦀");
let _ = crab.slice_encoded_bytes(..2);
}
#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_invalid_data() {
use crate::os::unix::ffi::OsStrExt;
let os_string = OsStr::from_bytes(b"\xFF\xFF");
let _ = os_string.slice_encoded_bytes(1..);
}
#[cfg(unix)]
#[test]
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
fn slice_partial_utf8() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push(part_crab);
let _ = os_string.slice_encoded_bytes(1..);
}
#[cfg(unix)]
#[test]
fn slice_invalid_edge() {
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
let os_string = OsStr::from_bytes(b"a\xFFa");
assert_eq!(os_string.slice_encoded_bytes(..1), "a");
assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
assert_eq!(os_string.slice_encoded_bytes(2..), "a");
let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));
let mut os_string = OsString::from_vec(vec![0xFF]);
os_string.push("🦀");
assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
}
#[cfg(windows)]
#[test]
#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn slice_between_surrogates() {
use crate::os::windows::ffi::OsStringExt;
@ -216,10 +266,14 @@ fn slice_between_surrogates() {
fn slice_surrogate_edge() {
use crate::os::windows::ffi::OsStringExt;
let os_string = OsString::from_wide(&[0xD800]);
let mut with_crab = os_string.clone();
with_crab.push("🦀");
let surrogate = OsString::from_wide(&[0xD800]);
let mut pre_crab = surrogate.clone();
pre_crab.push("🦀");
assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");
assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
let mut post_crab = OsString::from("🦀");
post_crab.push(&surrogate);
assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
}

View File

@ -211,6 +211,49 @@ impl Slice {
unsafe { mem::transmute(s) }
}
#[track_caller]
#[inline]
pub fn check_public_boundary(&self, index: usize) {
if index == 0 || index == self.inner.len() {
return;
}
if index < self.inner.len()
&& (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
{
return;
}
slow_path(&self.inner, index);
/// We're betting that typical splits will involve an ASCII character.
///
/// Putting the expensive checks in a separate function generates notably
/// better assembly.
#[track_caller]
#[inline(never)]
fn slow_path(bytes: &[u8], index: usize) {
let (before, after) = bytes.split_at(index);
// UTF-8 takes at most 4 bytes per codepoint, so we don't
// need to check more than that.
let after = after.get(..4).unwrap_or(after);
match str::from_utf8(after) {
Ok(_) => return,
Err(err) if err.valid_up_to() != 0 => return,
Err(_) => (),
}
for len in 2..=4.min(index) {
let before = &before[index - len..];
if str::from_utf8(before).is_ok() {
return;
}
}
panic!("byte index {index} is not an OsStr boundary");
}
}
#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }

View File

@ -6,7 +6,7 @@ use crate::fmt;
use crate::mem;
use crate::rc::Rc;
use crate::sync::Arc;
use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
use crate::sys_common::{AsInner, FromInner, IntoInner};
#[derive(Clone, Hash)]
@ -171,6 +171,11 @@ impl Slice {
mem::transmute(Wtf8::from_bytes_unchecked(s))
}
#[track_caller]
pub fn check_public_boundary(&self, index: usize) {
check_utf8_boundary(&self.inner, index);
}
#[inline]
pub fn from_str(s: &str) -> &Slice {
unsafe { mem::transmute(Wtf8::from_str(s)) }

View File

@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
unsafe { char::from_u32_unchecked(code_point) }
}
/// Copied from core::str::StrPrelude::is_char_boundary
/// Copied from str::is_char_boundary
#[inline]
pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
if index == slice.len() {
if index == 0 {
return true;
}
match slice.bytes.get(index) {
None => false,
Some(&b) => b < 128 || b >= 192,
None => index == slice.len(),
Some(&b) => (b as i8) >= -0x40,
}
}
/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
/// (i.e. a codepoint that's not a surrogate) or of the whole string.
///
/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
/// we do not permit it in the public API because WTF-8 is considered an
/// implementation detail.
#[track_caller]
#[inline]
pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
if index == 0 {
return;
}
match slice.bytes.get(index) {
Some(0xED) => (), // Might be a surrogate
Some(&b) if (b as i8) >= -0x40 => return,
Some(_) => panic!("byte index {index} is not a codepoint boundary"),
None if index == slice.len() => return,
None => panic!("byte index {index} is out of bounds"),
}
if slice.bytes[index + 1] >= 0xA0 {
// There's a surrogate after index. Now check before index.
if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
panic!("byte index {index} lies between surrogate codepoints");
}
}
}

View File

@ -663,3 +663,65 @@ fn wtf8_to_owned() {
assert_eq!(string.bytes, b"\xED\xA0\x80");
assert!(!string.is_known_utf8);
}
#[test]
fn wtf8_valid_utf8_boundaries() {
let mut string = Wtf8Buf::from_str("aé 💩");
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 0);
check_utf8_boundary(&string, 1);
check_utf8_boundary(&string, 3);
check_utf8_boundary(&string, 4);
check_utf8_boundary(&string, 8);
check_utf8_boundary(&string, 14);
assert_eq!(string.len(), 14);
string.push_char('a');
check_utf8_boundary(&string, 14);
check_utf8_boundary(&string, 15);
let mut string = Wtf8Buf::from_str("a");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);
let mut string = Wtf8Buf::from_str("\u{D7FF}");
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push_char('\u{D7FF}');
check_utf8_boundary(&string, 3);
}
#[test]
#[should_panic(expected = "byte index 4 is out of bounds")]
fn wtf8_utf8_boundary_out_of_bounds() {
let string = Wtf8::from_str("");
check_utf8_boundary(&string, 4);
}
#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_codepoint() {
let string = Wtf8::from_str("é");
check_utf8_boundary(&string, 1);
}
#[test]
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
fn wtf8_utf8_boundary_inside_surrogate() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 1);
}
#[test]
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
fn wtf8_utf8_boundary_between_surrogates() {
let mut string = Wtf8Buf::new();
string.push(CodePoint::from_u32(0xD800).unwrap());
string.push(CodePoint::from_u32(0xD800).unwrap());
check_utf8_boundary(&string, 3);
}