diesel/query_builder/insert_statement/
insert_with_default_for_sqlite.rs

1use super::{BatchInsert, InsertStatement};
2use crate::insertable::InsertValues;
3use crate::insertable::{CanInsertInSingleQuery, ColumnInsertValue, DefaultableColumnInsertValue};
4use crate::prelude::*;
5use crate::query_builder::{AstPass, QueryId, ValuesClause};
6use crate::query_builder::{DebugQuery, QueryFragment};
7use crate::query_dsl::methods::ExecuteDsl;
8use crate::sqlite::Sqlite;
9use std::fmt::{self, Debug, Display};
10
11pub trait DebugQueryHelper<ContainsDefaultableValue> {
12    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
13    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
14}
15
16impl<T, V, QId, Op, Ret, const STATIC_QUERY_ID: bool> DebugQueryHelper<Yes>
17    for DebugQuery<
18        '_,
19        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op, Ret>,
20        Sqlite,
21    >
22where
23    V: QueryFragment<Sqlite>,
24    T: Copy + QuerySource,
25    Op: Copy,
26    Ret: Copy,
27    for<'b> InsertStatement<T, &'b ValuesClause<V, T>, Op, Ret>: QueryFragment<Sqlite>,
28{
29    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        let mut statements = vec![String::from("BEGIN")];
31        for record in self.query.records.values.iter() {
32            let stmt = InsertStatement::new(
33                self.query.target,
34                record,
35                self.query.operator,
36                self.query.returning,
37            );
38            statements.push(crate::debug_query(&stmt).to_string());
39        }
40        statements.push("COMMIT".into());
41
42        f.debug_struct("Query")
43            .field("sql", &statements)
44            .field("binds", &[] as &[i32; 0])
45            .finish()
46    }
47
48    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        writeln!(f, "BEGIN;")?;
50        for record in self.query.records.values.iter() {
51            let stmt = InsertStatement::new(
52                self.query.target,
53                record,
54                self.query.operator,
55                self.query.returning,
56            );
57            writeln!(f, "{}", crate::debug_query(&stmt))?;
58        }
59        writeln!(f, "COMMIT;")?;
60        Ok(())
61    }
62}
63
64#[allow(unsafe_code)] // cast to transparent wrapper type
65impl<'a, T, V, QId, Op, const STATIC_QUERY_ID: bool> DebugQueryHelper<No>
66    for DebugQuery<'a, InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>, Sqlite>
67where
68    T: Copy + QuerySource,
69    Op: Copy,
70    DebugQuery<
71        'a,
72        InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
73        Sqlite,
74    >: Debug + Display,
75{
76    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        let value = unsafe {
78            // This cast is safe as `SqliteBatchInsertWrapper` is #[repr(transparent)]
79            &*(self as *const DebugQuery<
80                'a,
81                InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
82                Sqlite,
83            >
84                as *const DebugQuery<
85                    'a,
86                    InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
87                    Sqlite,
88                >)
89        };
90        <_ as Debug>::fmt(value, f)
91    }
92
93    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let value = unsafe {
95            // This cast is safe as `SqliteBatchInsertWrapper` is #[repr(transparent)]
96            &*(self as *const DebugQuery<
97                'a,
98                InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
99                Sqlite,
100            >
101                as *const DebugQuery<
102                    'a,
103                    InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
104                    Sqlite,
105                >)
106        };
107        <_ as Display>::fmt(value, f)
108    }
109}
110
111impl<T, V, QId, Op, O, const STATIC_QUERY_ID: bool> Display
112    for DebugQuery<
113        '_,
114        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
115        Sqlite,
116    >
117where
118    T: QuerySource,
119    V: ContainsDefaultableValue<Out = O>,
120    Self: DebugQueryHelper<O>,
121{
122    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123        self.fmt_display(f)
124    }
125}
126
127impl<T, V, QId, Op, O, const STATIC_QUERY_ID: bool> Debug
128    for DebugQuery<
129        '_,
130        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
131        Sqlite,
132    >
133where
134    T: QuerySource,
135    V: ContainsDefaultableValue<Out = O>,
136    Self: DebugQueryHelper<O>,
137{
138    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139        self.fmt_debug(f)
140    }
141}
142
143#[allow(missing_debug_implementations, missing_copy_implementations)]
144pub struct Yes;
145
146impl Default for Yes {
147    fn default() -> Self {
148        Yes
149    }
150}
151
152#[allow(missing_debug_implementations, missing_copy_implementations)]
153pub struct No;
154
155impl Default for No {
156    fn default() -> Self {
157        No
158    }
159}
160
161pub trait Any<Rhs> {
162    type Out: Any<Yes> + Any<No>;
163}
164
165impl Any<No> for No {
166    type Out = No;
167}
168
169impl Any<Yes> for No {
170    type Out = Yes;
171}
172
173impl Any<No> for Yes {
174    type Out = Yes;
175}
176
177impl Any<Yes> for Yes {
178    type Out = Yes;
179}
180
181pub trait ContainsDefaultableValue {
182    type Out: Any<Yes> + Any<No>;
183}
184
185impl<C, B> ContainsDefaultableValue for ColumnInsertValue<C, B> {
186    type Out = No;
187}
188
189impl<I> ContainsDefaultableValue for DefaultableColumnInsertValue<I> {
190    type Out = Yes;
191}
192
193impl<I, const SIZE: usize> ContainsDefaultableValue for [I; SIZE]
194where
195    I: ContainsDefaultableValue,
196{
197    type Out = I::Out;
198}
199
200impl<I, T> ContainsDefaultableValue for ValuesClause<I, T>
201where
202    I: ContainsDefaultableValue,
203{
204    type Out = I::Out;
205}
206
207impl<T> ContainsDefaultableValue for &T
208where
209    T: ContainsDefaultableValue,
210{
211    type Out = T::Out;
212}
213
214impl<V, T, QId, C, Op, O, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
215    for InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>
216where
217    T: QuerySource,
218    C: Connection<Backend = Sqlite>,
219    V: ContainsDefaultableValue<Out = O>,
220    O: Default,
221    (O, Self): ExecuteDsl<C, Sqlite>,
222{
223    fn execute(query: Self, conn: &mut C) -> QueryResult<usize> {
224        <(O, Self) as ExecuteDsl<C, Sqlite>>::execute((O::default(), query), conn)
225    }
226}
227
228impl<V, T, QId, C, Op, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
229    for (
230        Yes,
231        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
232    )
233where
234    C: Connection<Backend = Sqlite>,
235    T: Table + Copy + QueryId + 'static,
236    T::FromClause: QueryFragment<Sqlite>,
237    Op: Copy + QueryId + QueryFragment<Sqlite>,
238    V: InsertValues<Sqlite, T> + CanInsertInSingleQuery<Sqlite> + QueryId,
239{
240    fn execute((Yes, query): Self, conn: &mut C) -> QueryResult<usize> {
241        conn.transaction(|conn| {
242            let mut result = 0;
243            for record in &query.records.values {
244                let stmt =
245                    InsertStatement::new(query.target, record, query.operator, query.returning);
246                result += stmt.execute(conn)?;
247            }
248            Ok(result)
249        })
250    }
251}
252
253#[allow(missing_debug_implementations, missing_copy_implementations)]
254#[repr(transparent)]
255pub struct SqliteBatchInsertWrapper<V, T, QId, const STATIC_QUERY_ID: bool>(
256    BatchInsert<V, T, QId, STATIC_QUERY_ID>,
257);
258
259impl<V, Tab, QId, const STATIC_QUERY_ID: bool> QueryFragment<Sqlite>
260    for SqliteBatchInsertWrapper<Vec<ValuesClause<V, Tab>>, Tab, QId, STATIC_QUERY_ID>
261where
262    ValuesClause<V, Tab>: QueryFragment<Sqlite>,
263    V: QueryFragment<Sqlite>,
264{
265    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Sqlite>) -> QueryResult<()> {
266        if !STATIC_QUERY_ID {
267            out.unsafe_to_cache_prepared();
268        }
269
270        let mut values = self.0.values.iter();
271        if let Some(value) = values.next() {
272            value.walk_ast(out.reborrow())?;
273        }
274        for value in values {
275            out.push_sql(", (");
276            value.values.walk_ast(out.reborrow())?;
277            out.push_sql(")");
278        }
279        Ok(())
280    }
281}
282
283#[allow(missing_copy_implementations, missing_debug_implementations)]
284#[repr(transparent)]
285pub struct SqliteCanInsertInSingleQueryHelper<T: ?Sized>(T);
286
287impl<V, T, QId, const STATIC_QUERY_ID: bool> CanInsertInSingleQuery<Sqlite>
288    for SqliteBatchInsertWrapper<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>
289where
290    // We constrain that here on an internal helper type
291    // to make sure that this does not accidentally leak
292    // so that none does really implement normal batch
293    // insert for inserts with default values here
294    SqliteCanInsertInSingleQueryHelper<V>: CanInsertInSingleQuery<Sqlite>,
295{
296    fn rows_to_insert(&self) -> Option<usize> {
297        Some(self.0.values.len())
298    }
299}
300
301impl<T> CanInsertInSingleQuery<Sqlite> for SqliteCanInsertInSingleQueryHelper<T>
302where
303    T: CanInsertInSingleQuery<Sqlite>,
304{
305    fn rows_to_insert(&self) -> Option<usize> {
306        self.0.rows_to_insert()
307    }
308}
309
310impl<V, T, QId, const STATIC_QUERY_ID: bool> QueryId
311    for SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>
312where
313    BatchInsert<V, T, QId, STATIC_QUERY_ID>: QueryId,
314{
315    type QueryId = <BatchInsert<V, T, QId, STATIC_QUERY_ID> as QueryId>::QueryId;
316
317    const HAS_STATIC_QUERY_ID: bool =
318        <BatchInsert<V, T, QId, STATIC_QUERY_ID> as QueryId>::HAS_STATIC_QUERY_ID;
319}
320
321impl<V, T, QId, C, Op, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
322    for (
323        No,
324        InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
325    )
326where
327    C: Connection<Backend = Sqlite>,
328    T: Table + QueryId + 'static,
329    T::FromClause: QueryFragment<Sqlite>,
330    Op: QueryFragment<Sqlite> + QueryId,
331    SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>:
332        QueryFragment<Sqlite> + QueryId + CanInsertInSingleQuery<Sqlite>,
333{
334    fn execute((No, query): Self, conn: &mut C) -> QueryResult<usize> {
335        let query = InsertStatement {
336            records: SqliteBatchInsertWrapper(query.records),
337            operator: query.operator,
338            target: query.target,
339            returning: query.returning,
340            into_clause: query.into_clause,
341        };
342        query.execute(conn)
343    }
344}
345
346macro_rules! tuple_impls {
347        ($(
348            $Tuple:tt {
349                $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
350            }
351        )+) => {
352            $(
353                impl_contains_defaultable_value!($($T,)*);
354            )*
355        }
356    }
357
358macro_rules! impl_contains_defaultable_value {
359      (
360        @build
361        start_ts = [$($ST: ident,)*],
362        ts = [$T1: ident,],
363        bounds = [$($bounds: tt)*],
364        out = [$($out: tt)*],
365    )=> {
366        impl<$($ST,)*> ContainsDefaultableValue for ($($ST,)*)
367        where
368            $($ST: ContainsDefaultableValue,)*
369            $($bounds)*
370            $T1::Out: Any<$($out)*>,
371        {
372            type Out = <$T1::Out as Any<$($out)*>>::Out;
373        }
374
375    };
376    (
377        @build
378        start_ts = [$($ST: ident,)*],
379        ts = [$T1: ident, $($T: ident,)+],
380        bounds = [$($bounds: tt)*],
381        out = [$($out: tt)*],
382    )=> {
383        impl_contains_defaultable_value! {
384            @build
385            start_ts = [$($ST,)*],
386            ts = [$($T,)*],
387            bounds = [$($bounds)* $T1::Out: Any<$($out)*>,],
388            out = [<$T1::Out as Any<$($out)*>>::Out],
389        }
390    };
391    ($T1: ident, $($T: ident,)+) => {
392        impl_contains_defaultable_value! {
393            @build
394            start_ts = [$T1, $($T,)*],
395            ts = [$($T,)*],
396            bounds = [],
397            out = [$T1::Out],
398        }
399    };
400    ($T1: ident,) => {
401        impl<$T1> ContainsDefaultableValue for ($T1,)
402        where $T1: ContainsDefaultableValue,
403        {
404            type Out = <$T1 as ContainsDefaultableValue>::Out;
405        }
406    }
407}
408
409diesel_derives::__diesel_for_each_tuple!(tuple_impls);