Skip to content

Commit 51a7396

Browse files
committed
Move OsStr::slice_encoded_bytes validation to platform modules
On Windows and UEFI this improves performance and error messaging. On other platforms we optimize the fast path a bit more. This also prepares for later relaxing the checks on certain platforms.
1 parent d9d89fd commit 51a7396

File tree

7 files changed

+219
-47
lines changed

7 files changed

+219
-47
lines changed

library/std/src/ffi/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,11 @@
127127
//! trait, which provides a [`from_wide`] method to convert a native Windows
128128
//! string (without the terminating nul character) to an [`OsString`].
129129
//!
130+
//! ## Other platforms
131+
//!
132+
//! Many other platforms provide their own extension traits in a
133+
//! `std::os::*::ffi` module.
134+
//!
130135
//! ## On all platforms
131136
//!
132137
//! On all platforms, [`OsStr`] consists of a sequence of bytes that is encoded as a superset of
@@ -135,6 +140,8 @@
135140
//! For limited, inexpensive conversions from and to bytes, see [`OsStr::as_encoded_bytes`] and
136141
//! [`OsStr::from_encoded_bytes_unchecked`].
137142
//!
143+
//! For basic string processing, see [`OsStr::slice_encoded_bytes`].
144+
//!
138145
//! [Unicode scalar value]: https://www.unicode.org/glossary/#unicode_scalar_value
139146
//! [Unicode code point]: https://www.unicode.org/glossary/#code_point
140147
//! [`env::set_var()`]: crate::env::set_var "env::set_var"

library/std/src/ffi/os_str.rs

+8-35
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::hash::{Hash, Hasher};
99
use crate::ops::{self, Range};
1010
use crate::rc::Rc;
1111
use crate::slice;
12-
use crate::str::{from_utf8 as str_from_utf8, FromStr};
12+
use crate::str::FromStr;
1313
use crate::sync::Arc;
1414

1515
use crate::sys::os_str::{Buf, Slice};
@@ -997,42 +997,15 @@ impl OsStr {
997997
/// ```
998998
#[unstable(feature = "os_str_slice", issue = "118485")]
999999
pub fn slice_encoded_bytes<R: ops::RangeBounds<usize>>(&self, range: R) -> &Self {
1000-
#[track_caller]
1001-
fn check_valid_boundary(bytes: &[u8], index: usize) {
1002-
if index == 0 || index == bytes.len() {
1003-
return;
1004-
}
1005-
1006-
// Fast path
1007-
if bytes[index - 1].is_ascii() || bytes[index].is_ascii() {
1008-
return;
1009-
}
1010-
1011-
let (before, after) = bytes.split_at(index);
1012-
1013-
// UTF-8 takes at most 4 bytes per codepoint, so we don't
1014-
// need to check more than that.
1015-
let after = after.get(..4).unwrap_or(after);
1016-
match str_from_utf8(after) {
1017-
Ok(_) => return,
1018-
Err(err) if err.valid_up_to() != 0 => return,
1019-
Err(_) => (),
1020-
}
1021-
1022-
for len in 2..=4.min(index) {
1023-
let before = &before[index - len..];
1024-
if str_from_utf8(before).is_ok() {
1025-
return;
1026-
}
1027-
}
1028-
1029-
panic!("byte index {index} is not an OsStr boundary");
1030-
}
1031-
10321000
let encoded_bytes = self.as_encoded_bytes();
10331001
let Range { start, end } = slice::range(range, ..encoded_bytes.len());
1034-
check_valid_boundary(encoded_bytes, start);
1035-
check_valid_boundary(encoded_bytes, end);
1002+
1003+
// `check_public_boundary` should panic if the index does not lie on an
1004+
// `OsStr` boundary as described above. It's possible to do this in an
1005+
// encoding-agnostic way, but details of the internal encoding might
1006+
// permit a more efficient implementation.
1007+
self.inner.check_public_boundary(start);
1008+
self.inner.check_public_boundary(end);
10361009

10371010
// SAFETY: `slice::range` ensures that `start` and `end` are valid
10381011
let slice = unsafe { encoded_bytes.get_unchecked(start..end) };

library/std/src/ffi/os_str/tests.rs

+61-7
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,65 @@ fn slice_encoded_bytes() {
194194
}
195195

196196
#[test]
197-
#[should_panic(expected = "byte index 2 is not an OsStr boundary")]
197+
#[should_panic]
198+
fn slice_out_of_bounds() {
199+
let crab = OsStr::new("🦀");
200+
let _ = crab.slice_encoded_bytes(..5);
201+
}
202+
203+
#[test]
204+
#[should_panic]
198205
fn slice_mid_char() {
199206
let crab = OsStr::new("🦀");
200207
let _ = crab.slice_encoded_bytes(..2);
201208
}
202209

210+
#[cfg(unix)]
211+
#[test]
212+
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
213+
fn slice_invalid_data() {
214+
use crate::os::unix::ffi::OsStrExt;
215+
216+
let os_string = OsStr::from_bytes(b"\xFF\xFF");
217+
let _ = os_string.slice_encoded_bytes(1..);
218+
}
219+
220+
#[cfg(unix)]
221+
#[test]
222+
#[should_panic(expected = "byte index 1 is not an OsStr boundary")]
223+
fn slice_partial_utf8() {
224+
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
225+
226+
let part_crab = OsStr::from_bytes(&"🦀".as_bytes()[..3]);
227+
let mut os_string = OsString::from_vec(vec![0xFF]);
228+
os_string.push(part_crab);
229+
let _ = os_string.slice_encoded_bytes(1..);
230+
}
231+
232+
#[cfg(unix)]
233+
#[test]
234+
fn slice_invalid_edge() {
235+
use crate::os::unix::ffi::{OsStrExt, OsStringExt};
236+
237+
let os_string = OsStr::from_bytes(b"a\xFFa");
238+
assert_eq!(os_string.slice_encoded_bytes(..1), "a");
239+
assert_eq!(os_string.slice_encoded_bytes(1..), OsStr::from_bytes(b"\xFFa"));
240+
assert_eq!(os_string.slice_encoded_bytes(..2), OsStr::from_bytes(b"a\xFF"));
241+
assert_eq!(os_string.slice_encoded_bytes(2..), "a");
242+
243+
let os_string = OsStr::from_bytes(&"abc🦀".as_bytes()[..6]);
244+
assert_eq!(os_string.slice_encoded_bytes(..3), "abc");
245+
assert_eq!(os_string.slice_encoded_bytes(3..), OsStr::from_bytes(b"\xF0\x9F\xA6"));
246+
247+
let mut os_string = OsString::from_vec(vec![0xFF]);
248+
os_string.push("🦀");
249+
assert_eq!(os_string.slice_encoded_bytes(..1), OsStr::from_bytes(b"\xFF"));
250+
assert_eq!(os_string.slice_encoded_bytes(1..), "🦀");
251+
}
252+
203253
#[cfg(windows)]
204254
#[test]
205-
#[should_panic(expected = "byte index 3 is not an OsStr boundary")]
255+
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
206256
fn slice_between_surrogates() {
207257
use crate::os::windows::ffi::OsStringExt;
208258

@@ -216,10 +266,14 @@ fn slice_between_surrogates() {
216266
fn slice_surrogate_edge() {
217267
use crate::os::windows::ffi::OsStringExt;
218268

219-
let os_string = OsString::from_wide(&[0xD800]);
220-
let mut with_crab = os_string.clone();
221-
with_crab.push("🦀");
269+
let surrogate = OsString::from_wide(&[0xD800]);
270+
let mut pre_crab = surrogate.clone();
271+
pre_crab.push("🦀");
272+
assert_eq!(pre_crab.slice_encoded_bytes(..3), surrogate);
273+
assert_eq!(pre_crab.slice_encoded_bytes(3..), "🦀");
222274

223-
assert_eq!(with_crab.slice_encoded_bytes(..3), os_string);
224-
assert_eq!(with_crab.slice_encoded_bytes(3..), "🦀");
275+
let mut post_crab = OsString::from("🦀");
276+
post_crab.push(&surrogate);
277+
assert_eq!(post_crab.slice_encoded_bytes(..4), "🦀");
278+
assert_eq!(post_crab.slice_encoded_bytes(4..), surrogate);
225279
}

library/std/src/sys/os_str/bytes.rs

+43
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,49 @@ impl Slice {
211211
unsafe { mem::transmute(s) }
212212
}
213213

214+
#[track_caller]
215+
#[inline]
216+
pub fn check_public_boundary(&self, index: usize) {
217+
if index == 0 || index == self.inner.len() {
218+
return;
219+
}
220+
if index < self.inner.len()
221+
&& (self.inner[index - 1].is_ascii() || self.inner[index].is_ascii())
222+
{
223+
return;
224+
}
225+
226+
slow_path(&self.inner, index);
227+
228+
/// We're betting that typical splits will involve an ASCII character.
229+
///
230+
/// Putting the expensive checks in a separate function generates notably
231+
/// better assembly.
232+
#[track_caller]
233+
#[inline(never)]
234+
fn slow_path(bytes: &[u8], index: usize) {
235+
let (before, after) = bytes.split_at(index);
236+
237+
// UTF-8 takes at most 4 bytes per codepoint, so we don't
238+
// need to check more than that.
239+
let after = after.get(..4).unwrap_or(after);
240+
match str::from_utf8(after) {
241+
Ok(_) => return,
242+
Err(err) if err.valid_up_to() != 0 => return,
243+
Err(_) => (),
244+
}
245+
246+
for len in 2..=4.min(index) {
247+
let before = &before[index - len..];
248+
if str::from_utf8(before).is_ok() {
249+
return;
250+
}
251+
}
252+
253+
panic!("byte index {index} is not an OsStr boundary");
254+
}
255+
}
256+
214257
#[inline]
215258
pub fn from_str(s: &str) -> &Slice {
216259
unsafe { Slice::from_encoded_bytes_unchecked(s.as_bytes()) }

library/std/src/sys/os_str/wtf8.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::fmt;
66
use crate::mem;
77
use crate::rc::Rc;
88
use crate::sync::Arc;
9-
use crate::sys_common::wtf8::{Wtf8, Wtf8Buf};
9+
use crate::sys_common::wtf8::{check_utf8_boundary, Wtf8, Wtf8Buf};
1010
use crate::sys_common::{AsInner, FromInner, IntoInner};
1111

1212
#[derive(Clone, Hash)]
@@ -171,6 +171,11 @@ impl Slice {
171171
mem::transmute(Wtf8::from_bytes_unchecked(s))
172172
}
173173

174+
#[track_caller]
175+
pub fn check_public_boundary(&self, index: usize) {
176+
check_utf8_boundary(&self.inner, index);
177+
}
178+
174179
#[inline]
175180
pub fn from_str(s: &str) -> &Slice {
176181
unsafe { mem::transmute(Wtf8::from_str(s)) }

library/std/src/sys_common/wtf8.rs

+32-4
Original file line numberDiff line numberDiff line change
@@ -885,15 +885,43 @@ fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
885885
unsafe { char::from_u32_unchecked(code_point) }
886886
}
887887

888-
/// Copied from core::str::StrPrelude::is_char_boundary
888+
/// Copied from str::is_char_boundary
889889
#[inline]
890890
pub fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {
891-
if index == slice.len() {
891+
if index == 0 {
892892
return true;
893893
}
894894
match slice.bytes.get(index) {
895-
None => false,
896-
Some(&b) => b < 128 || b >= 192,
895+
None => index == slice.len(),
896+
Some(&b) => (b as i8) >= -0x40,
897+
}
898+
}
899+
900+
/// Verify that `index` is at the edge of either a valid UTF-8 codepoint
901+
/// (i.e. a codepoint that's not a surrogate) or of the whole string.
902+
///
903+
/// These are the cases currently permitted by `OsStr::slice_encoded_bytes`.
904+
/// Splitting between surrogates is valid as far as WTF-8 is concerned, but
905+
/// we do not permit it in the public API because WTF-8 is considered an
906+
/// implementation detail.
907+
#[track_caller]
908+
#[inline]
909+
pub fn check_utf8_boundary(slice: &Wtf8, index: usize) {
910+
if index == 0 {
911+
return;
912+
}
913+
match slice.bytes.get(index) {
914+
Some(0xED) => (), // Might be a surrogate
915+
Some(&b) if (b as i8) >= -0x40 => return,
916+
Some(_) => panic!("byte index {index} is not a codepoint boundary"),
917+
None if index == slice.len() => return,
918+
None => panic!("byte index {index} is out of bounds"),
919+
}
920+
if slice.bytes[index + 1] >= 0xA0 {
921+
// There's a surrogate after index. Now check before index.
922+
if index >= 3 && slice.bytes[index - 3] == 0xED && slice.bytes[index - 2] >= 0xA0 {
923+
panic!("byte index {index} lies between surrogate codepoints");
924+
}
897925
}
898926
}
899927

library/std/src/sys_common/wtf8/tests.rs

+62
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,65 @@ fn wtf8_to_owned() {
663663
assert_eq!(string.bytes, b"\xED\xA0\x80");
664664
assert!(!string.is_known_utf8);
665665
}
666+
667+
#[test]
668+
fn wtf8_valid_utf8_boundaries() {
669+
let mut string = Wtf8Buf::from_str("aé 💩");
670+
string.push(CodePoint::from_u32(0xD800).unwrap());
671+
string.push(CodePoint::from_u32(0xD800).unwrap());
672+
check_utf8_boundary(&string, 0);
673+
check_utf8_boundary(&string, 1);
674+
check_utf8_boundary(&string, 3);
675+
check_utf8_boundary(&string, 4);
676+
check_utf8_boundary(&string, 8);
677+
check_utf8_boundary(&string, 14);
678+
assert_eq!(string.len(), 14);
679+
680+
string.push_char('a');
681+
check_utf8_boundary(&string, 14);
682+
check_utf8_boundary(&string, 15);
683+
684+
let mut string = Wtf8Buf::from_str("a");
685+
string.push(CodePoint::from_u32(0xD800).unwrap());
686+
check_utf8_boundary(&string, 1);
687+
688+
let mut string = Wtf8Buf::from_str("\u{D7FF}");
689+
string.push(CodePoint::from_u32(0xD800).unwrap());
690+
check_utf8_boundary(&string, 3);
691+
692+
let mut string = Wtf8Buf::new();
693+
string.push(CodePoint::from_u32(0xD800).unwrap());
694+
string.push_char('\u{D7FF}');
695+
check_utf8_boundary(&string, 3);
696+
}
697+
698+
#[test]
699+
#[should_panic(expected = "byte index 4 is out of bounds")]
700+
fn wtf8_utf8_boundary_out_of_bounds() {
701+
let string = Wtf8::from_str("aé");
702+
check_utf8_boundary(&string, 4);
703+
}
704+
705+
#[test]
706+
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
707+
fn wtf8_utf8_boundary_inside_codepoint() {
708+
let string = Wtf8::from_str("é");
709+
check_utf8_boundary(&string, 1);
710+
}
711+
712+
#[test]
713+
#[should_panic(expected = "byte index 1 is not a codepoint boundary")]
714+
fn wtf8_utf8_boundary_inside_surrogate() {
715+
let mut string = Wtf8Buf::new();
716+
string.push(CodePoint::from_u32(0xD800).unwrap());
717+
check_utf8_boundary(&string, 1);
718+
}
719+
720+
#[test]
721+
#[should_panic(expected = "byte index 3 lies between surrogate codepoints")]
722+
fn wtf8_utf8_boundary_between_surrogates() {
723+
let mut string = Wtf8Buf::new();
724+
string.push(CodePoint::from_u32(0xD800).unwrap());
725+
string.push(CodePoint::from_u32(0xD800).unwrap());
726+
check_utf8_boundary(&string, 3);
727+
}

0 commit comments

Comments
 (0)