diff --git a/rinja/src/filters/escape.rs b/rinja/src/filters/escape.rs index 710709adc..14ad83c65 100644 --- a/rinja/src/filters/escape.rs +++ b/rinja/src/filters/escape.rs @@ -1,6 +1,5 @@ use std::convert::Infallible; use std::fmt::{self, Display, Formatter, Write}; -use std::num::NonZeroU8; use std::{borrow, str}; /// Marks a string (or other `Display` type) as safe @@ -83,69 +82,14 @@ pub fn e(text: impl fmt::Display, escaper: impl Escaper) -> Result(&self, mut fmt: W, string: &str) -> fmt::Result { - let mut escaped_buf = *b"&#__;"; - let mut last = 0; - - for (index, byte) in string.bytes().enumerate() { - const MIN_CHAR: u8 = b'"'; - const MAX_CHAR: u8 = b'>'; - - struct Table { - _align: [usize; 0], - lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize], - } - - const TABLE: Table = { - const fn n(c: u8) -> Option<[NonZeroU8; 2]> { - let n0 = match NonZeroU8::new(c / 10 + b'0') { - Some(n) => n, - None => panic!(), - }; - let n1 = match NonZeroU8::new(c % 10 + b'0') { - Some(n) => n, - None => panic!(), - }; - Some([n0, n1]) - } - - let mut table = Table { - _align: [], - lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize], - }; - - table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"'); - table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&'); - table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\''); - table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<'); - table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>'); - table - }; - - let escaped = match byte { - MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize], - _ => None, - }; - if let Some(escaped) = escaped { - escaped_buf[2] = escaped[0].get(); - escaped_buf[3] = escaped[1].get(); - fmt.write_str(&string[last..index])?; - fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?; - last = index + 1; - } - } - fmt.write_str(&string[last..]) + #[inline] + fn write_escaped_str(&self, fmt: W, string: &str) -> fmt::Result { + crate::html::write_escaped_str(fmt, string) } - fn write_escaped_char(&self, mut fmt: W, c: char) -> fmt::Result { - fmt.write_str(match (c.is_ascii(), c as u8) { - (true, b'"') => """, - (true, b'&') => "&", - (true, b'\'') => "'", - (true, b'<') => "<", - (true, b'>') => ">", - _ => return fmt.write_char(c), - }) + #[inline] + fn write_escaped_char(&self, fmt: W, c: char) -> fmt::Result { + crate::html::write_escaped_char(fmt, c) } } diff --git a/rinja/src/html.rs b/rinja/src/html.rs new file mode 100644 index 000000000..45c1d7270 --- /dev/null +++ b/rinja/src/html.rs @@ -0,0 +1,71 @@ +use std::fmt; +use std::num::NonZeroU8; + +#[allow(unused)] +pub(crate) fn write_escaped_str(mut fmt: impl fmt::Write, string: &str) -> fmt::Result { + let mut escaped_buf = *b"&#__;"; + let mut last = 0; + + for (index, byte) in string.bytes().enumerate() { + let escaped = match byte { + MIN_CHAR..=MAX_CHAR => TABLE.lookup[(byte - MIN_CHAR) as usize], + _ => None, + }; + if let Some(escaped) = escaped { + escaped_buf[2] = escaped[0].get(); + escaped_buf[3] = escaped[1].get(); + fmt.write_str(&string[last..index])?; + fmt.write_str(unsafe { std::str::from_utf8_unchecked(escaped_buf.as_slice()) })?; + last = index + 1; + } + } + fmt.write_str(&string[last..]) +} + +#[allow(unused)] +pub(crate) fn write_escaped_char(mut fmt: impl fmt::Write, c: char) -> fmt::Result { + fmt.write_str(match (c.is_ascii(), c as u8) { + (true, b'"') => """, + (true, b'&') => "&", + (true, b'\'') => "'", + (true, b'<') => "<", + (true, b'>') => ">", + _ => return fmt.write_char(c), + }) +} + +const MIN_CHAR: u8 = b'"'; +const MAX_CHAR: u8 = b'>'; + +struct Table { + _align: [usize; 0], + lookup: [Option<[NonZeroU8; 2]>; (MAX_CHAR - MIN_CHAR + 1) as usize], +} + +const TABLE: Table = { + const fn n(c: u8) -> Option<[NonZeroU8; 2]> { + assert!(MIN_CHAR <= c && c <= MAX_CHAR); + + let n0 = match NonZeroU8::new(c / 10 + b'0') { + Some(n) => n, + None => panic!(), + }; + let n1 = match NonZeroU8::new(c % 10 + b'0') { + Some(n) => n, + None => panic!(), + }; + Some([n0, n1]) + } + + let mut table = Table { + _align: [], + lookup: [None; (MAX_CHAR - MIN_CHAR + 1) as usize], + }; + + table.lookup[(b'"' - MIN_CHAR) as usize] = n(b'"'); + table.lookup[(b'&' - MIN_CHAR) as usize] = n(b'&'); + table.lookup[(b'\'' - MIN_CHAR) as usize] = n(b'\''); + table.lookup[(b'<' - MIN_CHAR) as usize] = n(b'<'); + table.lookup[(b'>' - MIN_CHAR) as usize] = n(b'>'); + table +}; diff --git a/rinja/src/lib.rs b/rinja/src/lib.rs index cc81796a2..c84212461 100644 --- a/rinja/src/lib.rs +++ b/rinja/src/lib.rs @@ -57,6 +57,7 @@ mod error; pub mod filters; pub mod helpers; +mod html; use std::{fmt, io}; diff --git a/rinja_derive/src/generator.rs b/rinja_derive/src/generator.rs index fa0ef79a1..009069b7b 100644 --- a/rinja_derive/src/generator.rs +++ b/rinja_derive/src/generator.rs @@ -14,6 +14,7 @@ use quote::quote; use crate::config::WhitespaceHandling; use crate::heritage::{Context, Heritage}; +use crate::html::write_escaped_str; use crate::input::{Source, TemplateInput}; use crate::{CompileError, MsgValidEscapers, CRATE}; @@ -1162,8 +1163,76 @@ impl<'a> Generator<'a> { } fn write_expr(&mut self, ws: Ws, s: &'a WithSpan<'a, Expr<'a>>) { + // In here, we inspect in the expression if it is a literal, and if it is, whether it + // can be escaped at compile time. We use an IIFE to make the code more readable + // (immediate returns, try expressions). + let writable = (|| -> Option> { + enum InputKind<'a> { + StrLit(&'a str), + CharLit(&'a str), + } + enum OutputKind { + Html, + Text, + } + + // for now, we only escape strings and chars at compile time + let lit = match &**s { + Expr::StrLit(input) => InputKind::StrLit(input), + Expr::CharLit(input) => InputKind::CharLit(input), + _ => return None, + }; + + // we only optimize for known escapers + let output = match self.input.escaper.strip_prefix(CRATE)? { + "::filters::Html" => OutputKind::Html, + "::filters::Text" => OutputKind::Text, + _ => return None, + }; + + // the input could be string escaped if it contains any backslashes + let escaped = match lit { + InputKind::StrLit(s) => s, + InputKind::CharLit(s) => s, + }; + let unescaped = if escaped.find('\\').is_none() { + // if the literal does not contain any backslashes, then it does not need unescaping + Cow::Borrowed(escaped) + } else { + // convert the input into a TokenStream and extract the first token + Cow::Owned(match lit { + InputKind::StrLit(escaped) => { + let input = format!(r#""{escaped}""#); + let input = input.parse().ok()?; + let input = syn::parse2::(input).ok()?; + input.value() + } + InputKind::CharLit(escaped) => { + let input = format!(r#"'{escaped}'"#); + let input = input.parse().ok()?; + let input = syn::parse2::(input).ok()?; + input.value().to_string() + } + }) + }; + + // escape the un-string-escaped input using the selected escaper + Some(Writable::Lit(match output { + OutputKind::Text => unescaped, + OutputKind::Html => { + let mut escaped = String::with_capacity(unescaped.len() + 20); + write_escaped_str(&mut escaped, &unescaped).ok()?; + match escaped == unescaped { + true => unescaped, + false => Cow::Owned(escaped), + } + } + })) + })() + .unwrap_or(Writable::Expr(s)); + self.handle_ws(ws); - self.buf_writable.push(Writable::Expr(s)); + self.buf_writable.push(writable); } // Write expression buffer and empty @@ -1174,7 +1243,7 @@ impl<'a> Generator<'a> { ) -> Result { let mut size_hint = 0; let items = mem::take(&mut self.buf_writable.buf); - let mut it = items.into_iter().enumerate().peekable(); + let mut it = items.iter().enumerate().peekable(); while let Some((_, Writable::Lit(s))) = it.peek() { size_hint += buf.write_writer(s); @@ -1267,20 +1336,23 @@ impl<'a> Generator<'a> { assert!(rws.is_empty()); self.next_ws = Some(lws); } - WhitespaceHandling::Preserve => self.buf_writable.push(Writable::Lit(lws)), + WhitespaceHandling::Preserve => { + self.buf_writable.push(Writable::Lit(Cow::Borrowed(lws))) + } WhitespaceHandling::Minimize => { - self.buf_writable - .push(Writable::Lit(match lws.contains('\n') { + self.buf_writable.push(Writable::Lit(Cow::Borrowed( + match lws.contains('\n') { true => "\n", false => " ", - })); + }, + ))); } } } if !val.is_empty() { self.skip_ws = WhitespaceHandling::Preserve; - self.buf_writable.push(Writable::Lit(val)); + self.buf_writable.push(Writable::Lit(Cow::Borrowed(val))); } if !rws.is_empty() { @@ -2031,17 +2103,18 @@ impl<'a> Generator<'a> { WhitespaceHandling::Preserve => { let val = self.next_ws.unwrap(); if !val.is_empty() { - self.buf_writable.push(Writable::Lit(val)); + self.buf_writable.push(Writable::Lit(Cow::Borrowed(val))); } } WhitespaceHandling::Minimize => { let val = self.next_ws.unwrap(); if !val.is_empty() { - self.buf_writable - .push(Writable::Lit(match val.contains('\n') { + self.buf_writable.push(Writable::Lit(Cow::Borrowed( + match val.contains('\n') { true => "\n", false => " ", - })); + }, + ))); } } WhitespaceHandling::Suppress => {} @@ -2481,7 +2554,7 @@ impl<'a> Deref for WritableBuffer<'a> { #[derive(Debug)] enum Writable<'a> { - Lit(&'a str), + Lit(Cow<'a, str>), Expr(&'a WithSpan<'a, Expr<'a>>), } diff --git a/rinja_derive/src/html.rs b/rinja_derive/src/html.rs new file mode 120000 index 000000000..a4f1066eb --- /dev/null +++ b/rinja_derive/src/html.rs @@ -0,0 +1 @@ +../../rinja/src/html.rs \ No newline at end of file diff --git a/rinja_derive/src/lib.rs b/rinja_derive/src/lib.rs index 5d3fea316..38db6dd98 100644 --- a/rinja_derive/src/lib.rs +++ b/rinja_derive/src/lib.rs @@ -4,6 +4,7 @@ mod config; mod generator; mod heritage; +mod html; mod input; #[cfg(test)] mod tests; diff --git a/rinja_derive/src/tests.rs b/rinja_derive/src/tests.rs index e03de6708..524470997 100644 --- a/rinja_derive/src/tests.rs +++ b/rinja_derive/src/tests.rs @@ -426,3 +426,38 @@ fn check_bool_conditions() { 3, ); } + +#[test] +fn check_escaping_at_compile_time() { + compare( + r#"The card is + {%- match suit %} + {%- when Suit::Clubs or Suit::Spades -%} + {{ " black" }} + {%- when Suit::Diamonds or Suit::Hearts -%} + {{ " red" }} + {%- endmatch %}"#, + r#"writer.write_str("The card is")?; + match &self.suit { + Suit::Clubs | Suit::Spades => { + writer.write_str(" black")?; + } + Suit::Diamonds | Suit::Hearts => { + writer.write_str(" red")?; + } + }"#, + &[("suit", "Suit")], + 16, + ); + + compare( + r#"{{ '\x41' }}{{ '\n' }}{{ '\r' }}{{ '\t' }}{{ '\\' }}{{ '\u{2665}' }}{{ '\'' }}{{ '\"' }}{{ '"' }} +{{ "\x41\n\r\t\\\u{2665}\'\"'" }}"#, + r#"writer.write_str("A +\r \\♥'\"\" +A +\r \\♥'\"'")?;"#, + &[], + 23, + ); +}