diesel/pg/query_builder/copy/
copy_from.rs

1use std::borrow::Cow;
2use std::marker::PhantomData;
3
4use byteorder::NetworkEndian;
5use byteorder::WriteBytesExt;
6
7use super::CommonOptions;
8use super::CopyFormat;
9use super::CopyTarget;
10use crate::expression::bound::Bound;
11use crate::insertable::ColumnInsertValue;
12use crate::pg::backend::FailedToLookupTypeError;
13use crate::pg::metadata_lookup::PgMetadataCacheKey;
14use crate::pg::Pg;
15use crate::pg::PgMetadataLookup;
16use crate::query_builder::BatchInsert;
17use crate::query_builder::QueryFragment;
18use crate::query_builder::QueryId;
19use crate::query_builder::ValuesClause;
20use crate::serialize::IsNull;
21use crate::serialize::ToSql;
22use crate::Connection;
23use crate::Insertable;
24use crate::QueryResult;
25use crate::{Column, Table};
26
27/// Describes the different possible settings for the `HEADER` option
28/// for `COPY FROM` statements
29#[derive(Debug, Copy, Clone)]
30pub enum CopyHeader {
31    /// Is the header set?
32    Set(bool),
33    /// Match the header with the targeted table names
34    /// and fail in the case of a mismatch
35    Match,
36}
37
38#[derive(Debug, Default)]
39pub struct CopyFromOptions {
40    common: CommonOptions,
41    default: Option<String>,
42    header: Option<CopyHeader>,
43}
44
45impl QueryFragment<Pg> for CopyFromOptions {
46    fn walk_ast<'b>(
47        &'b self,
48        mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
49    ) -> crate::QueryResult<()> {
50        if self.any_set() {
51            let mut comma = "";
52            pass.push_sql(" WITH (");
53            self.common.walk_ast(pass.reborrow(), &mut comma);
54            if let Some(ref default) = self.default {
55                pass.push_sql(comma);
56                comma = ", ";
57                pass.push_sql("DEFAULT '");
58                // cannot use binds here :(
59                pass.push_sql(default);
60                pass.push_sql("'");
61            }
62            if let Some(ref header) = self.header {
63                pass.push_sql(comma);
64                // commented out because rustc complains otherwise
65                //comma = ", ";
66                pass.push_sql("HEADER ");
67                match header {
68                    CopyHeader::Set(true) => pass.push_sql("1"),
69                    CopyHeader::Set(false) => pass.push_sql("0"),
70                    CopyHeader::Match => pass.push_sql("MATCH"),
71                }
72            }
73
74            pass.push_sql(")");
75        }
76        Ok(())
77    }
78}
79
80impl CopyFromOptions {
81    fn any_set(&self) -> bool {
82        self.common.any_set() || self.default.is_some() || self.header.is_some()
83    }
84}
85
86#[derive(Debug)]
87pub struct CopyFrom<S, F> {
88    options: CopyFromOptions,
89    copy_callback: F,
90    p: PhantomData<S>,
91}
92
93pub(crate) struct InternalCopyFromQuery<S, T> {
94    pub(crate) target: S,
95    p: PhantomData<T>,
96}
97
98#[cfg(feature = "postgres")]
99impl<S, T> InternalCopyFromQuery<S, T> {
100    pub(crate) fn new(target: S) -> Self {
101        Self {
102            target,
103            p: PhantomData,
104        }
105    }
106}
107
108impl<S, T> QueryId for InternalCopyFromQuery<S, T>
109where
110    S: CopyFromExpression<T>,
111{
112    const HAS_STATIC_QUERY_ID: bool = false;
113    type QueryId = ();
114}
115
116impl<S, T> QueryFragment<Pg> for InternalCopyFromQuery<S, T>
117where
118    S: CopyFromExpression<T>,
119{
120    fn walk_ast<'b>(
121        &'b self,
122        mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
123    ) -> crate::QueryResult<()> {
124        pass.unsafe_to_cache_prepared();
125        pass.push_sql("COPY ");
126        self.target.walk_target(pass.reborrow())?;
127        pass.push_sql(" FROM STDIN");
128        self.target.options().walk_ast(pass.reborrow())?;
129        Ok(())
130    }
131}
132
133pub trait CopyFromExpression<T> {
134    type Error: From<crate::result::Error> + std::error::Error;
135
136    fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error>;
137
138    fn walk_target<'b>(
139        &'b self,
140        pass: crate::query_builder::AstPass<'_, 'b, Pg>,
141    ) -> crate::QueryResult<()>;
142
143    fn options(&self) -> &CopyFromOptions;
144}
145
146impl<S, F, E> CopyFromExpression<S::Table> for CopyFrom<S, F>
147where
148    E: From<crate::result::Error> + std::error::Error,
149    S: CopyTarget,
150    F: Fn(&mut dyn std::io::Write) -> Result<(), E>,
151{
152    type Error = E;
153
154    fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error> {
155        (self.copy_callback)(copy)
156    }
157
158    fn options(&self) -> &CopyFromOptions {
159        &self.options
160    }
161
162    fn walk_target<'b>(
163        &'b self,
164        pass: crate::query_builder::AstPass<'_, 'b, Pg>,
165    ) -> crate::QueryResult<()> {
166        S::walk_target(pass)
167    }
168}
169
170struct Dummy;
171
172impl PgMetadataLookup for Dummy {
173    fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> crate::pg::PgTypeMetadata {
174        let cache_key = PgMetadataCacheKey::new(
175            schema.map(Into::into).map(Cow::Owned),
176            Cow::Owned(type_name.into()),
177        );
178        crate::pg::PgTypeMetadata(Err(FailedToLookupTypeError::new_internal(cache_key)))
179    }
180}
181
182trait CopyFromInsertableHelper {
183    type Target: CopyTarget;
184    const COLUMN_COUNT: i16;
185
186    fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull>;
187}
188
189macro_rules! impl_copy_from_insertable_helper_for_values_clause {
190    ($(
191        $Tuple:tt {
192            $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
193        }
194    )+) => {
195        $(
196            impl<T, $($ST,)* $($T,)* $($TT,)*> CopyFromInsertableHelper for ValuesClause<
197                ($(ColumnInsertValue<$ST, Bound<$T, $TT>>,)*),
198            T>
199                where
200                T: Table,
201                $($ST: Column<Table = T>,)*
202                ($($ST,)*): CopyTarget,
203                $($TT: ToSql<$T, Pg>,)*
204            {
205                type Target = ($($ST,)*);
206
207                // statically known to always fit
208                // as we don't support more than 128 columns
209                #[allow(clippy::cast_possible_truncation)]
210                const COLUMN_COUNT: i16 = $Tuple as i16;
211
212                fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull> {
213                    use crate::query_builder::ByteWrapper;
214                    use crate::serialize::Output;
215
216                    let values = &self.values;
217                    match idx {
218                        $($idx =>{
219                            let item = &values.$idx.expr.item;
220                            let is_null = ToSql::<$T, Pg>::to_sql(
221                                item,
222                                &mut Output::new( ByteWrapper(out), &mut Dummy as _)
223                            ).map_err(crate::result::Error::SerializationError)?;
224                            return Ok(is_null);
225                        })*
226                        _ => unreachable!(),
227                    }
228                }
229            }
230
231            impl<'a, T, $($ST,)* $($T,)* $($TT,)*> CopyFromInsertableHelper for ValuesClause<
232                ($(ColumnInsertValue<$ST, &'a Bound<$T, $TT>>,)*),
233            T>
234                where
235                T: Table,
236                $($ST: Column<Table = T>,)*
237                ($($ST,)*): CopyTarget,
238                $($TT: ToSql<$T, Pg>,)*
239            {
240                type Target = ($($ST,)*);
241
242                // statically known to always fit
243                // as we don't support more than 128 columns
244                #[allow(clippy::cast_possible_truncation)]
245                const COLUMN_COUNT: i16 = $Tuple as i16;
246
247                fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull> {
248                    use crate::query_builder::ByteWrapper;
249                    use crate::serialize::Output;
250
251                    let values = &self.values;
252                    match idx {
253                        $($idx =>{
254                            let item = &values.$idx.expr.item;
255                            let is_null = ToSql::<$T, Pg>::to_sql(
256                                item,
257                                &mut Output::new( ByteWrapper(out), &mut Dummy as _)
258                            ).map_err(crate::result::Error::SerializationError)?;
259                            return Ok(is_null);
260                        })*
261                        _ => unreachable!(),
262                    }
263                }
264            }
265        )*
266    }
267}
268
269diesel_derives::__diesel_for_each_tuple!(impl_copy_from_insertable_helper_for_values_clause);
270
271#[derive(Debug)]
272pub struct InsertableWrapper<I>(Option<I>);
273
274impl<I, T, V, QId, const STATIC_QUERY_ID: bool> CopyFromExpression<T> for InsertableWrapper<I>
275where
276    I: Insertable<T, Values = BatchInsert<Vec<V>, T, QId, STATIC_QUERY_ID>>,
277    V: CopyFromInsertableHelper,
278{
279    type Error = crate::result::Error;
280
281    fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error> {
282        let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e));
283        // see https://www.postgresql.org/docs/current/sql-copy.html for
284        // a description of the binary format
285        //
286        // We don't write oids
287
288        // write the header
289        copy.write_all(&super::COPY_MAGIC_HEADER)
290            .map_err(io_result_mapper)?;
291        copy.write_i32::<NetworkEndian>(0)
292            .map_err(io_result_mapper)?;
293        copy.write_i32::<NetworkEndian>(0)
294            .map_err(io_result_mapper)?;
295        // write the data
296        // we reuse the same buffer here again and again
297        // as we expect the data to be "similar"
298        // this skips reallocating
299        let mut buffer = Vec::<u8>::new();
300        let values = self
301            .0
302            .take()
303            .expect("We only call this callback once")
304            .values();
305        for i in values.values {
306            // column count
307            buffer
308                .write_i16::<NetworkEndian>(V::COLUMN_COUNT)
309                .map_err(io_result_mapper)?;
310            for idx in 0..V::COLUMN_COUNT {
311                // first write the null indicator as dummy value
312                buffer
313                    .write_i32::<NetworkEndian>(-1)
314                    .map_err(io_result_mapper)?;
315                let len_before = buffer.len();
316                let is_null = i.write_to_buffer(idx, &mut buffer)?;
317                if is_null == IsNull::No {
318                    // fill in the length afterwards
319                    let len_after = buffer.len();
320                    let diff = (len_after - len_before)
321                        .try_into()
322                        .map_err(|e| crate::result::Error::SerializationError(Box::new(e)))?;
323                    let bytes = i32::to_be_bytes(diff);
324                    for (b, t) in bytes.into_iter().zip(&mut buffer[len_before - 4..]) {
325                        *t = b;
326                    }
327                }
328            }
329            copy.write_all(&buffer).map_err(io_result_mapper)?;
330            buffer.clear();
331        }
332        // write the trailer
333        copy.write_i16::<NetworkEndian>(-1)
334            .map_err(io_result_mapper)?;
335        Ok(())
336    }
337
338    fn options(&self) -> &CopyFromOptions {
339        &CopyFromOptions {
340            common: CommonOptions {
341                format: Some(CopyFormat::Binary),
342                freeze: None,
343                delimiter: None,
344                null: None,
345                quote: None,
346                escape: None,
347            },
348            default: None,
349            header: None,
350        }
351    }
352
353    fn walk_target<'b>(
354        &'b self,
355        pass: crate::query_builder::AstPass<'_, 'b, Pg>,
356    ) -> crate::QueryResult<()> {
357        <V as CopyFromInsertableHelper>::Target::walk_target(pass)
358    }
359}
360
361/// The structure returned by [`copy_from`]
362///
363/// The [`from_raw_data`] and the [`from_insertable`] methods allow
364/// to configure the data copied into the database
365///
366/// The `with_*` methods allow to configure the settings used for the
367/// copy statement.
368///
369/// [`from_raw_data`]: CopyFromQuery::from_raw_data
370/// [`from_insertable`]: CopyFromQuery::from_insertable
371#[derive(Debug)]
372#[must_use = "`COPY FROM` statements are only executed when calling `.execute()`."]
373#[cfg(feature = "postgres_backend")]
374pub struct CopyFromQuery<T, Action> {
375    table: T,
376    action: Action,
377}
378
379impl<T> CopyFromQuery<T, NotSet>
380where
381    T: Table,
382{
383    /// Copy data into the database by directly providing the data in the corresponding format
384    ///
385    /// `target` specifies the column selection that is the target of the `COPY FROM` statement
386    /// `action` expects a callback which accepts a [`std::io::Write`] argument. The necessary format
387    /// accepted by this writer sink depends on the options provided via the `with_*` methods
388    #[allow(clippy::wrong_self_convention)] // the sql struct is named that way
389    pub fn from_raw_data<F, C, E>(self, _target: C, action: F) -> CopyFromQuery<T, CopyFrom<C, F>>
390    where
391        C: CopyTarget<Table = T>,
392        F: Fn(&mut dyn std::io::Write) -> Result<(), E>,
393    {
394        CopyFromQuery {
395            table: self.table,
396            action: CopyFrom {
397                p: PhantomData,
398                options: Default::default(),
399                copy_callback: action,
400            },
401        }
402    }
403
404    /// Copy a set of insertable values into the database.
405    ///
406    /// The `insertable` argument is expected to be a `Vec<I>`, `&[I]` or similar, where `I`
407    /// needs to implement `Insertable<T>`. If you use the [`#[derive(Insertable)]`](derive@crate::prelude::Insertable)
408    /// derive macro make sure to also set the `#[diesel(treat_none_as_default_value = false)]` option
409    /// to disable the default value handling otherwise implemented by `#[derive(Insertable)]`.
410    ///
411    /// This uses the binary format. It internally configures the correct
412    /// set of settings and does not allow to set other options
413    #[allow(clippy::wrong_self_convention)] // the sql struct is named that way
414    pub fn from_insertable<I>(self, insertable: I) -> CopyFromQuery<T, InsertableWrapper<I>>
415    where
416        InsertableWrapper<I>: CopyFromExpression<T>,
417    {
418        CopyFromQuery {
419            table: self.table,
420            action: InsertableWrapper(Some(insertable)),
421        }
422    }
423}
424
425impl<T, C, F> CopyFromQuery<T, CopyFrom<C, F>> {
426    /// The format used for the copy statement
427    ///
428    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
429    /// for more details.
430    pub fn with_format(mut self, format: CopyFormat) -> Self {
431        self.action.options.common.format = Some(format);
432        self
433    }
434
435    /// Whether or not the `freeze` option is set
436    ///
437    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
438    /// for more details.
439    pub fn with_freeze(mut self, freeze: bool) -> Self {
440        self.action.options.common.freeze = Some(freeze);
441        self
442    }
443
444    /// Which delimiter should be used for textual input formats
445    ///
446    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
447    /// for more details.
448    pub fn with_delimiter(mut self, delimiter: char) -> Self {
449        self.action.options.common.delimiter = Some(delimiter);
450        self
451    }
452
453    /// Which string should be used in place of a `NULL` value
454    /// for textual input formats
455    ///
456    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
457    /// for more details.
458    pub fn with_null(mut self, null: impl Into<String>) -> Self {
459        self.action.options.common.null = Some(null.into());
460        self
461    }
462
463    /// Which quote character should be used for textual input formats
464    ///
465    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
466    /// for more details.
467    pub fn with_quote(mut self, quote: char) -> Self {
468        self.action.options.common.quote = Some(quote);
469        self
470    }
471
472    /// Which escape character should be used for textual input formats
473    ///
474    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
475    /// for more details.
476    pub fn with_escape(mut self, escape: char) -> Self {
477        self.action.options.common.escape = Some(escape);
478        self
479    }
480
481    /// Which string should be used to indicate that
482    /// the `default` value should be used in place of that string
483    /// for textual formats
484    ///
485    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
486    /// for more details.
487    ///
488    /// (This parameter was added with PostgreSQL 16)
489    pub fn with_default(mut self, default: impl Into<String>) -> Self {
490        self.action.options.default = Some(default.into());
491        self
492    }
493
494    /// Is a header provided as part of the textual input or not
495    ///
496    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
497    /// for more details.
498    pub fn with_header(mut self, header: CopyHeader) -> Self {
499        self.action.options.header = Some(header);
500        self
501    }
502}
503
504/// A custom execute function tailored for `COPY FROM` statements
505///
506/// This trait can be used to execute `COPY FROM` queries constructed
507/// via [`copy_from]`
508pub trait ExecuteCopyFromDsl<C>
509where
510    C: Connection<Backend = Pg>,
511{
512    /// The error type returned by the execute function
513    type Error: std::error::Error;
514
515    /// See the trait documentation for details
516    fn execute(self, conn: &mut C) -> Result<usize, Self::Error>;
517}
518
519#[cfg(feature = "postgres")]
520impl<T, A> ExecuteCopyFromDsl<crate::PgConnection> for CopyFromQuery<T, A>
521where
522    A: CopyFromExpression<T>,
523{
524    type Error = A::Error;
525
526    fn execute(self, conn: &mut crate::PgConnection) -> Result<usize, A::Error> {
527        conn.copy_from::<A, T>(self.action)
528    }
529}
530
531#[cfg(feature = "r2d2")]
532impl<T, A, C> ExecuteCopyFromDsl<crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>>
533    for CopyFromQuery<T, A>
534where
535    A: CopyFromExpression<T>,
536    C: crate::r2d2::R2D2Connection<Backend = Pg> + 'static,
537    Self: ExecuteCopyFromDsl<C>,
538{
539    type Error = <Self as ExecuteCopyFromDsl<C>>::Error;
540
541    fn execute(
542        self,
543        conn: &mut crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>,
544    ) -> Result<usize, Self::Error> {
545        self.execute(&mut **conn)
546    }
547}
548
549#[derive(Debug, Clone, Copy)]
550pub struct NotSet;
551
552/// Creates a `COPY FROM` statement
553///
554/// This function constructs `COPY FROM` statement which copies data
555/// *from* a source into the database. It's designed to move larger
556/// amounts of data into the database.
557///
558/// This function accepts a target table as argument.
559///
560/// There are two ways to construct a `COPY FROM` statement with
561/// diesel:
562///
563/// * By providing a `Vec<I>` where `I` implements `Insertable` for the
564///   given table
565/// * By providing a target selection (column list or table name)
566///   and a callback that provides the data
567///
568/// The first variant uses the `BINARY` format internally to send
569/// the provided data efficiently to the database. It automatically
570/// sets the right options and does not allow changing them.
571/// Use [`CopyFromQuery::from_insertable`] for this.
572///
573/// The second variant allows you to control the behaviour
574/// of the generated `COPY FROM` statement in detail. It can
575/// be setup via the [`CopyFromQuery::from_raw_data`] function.
576/// The callback accepts an opaque object as argument that allows
577/// to write the corresponding data to the database. The exact
578/// format depends on the settings chosen by the various
579/// `CopyFromQuery::with_*` methods. See
580/// [the postgresql documentation](https://www.postgresql.org/docs/current/sql-copy.html)
581/// for more details about the expected formats.
582///
583/// If you don't have any specific needs you should prefer
584/// using the more convenient first variant.
585///
586/// This functionality is postgresql specific.
587///
588/// # Examples
589///
590/// ## Via [`CopyFromQuery::from_insertable`]
591///
592/// ```rust
593/// # include!("../../../doctest_setup.rs");
594/// # use crate::schema::users;
595///
596/// #[derive(Insertable)]
597/// #[diesel(table_name = users)]
598/// #[diesel(treat_none_as_default_value = false)]
599/// struct NewUser {
600///     name: &'static str,
601/// }
602///
603/// # fn run_test() -> QueryResult<()> {
604/// # let connection = &mut establish_connection();
605///
606/// let data = vec![
607///     NewUser { name: "Diva Plavalaguna" },
608///     NewUser { name: "Father Vito Cornelius" },
609/// ];
610///
611/// let count = diesel::copy_from(users::table)
612///     .from_insertable(&data)
613///     .execute(connection)?;
614///
615/// assert_eq!(count, 2);
616/// # Ok(())
617/// # }
618/// # fn main() {
619/// #    run_test().unwrap();
620/// # }
621/// ```
622///
623/// ## Via [`CopyFromQuery::from_raw_data`]
624///
625/// ```rust
626/// # include!("../../../doctest_setup.rs");
627/// # fn run_test() -> QueryResult<()> {
628/// # use crate::schema::users;
629/// use diesel::pg::CopyFormat;
630/// # let connection = &mut establish_connection();
631/// let count = diesel::copy_from(users::table)
632///     .from_raw_data(users::table, |copy| {
633///         writeln!(copy, "3,Diva Plavalaguna").unwrap();
634///         writeln!(copy, "4,Father Vito Cornelius").unwrap();
635///         diesel::QueryResult::Ok(())
636///     })
637///     .with_format(CopyFormat::Csv)
638///     .execute(connection)?;
639///
640/// assert_eq!(count, 2);
641/// # Ok(())
642/// # }
643/// # fn main() {
644/// #    run_test().unwrap();
645/// # }
646/// ```
647#[cfg(feature = "postgres_backend")]
648pub fn copy_from<T>(table: T) -> CopyFromQuery<T, NotSet>
649where
650    T: Table,
651{
652    CopyFromQuery {
653        table,
654        action: NotSet,
655    }
656}