diesel/query_builder/insert_statement/
insert_with_default_for_sqlite.rs

1use super::{BatchInsert, InsertStatement};
2use crate::insertable::InsertValues;
3use crate::insertable::{CanInsertInSingleQuery, ColumnInsertValue, DefaultableColumnInsertValue};
4use crate::prelude::*;
5use crate::query_builder::upsert::on_conflict_clause::OnConflictValues;
6use crate::query_builder::{AstPass, QueryId, ValuesClause};
7use crate::query_builder::{DebugQuery, QueryFragment};
8use crate::query_dsl::methods::ExecuteDsl;
9use crate::sqlite::Sqlite;
10use std::fmt::{self, Debug, Display};
11
12pub trait DebugQueryHelper<ContainsDefaultableValue> {
13    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
14    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result;
15}
16
17impl<T, V, QId, Op, Ret, const STATIC_QUERY_ID: bool> DebugQueryHelper<Yes>
18    for DebugQuery<
19        '_,
20        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op, Ret>,
21        Sqlite,
22    >
23where
24    V: QueryFragment<Sqlite>,
25    T: Copy + QuerySource,
26    Op: Copy,
27    Ret: Copy,
28    for<'b> InsertStatement<T, &'b ValuesClause<V, T>, Op, Ret>: QueryFragment<Sqlite>,
29{
30    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        let mut statements = vec![String::from("BEGIN")];
32        for record in self.query.records.values.iter() {
33            let stmt = InsertStatement::new(
34                self.query.target,
35                record,
36                self.query.operator,
37                self.query.returning,
38            );
39            statements.push(crate::debug_query(&stmt).to_string());
40        }
41        statements.push("COMMIT".into());
42
43        f.debug_struct("Query")
44            .field("sql", &statements)
45            .field("binds", &[] as &[i32; 0])
46            .finish()
47    }
48
49    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        writeln!(f, "BEGIN;")?;
51        for record in self.query.records.values.iter() {
52            let stmt = InsertStatement::new(
53                self.query.target,
54                record,
55                self.query.operator,
56                self.query.returning,
57            );
58            writeln!(f, "{}", crate::debug_query(&stmt))?;
59        }
60        writeln!(f, "COMMIT;")?;
61        Ok(())
62    }
63}
64
65#[allow(unsafe_code)] // cast to transparent wrapper type
66impl<'a, T, V, QId, Op, const STATIC_QUERY_ID: bool> DebugQueryHelper<No>
67    for DebugQuery<'a, InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>, Sqlite>
68where
69    T: Copy + QuerySource,
70    Op: Copy,
71    DebugQuery<
72        'a,
73        InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
74        Sqlite,
75    >: Debug + Display,
76{
77    fn fmt_debug(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        let value = unsafe {
79            // This cast is safe as `SqliteBatchInsertWrapper` is #[repr(transparent)]
80            &*(self as *const DebugQuery<
81                'a,
82                InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
83                Sqlite,
84            >
85                as *const DebugQuery<
86                    'a,
87                    InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
88                    Sqlite,
89                >)
90        };
91        <_ as Debug>::fmt(value, f)
92    }
93
94    fn fmt_display(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95        let value = unsafe {
96            // This cast is safe as `SqliteBatchInsertWrapper` is #[repr(transparent)]
97            &*(self as *const DebugQuery<
98                'a,
99                InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
100                Sqlite,
101            >
102                as *const DebugQuery<
103                    'a,
104                    InsertStatement<T, SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Op>,
105                    Sqlite,
106                >)
107        };
108        <_ as Display>::fmt(value, f)
109    }
110}
111
112impl<T, V, QId, Op, O, const STATIC_QUERY_ID: bool> Display
113    for DebugQuery<
114        '_,
115        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
116        Sqlite,
117    >
118where
119    T: QuerySource,
120    V: ContainsDefaultableValue<Out = O>,
121    Self: DebugQueryHelper<O>,
122{
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        self.fmt_display(f)
125    }
126}
127
128impl<T, V, QId, Op, O, const STATIC_QUERY_ID: bool> Debug
129    for DebugQuery<
130        '_,
131        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
132        Sqlite,
133    >
134where
135    T: QuerySource,
136    V: ContainsDefaultableValue<Out = O>,
137    Self: DebugQueryHelper<O>,
138{
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        self.fmt_debug(f)
141    }
142}
143
144#[allow(missing_debug_implementations, missing_copy_implementations)]
145pub struct Yes;
146
147impl Default for Yes {
148    fn default() -> Self {
149        Yes
150    }
151}
152
153#[allow(missing_debug_implementations, missing_copy_implementations)]
154pub struct No;
155
156impl Default for No {
157    fn default() -> Self {
158        No
159    }
160}
161
162pub trait Any<Rhs> {
163    type Out: Any<Yes> + Any<No>;
164}
165
166impl Any<No> for No {
167    type Out = No;
168}
169
170impl Any<Yes> for No {
171    type Out = Yes;
172}
173
174impl Any<No> for Yes {
175    type Out = Yes;
176}
177
178impl Any<Yes> for Yes {
179    type Out = Yes;
180}
181
182pub trait ContainsDefaultableValue {
183    type Out: Any<Yes> + Any<No>;
184}
185
186impl<C, B> ContainsDefaultableValue for ColumnInsertValue<C, B> {
187    type Out = No;
188}
189
190impl<I> ContainsDefaultableValue for DefaultableColumnInsertValue<I> {
191    type Out = Yes;
192}
193
194impl<I, const SIZE: usize> ContainsDefaultableValue for [I; SIZE]
195where
196    I: ContainsDefaultableValue,
197{
198    type Out = I::Out;
199}
200
201impl<I, T> ContainsDefaultableValue for ValuesClause<I, T>
202where
203    I: ContainsDefaultableValue,
204{
205    type Out = I::Out;
206}
207
208impl<T> ContainsDefaultableValue for &T
209where
210    T: ContainsDefaultableValue,
211{
212    type Out = T::Out;
213}
214
215impl<V, T, QId, C, Op, O, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
216    for InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>
217where
218    T: QuerySource,
219    C: Connection<Backend = Sqlite>,
220    V: ContainsDefaultableValue<Out = O>,
221    O: Default,
222    (O, Self): ExecuteDsl<C, Sqlite>,
223{
224    fn execute(query: Self, conn: &mut C) -> QueryResult<usize> {
225        <(O, Self) as ExecuteDsl<C, Sqlite>>::execute((O::default(), query), conn)
226    }
227}
228
229impl<V, T, QId, C, Op, O, Target, ConflictOpt, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
230    for InsertStatement<
231        T,
232        OnConflictValues<
233            BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>,
234            Target,
235            ConflictOpt,
236        >,
237        Op,
238    >
239where
240    T: QuerySource,
241    C: Connection<Backend = Sqlite>,
242    V: ContainsDefaultableValue<Out = O>,
243    O: Default,
244    (O, Self): ExecuteDsl<C, Sqlite>,
245{
246    fn execute(query: Self, conn: &mut C) -> QueryResult<usize> {
247        <(O, Self) as ExecuteDsl<C, Sqlite>>::execute((O::default(), query), conn)
248    }
249}
250
251impl<V, T, QId, C, Op, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
252    for (
253        Yes,
254        InsertStatement<T, BatchInsert<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>, Op>,
255    )
256where
257    C: Connection<Backend = Sqlite>,
258    T: Table + Copy + QueryId + 'static,
259    T::FromClause: QueryFragment<Sqlite>,
260    Op: Copy + QueryId + QueryFragment<Sqlite>,
261    V: InsertValues<Sqlite, T> + CanInsertInSingleQuery<Sqlite> + QueryId,
262{
263    fn execute((Yes, query): Self, conn: &mut C) -> QueryResult<usize> {
264        conn.transaction(|conn| {
265            let mut result = 0;
266            for record in &query.records.values {
267                let stmt =
268                    InsertStatement::new(query.target, record, query.operator, query.returning);
269                result += stmt.execute(conn)?;
270            }
271            Ok(result)
272        })
273    }
274}
275
276#[allow(missing_debug_implementations, missing_copy_implementations)]
277#[repr(transparent)]
278pub struct SqliteBatchInsertWrapper<V, T, QId, const STATIC_QUERY_ID: bool>(
279    BatchInsert<V, T, QId, STATIC_QUERY_ID>,
280);
281
282impl<V, Tab, QId, const STATIC_QUERY_ID: bool> QueryFragment<Sqlite>
283    for SqliteBatchInsertWrapper<Vec<ValuesClause<V, Tab>>, Tab, QId, STATIC_QUERY_ID>
284where
285    ValuesClause<V, Tab>: QueryFragment<Sqlite>,
286    V: QueryFragment<Sqlite>,
287{
288    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Sqlite>) -> QueryResult<()> {
289        if !STATIC_QUERY_ID {
290            out.unsafe_to_cache_prepared();
291        }
292
293        let mut values = self.0.values.iter();
294        if let Some(value) = values.next() {
295            value.walk_ast(out.reborrow())?;
296        }
297        for value in values {
298            out.push_sql(", (");
299            value.values.walk_ast(out.reborrow())?;
300            out.push_sql(")");
301        }
302        Ok(())
303    }
304}
305
306#[allow(missing_copy_implementations, missing_debug_implementations)]
307#[repr(transparent)]
308pub struct SqliteCanInsertInSingleQueryHelper<T: ?Sized>(T);
309
310impl<V, T, QId, const STATIC_QUERY_ID: bool> CanInsertInSingleQuery<Sqlite>
311    for SqliteBatchInsertWrapper<Vec<ValuesClause<V, T>>, T, QId, STATIC_QUERY_ID>
312where
313    // We constrain that here on an internal helper type
314    // to make sure that this does not accidentally leak
315    // so that none does really implement normal batch
316    // insert for inserts with default values here
317    SqliteCanInsertInSingleQueryHelper<V>: CanInsertInSingleQuery<Sqlite>,
318{
319    fn rows_to_insert(&self) -> Option<usize> {
320        Some(self.0.values.len())
321    }
322}
323
324impl<T> CanInsertInSingleQuery<Sqlite> for SqliteCanInsertInSingleQueryHelper<T>
325where
326    T: CanInsertInSingleQuery<Sqlite>,
327{
328    fn rows_to_insert(&self) -> Option<usize> {
329        self.0.rows_to_insert()
330    }
331}
332
333impl<V, T, QId, const STATIC_QUERY_ID: bool> QueryId
334    for SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>
335where
336    BatchInsert<V, T, QId, STATIC_QUERY_ID>: QueryId,
337{
338    type QueryId = <BatchInsert<V, T, QId, STATIC_QUERY_ID> as QueryId>::QueryId;
339
340    const HAS_STATIC_QUERY_ID: bool =
341        <BatchInsert<V, T, QId, STATIC_QUERY_ID> as QueryId>::HAS_STATIC_QUERY_ID;
342}
343
344impl<V, T, QId, C, Op, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
345    for (
346        No,
347        InsertStatement<T, BatchInsert<V, T, QId, STATIC_QUERY_ID>, Op>,
348    )
349where
350    C: Connection<Backend = Sqlite>,
351    T: Table + QueryId + 'static,
352    T::FromClause: QueryFragment<Sqlite>,
353    Op: QueryFragment<Sqlite> + QueryId,
354    SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>:
355        QueryFragment<Sqlite> + QueryId + CanInsertInSingleQuery<Sqlite>,
356{
357    fn execute((No, query): Self, conn: &mut C) -> QueryResult<usize> {
358        let query = InsertStatement {
359            records: SqliteBatchInsertWrapper(query.records),
360            operator: query.operator,
361            target: query.target,
362            returning: query.returning,
363            into_clause: query.into_clause,
364        };
365        query.execute(conn)
366    }
367}
368
369impl<V, T, QId, C, Op, Target, ConflictOpt, const STATIC_QUERY_ID: bool> ExecuteDsl<C, Sqlite>
370    for (
371        No,
372        InsertStatement<
373            T,
374            OnConflictValues<BatchInsert<V, T, QId, STATIC_QUERY_ID>, Target, ConflictOpt>,
375            Op,
376        >,
377    )
378where
379    C: Connection<Backend = Sqlite>,
380    T: Table + QueryId + 'static,
381    T::FromClause: QueryFragment<Sqlite>,
382    Op: QueryFragment<Sqlite> + QueryId,
383    OnConflictValues<SqliteBatchInsertWrapper<V, T, QId, STATIC_QUERY_ID>, Target, ConflictOpt>:
384        QueryFragment<Sqlite> + CanInsertInSingleQuery<Sqlite> + QueryId,
385{
386    fn execute((No, query): Self, conn: &mut C) -> QueryResult<usize> {
387        let query = InsertStatement {
388            operator: query.operator,
389            target: query.target,
390            records: OnConflictValues {
391                values: SqliteBatchInsertWrapper(query.records.values),
392                target: query.records.target,
393                action: query.records.action,
394                where_clause: query.records.where_clause,
395            },
396            returning: query.returning,
397            into_clause: query.into_clause,
398        };
399        query.execute(conn)
400    }
401}
402
403macro_rules! tuple_impls {
404        ($(
405            $Tuple:tt {
406                $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
407            }
408        )+) => {
409            $(
410                impl_contains_defaultable_value!($($T,)*);
411            )*
412        }
413    }
414
415macro_rules! impl_contains_defaultable_value {
416      (
417        @build
418        start_ts = [$($ST: ident,)*],
419        ts = [$T1: ident,],
420        bounds = [$($bounds: tt)*],
421        out = [$($out: tt)*],
422    )=> {
423        impl<$($ST,)*> ContainsDefaultableValue for ($($ST,)*)
424        where
425            $($ST: ContainsDefaultableValue,)*
426            $($bounds)*
427            $T1::Out: Any<$($out)*>,
428        {
429            type Out = <$T1::Out as Any<$($out)*>>::Out;
430        }
431
432    };
433    (
434        @build
435        start_ts = [$($ST: ident,)*],
436        ts = [$T1: ident, $($T: ident,)+],
437        bounds = [$($bounds: tt)*],
438        out = [$($out: tt)*],
439    )=> {
440        impl_contains_defaultable_value! {
441            @build
442            start_ts = [$($ST,)*],
443            ts = [$($T,)*],
444            bounds = [$($bounds)* $T1::Out: Any<$($out)*>,],
445            out = [<$T1::Out as Any<$($out)*>>::Out],
446        }
447    };
448    ($T1: ident, $($T: ident,)+) => {
449        impl_contains_defaultable_value! {
450            @build
451            start_ts = [$T1, $($T,)*],
452            ts = [$($T,)*],
453            bounds = [],
454            out = [$T1::Out],
455        }
456    };
457    ($T1: ident,) => {
458        impl<$T1> ContainsDefaultableValue for ($T1,)
459        where $T1: ContainsDefaultableValue,
460        {
461            type Out = <$T1 as ContainsDefaultableValue>::Out;
462        }
463    }
464}
465
466diesel_derives::__diesel_for_each_tuple!(tuple_impls);