Skip to content

Commit bb79170

Browse files
committed
Another improvment to the sqlite bind code
Any rust container like `Box<T>`, `Vec<T>` or `String<T>` internally contains a `Unique<T>` pointer, which communicates to the compiler that this container is the owner of that memory location and all access goes through that pointer. See rust-lang/unsafe-code-guidelines#194 for details. Passing out a pointer to the underlying buffer to sqlite could cause UB according to this definition, at least if someone else accesses the buffer through the originial pointer. To prevent that we temporarily leak the Buffer and manage the pointer by ourself. Additionally this change introduces a way to construct the `BoundStatement` as early as possible as part of the `BoundStatement::bind` function, so that all cleanup code can be concetracted in the corresponding `Drop` impl
1 parent 10cf0e5 commit bb79170

File tree

1 file changed

+139
-116
lines changed
  • diesel/src/sqlite/connection

1 file changed

+139
-116
lines changed

diesel/src/sqlite/connection/stmt.rs

+139-116
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ use crate::sqlite::{Sqlite, SqliteType};
1111
use crate::util::OnceCell;
1212
use std::ffi::{CStr, CString};
1313
use std::io::{stderr, Write};
14-
use std::mem::ManuallyDrop;
1514
use std::os::raw as libc;
1615
use std::ptr::{self, NonNull};
1716

@@ -58,9 +57,10 @@ impl Statement {
5857
unsafe fn bind(
5958
&mut self,
6059
tpe: SqliteType,
61-
value: &SqliteBindValue<'_>,
60+
value: SqliteBindValue<'_>,
6261
bind_index: i32,
63-
) -> QueryResult<()> {
62+
) -> QueryResult<Option<NonNull<[u8]>>> {
63+
let mut ret_ptr = None;
6464
let result = match (tpe, value) {
6565
(_, SqliteBindValue::Null) => {
6666
ffi::sqlite3_bind_null(self.inner_statement.as_ptr(), bind_index)
@@ -72,47 +72,87 @@ impl Statement {
7272
bytes.len() as libc::c_int,
7373
ffi::SQLITE_STATIC(),
7474
),
75-
(SqliteType::Binary, SqliteBindValue::Binary(bytes)) => ffi::sqlite3_bind_blob(
76-
self.inner_statement.as_ptr(),
77-
bind_index,
78-
bytes.as_ptr() as *const libc::c_void,
79-
bytes.len() as libc::c_int,
80-
ffi::SQLITE_STATIC(),
81-
),
75+
(SqliteType::Binary, SqliteBindValue::Binary(mut bytes)) => {
76+
let len = bytes.len();
77+
// We need a seperate pointer here to pass it to sqlite
78+
// as the returned pointer is a pointer to a dyn sized **slice**
79+
// and not the pointer to the first element of the slice
80+
let ptr;
81+
ret_ptr = if len > 0 {
82+
ptr = bytes.as_mut_ptr();
83+
NonNull::new(Box::into_raw(bytes))
84+
} else {
85+
ptr = std::ptr::null_mut();
86+
None
87+
};
88+
ffi::sqlite3_bind_blob(
89+
self.inner_statement.as_ptr(),
90+
bind_index,
91+
ptr as *const libc::c_void,
92+
len as libc::c_int,
93+
ffi::SQLITE_STATIC(),
94+
)
95+
}
8296
(SqliteType::Text, SqliteBindValue::BorrowedString(bytes)) => ffi::sqlite3_bind_text(
8397
self.inner_statement.as_ptr(),
8498
bind_index,
8599
bytes.as_ptr() as *const libc::c_char,
86100
bytes.len() as libc::c_int,
87101
ffi::SQLITE_STATIC(),
88102
),
89-
(SqliteType::Text, SqliteBindValue::String(bytes)) => ffi::sqlite3_bind_text(
90-
self.inner_statement.as_ptr(),
91-
bind_index,
92-
bytes.as_ptr() as *const libc::c_char,
93-
bytes.len() as libc::c_int,
94-
ffi::SQLITE_STATIC(),
95-
),
103+
(SqliteType::Text, SqliteBindValue::String(bytes)) => {
104+
let mut bytes = Box::<[u8]>::from(bytes);
105+
let len = bytes.len();
106+
// We need a seperate pointer here to pass it to sqlite
107+
// as the returned pointer is a pointer to a dyn sized **slice**
108+
// and not the pointer to the first element of the slice
109+
let ptr;
110+
ret_ptr = if len > 0 {
111+
ptr = bytes.as_mut_ptr();
112+
NonNull::new(Box::into_raw(bytes))
113+
} else {
114+
ptr = std::ptr::null_mut();
115+
None
116+
};
117+
ffi::sqlite3_bind_text(
118+
self.inner_statement.as_ptr(),
119+
bind_index,
120+
ptr as *const libc::c_char,
121+
len as libc::c_int,
122+
ffi::SQLITE_STATIC(),
123+
)
124+
}
96125
(SqliteType::Float, SqliteBindValue::F64(value))
97126
| (SqliteType::Double, SqliteBindValue::F64(value)) => ffi::sqlite3_bind_double(
98127
self.inner_statement.as_ptr(),
99128
bind_index,
100-
*value as libc::c_double,
129+
value as libc::c_double,
101130
),
102131
(SqliteType::SmallInt, SqliteBindValue::I32(value))
103132
| (SqliteType::Integer, SqliteBindValue::I32(value)) => {
104-
ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, *value)
133+
ffi::sqlite3_bind_int(self.inner_statement.as_ptr(), bind_index, value)
105134
}
106135
(SqliteType::Long, SqliteBindValue::I64(value)) => {
107-
ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, *value)
136+
ffi::sqlite3_bind_int64(self.inner_statement.as_ptr(), bind_index, value)
108137
}
109138
(t, b) => {
110139
return Err(Error::SerializationError(
111140
format!("Type missmatch: Expected {:?}, got {}", t, b).into(),
112141
))
113142
}
114143
};
115-
ensure_sqlite_ok(result, self.raw_connection())
144+
match ensure_sqlite_ok(result, self.raw_connection()) {
145+
Ok(()) => Ok(ret_ptr),
146+
Err(e) => {
147+
if let Some(ptr) = ret_ptr {
148+
// This is a `NonNul` ptr so it cannot be null
149+
// It points to a slice internally as we did not apply
150+
// any cast above.
151+
std::mem::drop(Box::from_raw(ptr.as_ptr()))
152+
}
153+
Err(e)
154+
}
155+
}
116156
}
117157

118158
fn reset(&mut self) {
@@ -180,152 +220,135 @@ impl Drop for Statement {
180220
}
181221
}
182222

223+
// A warning for future editiors:
224+
// Changing this code to something "simplier" may
225+
// introduce undefined behaviour. Make sure you read
226+
// the following discussions for details about
227+
// the current version:
228+
//
229+
// * https://github.com/weiznich/diesel/pull/7
230+
// * https://users.rust-lang.org/t/code-review-for-unsafe-code-in-diesel/66798/
231+
// * https://github.com/rust-lang/unsafe-code-guidelines/issues/194
183232
struct BoundStatement<'stmt, 'query> {
184233
statement: MaybeCached<'stmt, Statement>,
185234
// we need to store the query here to ensure noone does
186235
// drop it till the end ot the statement
187236
// We use a boxed queryfragment here just to erase the
188-
// generic type, we use ManuallyDrop to communicate
237+
// generic type, we use NonNull to communicate
189238
// that this is a shared buffer
190-
query: ManuallyDrop<Box<dyn QueryFragment<Sqlite> + 'query>>,
239+
query: Option<NonNull<dyn QueryFragment<Sqlite> + 'query>>,
191240
// we need to store any owned bind values speratly, as they are not
192-
// contained in the query itself. We use ManuallyDrop to
241+
// contained in the query itself. We use NonNull to
193242
// communicate that this is a shared buffer
194-
binds_to_free: ManuallyDrop<Vec<(i32, Option<SqliteBindValue<'static>>)>>,
243+
binds_to_free: Vec<(i32, Option<NonNull<[u8]>>)>,
195244
}
196245

197246
impl<'stmt, 'query> BoundStatement<'stmt, 'query> {
198247
fn bind<T>(
199-
mut statement: MaybeCached<'stmt, Statement>,
248+
statement: MaybeCached<'stmt, Statement>,
200249
query: T,
201250
) -> QueryResult<BoundStatement<'stmt, 'query>>
202251
where
203252
T: QueryFragment<Sqlite> + QueryId + 'query,
204253
{
205254
// Don't use a trait object here to prevent using a virtual function call
206255
// For sqlite this can introduce a measurable overhead
207-
let mut query = ManuallyDrop::new(Box::new(query));
256+
let query = Box::new(query);
208257

209258
let mut bind_collector = SqliteBindCollector::new();
210259
query.collect_binds(&mut bind_collector, &mut ())?;
211260
let SqliteBindCollector { binds } = bind_collector;
212261

213-
let binds_to_free = match Self::bind_buffers(binds, &mut statement) {
214-
Ok(value) => value,
215-
Err(e) => {
216-
unsafe {
217-
// We return from this function afterwards and
218-
// any buffer is already unbound by `bind_buffers`
219-
// so it's safe to drop query now
220-
ManuallyDrop::drop(&mut query);
221-
}
222-
return Err(e);
223-
}
224-
};
225-
226-
Ok(Self {
262+
let mut ret = BoundStatement {
227263
statement,
228-
binds_to_free,
229-
query: ManuallyDrop::new(
230-
// Cast to a trait object here, to erase the generic parameter T
231-
ManuallyDrop::into_inner(query) as Box<dyn QueryFragment<Sqlite> + 'query>,
264+
query: None,
265+
binds_to_free: Vec::with_capacity(
266+
binds
267+
.iter()
268+
.filter(|&(b, _)| {
269+
matches!(
270+
b,
271+
SqliteBindValue::BorrowedBinary(_)
272+
| SqliteBindValue::BorrowedString(_)
273+
| SqliteBindValue::String(_)
274+
| SqliteBindValue::Binary(_)
275+
)
276+
})
277+
.count(),
232278
),
233-
})
279+
};
280+
281+
ret.bind_buffers(binds)?;
282+
283+
let query = query as Box<dyn QueryFragment<Sqlite> + 'query>;
284+
ret.query = NonNull::new(Box::into_raw(query));
285+
286+
Ok(ret)
234287
}
235288

236289
// This is a seperate function so that
237290
// not the whole construtor is generic over the query type T.
238291
// This hopefully prevents binary bloat.
239-
fn bind_buffers(
240-
binds: Vec<(SqliteBindValue<'_>, SqliteType)>,
241-
statement: &mut MaybeCached<'stmt, Statement>,
242-
) -> QueryResult<ManuallyDrop<Vec<(i32, Option<SqliteBindValue<'static>>)>>> {
243-
let mut binds_to_free = ManuallyDrop::new(Vec::with_capacity(
244-
binds
245-
.iter()
246-
.filter(|&(b, _)| {
247-
matches!(
248-
b,
249-
SqliteBindValue::BorrowedBinary(_)
250-
| SqliteBindValue::BorrowedString(_)
251-
| SqliteBindValue::String(_)
252-
| SqliteBindValue::Binary(_)
253-
)
254-
})
255-
.count(),
256-
));
292+
fn bind_buffers(&mut self, binds: Vec<(SqliteBindValue<'_>, SqliteType)>) -> QueryResult<()> {
257293
for (bind_idx, (bind, tpe)) in (1..).zip(binds) {
294+
if matches!(
295+
bind,
296+
SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_)
297+
) {
298+
// Store the id's of borrowed binds to unbind them on drop
299+
self.binds_to_free.push((bind_idx, None));
300+
}
301+
258302
// It's safe to call bind here as:
259303
// * The type and value matches
260304
// * We ensure that corresponding buffers lives long enough below
261305
// * The statement is not used yet by `step` or anything else
262-
let res = unsafe { statement.bind(tpe, &bind, bind_idx) };
263-
264-
if let Err(e) = res {
265-
Self::unbind_buffers(statement, &binds_to_free);
266-
unsafe {
267-
// It's safe to drop binds_to_free here as
268-
// we've already unbound the buffers
269-
ManuallyDrop::drop(&mut binds_to_free);
270-
}
271-
return Err(e);
272-
}
273-
274-
// We want to unbind the buffers later to ensure
275-
// that sqlite does not access uninitilized memory
276-
match bind {
277-
SqliteBindValue::BorrowedString(_) | SqliteBindValue::BorrowedBinary(_) => {
278-
binds_to_free.push((bind_idx, None));
279-
}
280-
SqliteBindValue::Binary(b) => {
281-
binds_to_free.push((bind_idx, Some(SqliteBindValue::Binary(b))));
282-
}
283-
SqliteBindValue::String(b) => {
284-
binds_to_free.push((bind_idx, Some(SqliteBindValue::String(b))));
285-
}
286-
SqliteBindValue::I32(_)
287-
| SqliteBindValue::I64(_)
288-
| SqliteBindValue::F64(_)
289-
| SqliteBindValue::Null => {}
306+
let res = unsafe { self.statement.bind(tpe, bind, bind_idx) }?;
307+
if let Some(ptr) = res {
308+
// Store the id + pointer for a owned bind
309+
// as we must unbind and free them on drop
310+
self.binds_to_free.push((bind_idx, Some(ptr)));
290311
}
291312
}
292-
Ok(binds_to_free)
313+
Ok(())
293314
}
315+
}
316+
317+
impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> {
318+
fn drop(&mut self) {
319+
// First reset the statement, otherwise the bind calls
320+
// below will fails
321+
self.statement.reset();
294322

295-
fn unbind_buffers(
296-
stmt: &mut MaybeCached<'stmt, Statement>,
297-
binds_to_free: &[(i32, Option<SqliteBindValue<'static>>)],
298-
) {
299-
for (idx, _buffer) in binds_to_free {
323+
for (idx, buffer) in std::mem::take(&mut self.binds_to_free) {
300324
unsafe {
301325
// It's always safe to bind null values, as there is no buffer that needs to outlife something
302-
stmt.bind(SqliteType::Text, &SqliteBindValue::Null, *idx)
326+
self.statement
327+
.bind(SqliteType::Text, SqliteBindValue::Null, idx)
303328
.expect(
304329
"Binding a null value should never fail. \
305330
If you ever see this error message please open \
306331
an issue at diesels issue tracker containing \
307332
code how to trigger this message.",
308333
);
309334
}
310-
}
311-
}
312-
}
313335

314-
impl<'stmt, 'query> Drop for BoundStatement<'stmt, 'query> {
315-
fn drop(&mut self) {
316-
// First reset the statement, otherwise the bind calls
317-
// below will fails
318-
self.statement.reset();
336+
if let Some(buffer) = buffer {
337+
unsafe {
338+
// Constructing the `Box` here is safe as we
339+
// got the pointer from a box + it is guarenteed to be not null.
340+
std::mem::drop(Box::from_raw(buffer.as_ptr()));
341+
}
342+
}
343+
}
319344

320-
// Reset the binds that may point to memory that will be/needs to be freed
321-
Self::unbind_buffers(&mut self.statement, &self.binds_to_free);
322-
unsafe {
323-
// We unbound the corresponding buffers above, so it's fine to drop the
324-
// owned binds now
325-
ManuallyDrop::drop(&mut self.binds_to_free);
326-
// We've dropped everything that could reference the query
327-
// so it's safe to drop the query here
328-
ManuallyDrop::drop(&mut self.query);
345+
if let Some(query) = self.query {
346+
unsafe {
347+
// Constructing the `Box` here is safe as we
348+
// got the pointer from a box + it is guarenteed to be not null.
349+
std::mem::drop(Box::from_raw(query.as_ptr()));
350+
}
351+
self.query = None;
329352
}
330353
}
331354
}

0 commit comments

Comments
 (0)