Skip to content

Commit 7408a98

Browse files
Merge pull request #92 from Kijewski/pr-escape-at-compile-time
derive: escape strings at compile-time when possible
2 parents ab485df + 7ea3484 commit 7408a98

File tree

7 files changed

+200
-74
lines changed

7 files changed

+200
-74
lines changed

rinja/src/filters/escape.rs

+6-62
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::convert::Infallible;
22
use std::fmt::{self, Display, Formatter, Write};
3-
use std::num::NonZeroU8;
43
use std::{borrow, str};
54

65
/// Marks a string (or other `Display` type) as safe
@@ -83,69 +82,14 @@ pub fn e(text: impl fmt::Display, escaper: impl Escaper) -> Result<Safe<impl Dis
8382
pub struct Html;
8483

8584
impl Escaper for Html {
86-
fn write_escaped_str<W: Write>(&self, mut fmt: W, string: &str) -> fmt::Result {
87-
let mut escaped_buf = *b"&#__;";
88-
let mut last = 0;
89-
90-
for (index, byte) in string.bytes().enumerate() {
91-
const MIN_CHAR: u8 = b'"';
92-
const MAX_CHAR: u8 = b'>';
93-
94-
struct Table {
95-
_align: [usize; 0],
96-
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
97-
}
98-
99-
const TABLE: Table = {
100-
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
101-
let n0 = match NonZeroU8::new(c / 10 + b'0') {
102-
Some(n) => n,
103-
None => panic!(),
104-
};
105-
let n1 = match NonZeroU8::new(c % 10 + b'0') {
106-
Some(n) => n,
107-
None => panic!(),
108-
};
109-
Some([n0, n1])
110-
}
111-
112-
let mut table = Table {
113-
_align: [],
114-
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
115-
};
116-
117-
table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
118-
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
119-
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
120-
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
121-
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
122-
table
123-
};
124-
125-
let escaped = match byte {
126-
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
127-
_ => None,
128-
};
129-
if let Some(escaped) = escaped {
130-
escaped_buf[2] = escaped[0].get();
131-
escaped_buf[3] = escaped[1].get();
132-
fmt.write_str(&string[last..index])?;
133-
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
134-
last = index + 1;
135-
}
136-
}
137-
fmt.write_str(&string[last..])
85+
#[inline]
86+
fn write_escaped_str<W: Write>(&self, fmt: W, string: &str) -> fmt::Result {
87+
crate::html::write_escaped_str(fmt, string)
13888
}
13989

140-
fn write_escaped_char<W: Write>(&self, mut fmt: W, c: char) -> fmt::Result {
141-
fmt.write_str(match (c.is_ascii(), c as u8) {
142-
(true, b'"') => "&#34;",
143-
(true, b'&') => "&#38;",
144-
(true, b'\'') => "&#39;",
145-
(true, b'<') => "&#60;",
146-
(true, b'>') => "&#62;",
147-
_ => return fmt.write_char(c),
148-
})
90+
#[inline]
91+
fn write_escaped_char<W: Write>(&self, fmt: W, c: char) -> fmt::Result {
92+
crate::html::write_escaped_char(fmt, c)
14993
}
15094
}
15195

rinja/src/html.rs

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
use std::fmt;
2+
use std::num::NonZeroU8;
3+
4+
#[allow(unused)]
5+
pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result {
6+
let mut escaped_buf = *b"&#__;";
7+
let mut last = 0;
8+
9+
for (index, byte) in string.bytes().enumerate() {
10+
let escaped = match byte {
11+
MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize],
12+
_ => None,
13+
};
14+
if let Some(escaped) = escaped {
15+
escaped_buf[2] = escaped[0].get();
16+
escaped_buf[3] = escaped[1].get();
17+
fmt.write_str(&string[last..index])?;
18+
fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?;
19+
last = index + 1;
20+
}
21+
}
22+
fmt.write_str(&string[last..])
23+
}
24+
25+
#[allow(unused)]
26+
pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result {
27+
fmt.write_str(match (c.is_ascii(), c as u8) {
28+
(true, b'"') => "&#34;",
29+
(true, b'&') => "&#38;",
30+
(true, b'\'') => "&#39;",
31+
(true, b'<') => "&#60;",
32+
(true, b'>') => "&#62;",
33+
_ => return fmt.write_char(c),
34+
})
35+
}
36+
37+
const MIN_CHAR: u8 = b'"';
38+
const MAX_CHAR: u8 = b'>';
39+
40+
struct Table {
41+
_align: [usize; 0],
42+
lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize],
43+
}
44+
45+
const TABLE: Table = {
46+
const fn n(c: u8) -> Option<[NonZeroU8; 2]> {
47+
assert!(MIN_CHAR <= c && c <= MAX_CHAR);
48+
49+
let n0 = match NonZeroU8::new(c / 10 + b'0') {
50+
Some(n) => n,
51+
None => panic!(),
52+
};
53+
let n1 = match NonZeroU8::new(c % 10 + b'0') {
54+
Some(n) => n,
55+
None => panic!(),
56+
};
57+
Some([n0, n1])
58+
}
59+
60+
let mut table = Table {
61+
_align: [],
62+
lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize],
63+
};
64+
65+
table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"');
66+
table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&');
67+
table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\'');
68+
table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<');
69+
table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>');
70+
table
71+
};

rinja/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
mod error;
5858
pub mod filters;
5959
pub mod helpers;
60+
mod html;
6061

6162
use std::{fmt, io};
6263

rinja_derive/src/generator.rs

+85-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use quote::quote;
1414

1515
use crate::config::WhitespaceHandling;
1616
use crate::heritage::{Context, Heritage};
17+
use crate::html::write_escaped_str;
1718
use crate::input::{Source, TemplateInput};
1819
use crate::{CompileError, MsgValidEscapers, CRATE};
1920

@@ -1162,8 +1163,76 @@ impl<'a> Generator<'a> {
11621163
}
11631164

11641165
fn write_expr(&mut self, ws: Ws, s: &'a WithSpan<'a, Expr<'a>>) {
1166+
// In here, we inspect in the expression if it is a literal, and if it is, whether it
1167+
// can be escaped at compile time. We use an IIFE to make the code more readable
1168+
// (immediate returns, try expressions).
1169+
let writable = (|| -> Option<Writable<'a>> {
1170+
enum InputKind<'a> {
1171+
StrLit(&'a str),
1172+
CharLit(&'a str),
1173+
}
1174+
enum OutputKind {
1175+
Html,
1176+
Text,
1177+
}
1178+
1179+
// for now, we only escape strings and chars at compile time
1180+
let lit = match &**s {
1181+
Expr::StrLit(input) => InputKind::StrLit(input),
1182+
Expr::CharLit(input) => InputKind::CharLit(input),
1183+
_ => return None,
1184+
};
1185+
1186+
// we only optimize for known escapers
1187+
let output = match self.input.escaper.strip_prefix(CRATE)? {
1188+
"::filters::Html" => OutputKind::Html,
1189+
"::filters::Text" => OutputKind::Text,
1190+
_ => return None,
1191+
};
1192+
1193+
// the input could be string escaped if it contains any backslashes
1194+
let escaped = match lit {
1195+
InputKind::StrLit(s) => s,
1196+
InputKind::CharLit(s) => s,
1197+
};
1198+
let unescaped = if escaped.find('\\').is_none() {
1199+
// if the literal does not contain any backslashes, then it does not need unescaping
1200+
Cow::Borrowed(escaped)
1201+
} else {
1202+
// convert the input into a TokenStream and extract the first token
1203+
Cow::Owned(match lit {
1204+
InputKind::StrLit(escaped) => {
1205+
let input = format!(r#""{escaped}""#);
1206+
let input = input.parse().ok()?;
1207+
let input = syn::parse2::<syn::LitStr>(input).ok()?;
1208+
input.value()
1209+
}
1210+
InputKind::CharLit(escaped) => {
1211+
let input = format!(r#"'{escaped}'"#);
1212+
let input = input.parse().ok()?;
1213+
let input = syn::parse2::<syn::LitChar>(input).ok()?;
1214+
input.value().to_string()
1215+
}
1216+
})
1217+
};
1218+
1219+
// escape the un-string-escaped input using the selected escaper
1220+
Some(Writable::Lit(match output {
1221+
OutputKind::Text => unescaped,
1222+
OutputKind::Html => {
1223+
let mut escaped = String::with_capacity(unescaped.len() + 20);
1224+
write_escaped_str(&mut escaped, &unescaped).ok()?;
1225+
match escaped == unescaped {
1226+
true => unescaped,
1227+
false => Cow::Owned(escaped),
1228+
}
1229+
}
1230+
}))
1231+
})()
1232+
.unwrap_or(Writable::Expr(s));
1233+
11651234
self.handle_ws(ws);
1166-
self.buf_writable.push(Writable::Expr(s));
1235+
self.buf_writable.push(writable);
11671236
}
11681237

11691238
// Write expression buffer and empty
@@ -1174,7 +1243,7 @@ impl<'a> Generator<'a> {
11741243
) -> Result<usize, CompileError> {
11751244
let mut size_hint = 0;
11761245
let items = mem::take(&mut self.buf_writable.buf);
1177-
let mut it = items.into_iter().enumerate().peekable();
1246+
let mut it = items.iter().enumerate().peekable();
11781247

11791248
while let Some((_, Writable::Lit(s))) = it.peek() {
11801249
size_hint += buf.write_writer(s);
@@ -1267,20 +1336,23 @@ impl<'a> Generator<'a> {
12671336
assert!(rws.is_empty());
12681337
self.next_ws = Some(lws);
12691338
}
1270-
WhitespaceHandling::Preserve => self.buf_writable.push(Writable::Lit(lws)),
1339+
WhitespaceHandling::Preserve => {
1340+
self.buf_writable.push(Writable::Lit(Cow::Borrowed(lws)))
1341+
}
12711342
WhitespaceHandling::Minimize => {
1272-
self.buf_writable
1273-
.push(Writable::Lit(match lws.contains('\n') {
1343+
self.buf_writable.push(Writable::Lit(Cow::Borrowed(
1344+
match lws.contains('\n') {
12741345
true => "\n",
12751346
false => " ",
1276-
}));
1347+
},
1348+
)));
12771349
}
12781350
}
12791351
}
12801352

12811353
if !val.is_empty() {
12821354
self.skip_ws = WhitespaceHandling::Preserve;
1283-
self.buf_writable.push(Writable::Lit(val));
1355+
self.buf_writable.push(Writable::Lit(Cow::Borrowed(val)));
12841356
}
12851357

12861358
if !rws.is_empty() {
@@ -2031,17 +2103,18 @@ impl<'a> Generator<'a> {
20312103
WhitespaceHandling::Preserve => {
20322104
let val = self.next_ws.unwrap();
20332105
if !val.is_empty() {
2034-
self.buf_writable.push(Writable::Lit(val));
2106+
self.buf_writable.push(Writable::Lit(Cow::Borrowed(val)));
20352107
}
20362108
}
20372109
WhitespaceHandling::Minimize => {
20382110
let val = self.next_ws.unwrap();
20392111
if !val.is_empty() {
2040-
self.buf_writable
2041-
.push(Writable::Lit(match val.contains('\n') {
2112+
self.buf_writable.push(Writable::Lit(Cow::Borrowed(
2113+
match val.contains('\n') {
20422114
true => "\n",
20432115
false => " ",
2044-
}));
2116+
},
2117+
)));
20452118
}
20462119
}
20472120
WhitespaceHandling::Suppress => {}
@@ -2481,7 +2554,7 @@ impl<'a> Deref for WritableBuffer<'a> {
24812554

24822555
#[derive(Debug)]
24832556
enum Writable<'a> {
2484-
Lit(&'a str),
2557+
Lit(Cow<'a, str>),
24852558
Expr(&'a WithSpan<'a, Expr<'a>>),
24862559
}
24872560

rinja_derive/src/html.rs

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../rinja/src/html.rs

rinja_derive/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
mod config;
55
mod generator;
66
mod heritage;
7+
mod html;
78
mod input;
89
#[cfg(test)]
910
mod tests;

rinja_derive/src/tests.rs

+35
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,38 @@ fn check_bool_conditions() {
426426
3,
427427
);
428428
}
429+
430+
#[test]
431+
fn check_escaping_at_compile_time() {
432+
compare(
433+
r#"The card is
434+
{%- match suit %}
435+
{%- when Suit::Clubs or Suit::Spades -%}
436+
{{ " black" }}
437+
{%- when Suit::Diamonds or Suit::Hearts -%}
438+
{{ " red" }}
439+
{%- endmatch %}"#,
440+
r#"writer.write_str("The card is")?;
441+
match &self.suit {
442+
Suit::Clubs | Suit::Spades => {
443+
writer.write_str(" black")?;
444+
}
445+
Suit::Diamonds | Suit::Hearts => {
446+
writer.write_str(" red")?;
447+
}
448+
}"#,
449+
&[("suit", "Suit")],
450+
16,
451+
);
452+
453+
compare(
454+
r#"{{ '\x41' }}{{ '\n' }}{{ '\r' }}{{ '\t' }}{{ '\\' }}{{ '\u{2665}' }}{{ '\'' }}{{ '\"' }}{{ '"' }}
455+
{{ "\x41\n\r\t\\\u{2665}\'\"'" }}"#,
456+
r#"writer.write_str("A
457+
\r \\♥'\"\"
458+
A
459+
\r \\♥'\"'")?;"#,
460+
&[],
461+
23,
462+
);
463+
}

0 commit comments

Comments
 (0)