diesel/pg/query_builder/copy/
copy_to.rs

1use std::io::BufRead;
2use std::marker::PhantomData;
3
4use super::CommonOptions;
5use super::CopyFormat;
6use super::CopyTarget;
7use crate::deserialize::FromSqlRow;
8#[cfg(feature = "postgres")]
9use crate::pg::value::TypeOidLookup;
10use crate::pg::Pg;
11use crate::query_builder::QueryFragment;
12use crate::query_builder::QueryId;
13use crate::row::Row;
14#[cfg(feature = "postgres")]
15use crate::row::{self, Field, PartialRow, RowIndex, RowSealed};
16use crate::AppearsOnTable;
17use crate::Connection;
18use crate::Expression;
19use crate::QueryResult;
20use crate::Selectable;
21
22#[derive(Default, Debug)]
23pub struct CopyToOptions {
24    common: CommonOptions,
25    header: Option<bool>,
26}
27
28impl CopyToOptions {
29    fn any_set(&self) -> bool {
30        self.common.any_set() || self.header.is_some()
31    }
32}
33
34impl QueryFragment<Pg> for CopyToOptions {
35    fn walk_ast<'b>(
36        &'b self,
37        mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
38    ) -> crate::QueryResult<()> {
39        if self.any_set() {
40            let mut comma = "";
41            pass.push_sql(" WITH (");
42            self.common.walk_ast(pass.reborrow(), &mut comma);
43            if let Some(header_is_set) = self.header {
44                pass.push_sql(comma);
45                // commented out because rustc complains otherwise
46                //comma = ", ";
47                pass.push_sql("HEADER ");
48                pass.push_sql(if header_is_set { "1" } else { "0" });
49            }
50
51            pass.push_sql(")");
52        }
53        Ok(())
54    }
55}
56
57#[derive(Debug)]
58pub struct CopyToCommand<S> {
59    options: CopyToOptions,
60    p: PhantomData<S>,
61}
62
63impl<S> QueryId for CopyToCommand<S>
64where
65    S: CopyTarget,
66{
67    type QueryId = ();
68
69    const HAS_STATIC_QUERY_ID: bool = false;
70}
71
72impl<S> QueryFragment<Pg> for CopyToCommand<S>
73where
74    S: CopyTarget,
75{
76    fn walk_ast<'b>(
77        &'b self,
78        mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
79    ) -> crate::QueryResult<()> {
80        pass.unsafe_to_cache_prepared();
81        pass.push_sql("COPY ");
82        S::walk_target(pass.reborrow())?;
83        pass.push_sql(" TO STDOUT");
84        self.options.walk_ast(pass.reborrow())?;
85        Ok(())
86    }
87}
88
89#[derive(Debug, Clone, Copy)]
90pub struct NotSet;
91
92pub trait CopyToMarker: Sized {
93    fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions>;
94}
95
96impl CopyToMarker for NotSet {
97    fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions> {
98        CopyToQuery {
99            target: q.target,
100            options: CopyToOptions::default(),
101        }
102    }
103}
104impl CopyToMarker for CopyToOptions {
105    fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions> {
106        q
107    }
108}
109/// The structure returned by [`copy_to`]
110///
111/// The [`load`] and the [`load_raw`] methods allow
112/// to receive the configured data from the database.
113/// If you don't have any special needs you should prefer using
114/// the more convenient `load` method.
115///
116/// The `with_*` methods allow to configure the settings used for the
117/// copy statement.
118///
119/// [`load`]: CopyToQuery::load
120/// [`load_raw`]: CopyToQuery::load_raw
121#[derive(Debug)]
122#[must_use = "`COPY TO` statements are only executed when calling `.load()` or `load_raw()`."]
123#[cfg(feature = "postgres_backend")]
124pub struct CopyToQuery<T, O> {
125    target: T,
126    options: O,
127}
128
129#[cfg(feature = "postgres")]
130struct CopyRow<'a> {
131    buffers: Vec<Option<&'a [u8]>>,
132    result: &'a crate::pg::connection::PgResult,
133}
134
135#[cfg(feature = "postgres")]
136struct CopyField<'a> {
137    field: &'a Option<&'a [u8]>,
138    result: &'a crate::pg::connection::PgResult,
139    col_idx: usize,
140}
141
142#[cfg(feature = "postgres")]
143impl<'f> Field<'f, Pg> for CopyField<'f> {
144    fn field_name(&self) -> Option<&str> {
145        None
146    }
147
148    fn value(&self) -> Option<<Pg as crate::backend::Backend>::RawValue<'_>> {
149        let value = self.field.as_deref()?;
150        Some(crate::pg::PgValue::new_internal(value, self))
151    }
152}
153
154#[cfg(feature = "postgres")]
155impl TypeOidLookup for CopyField<'_> {
156    fn lookup(&self) -> std::num::NonZeroU32 {
157        self.result.column_type(self.col_idx)
158    }
159}
160
161#[cfg(feature = "postgres")]
162impl RowSealed for CopyRow<'_> {}
163
164#[cfg(feature = "postgres")]
165impl RowIndex<usize> for CopyRow<'_> {
166    fn idx(&self, idx: usize) -> Option<usize> {
167        if idx < self.field_count() {
168            Some(idx)
169        } else {
170            None
171        }
172    }
173}
174
175#[cfg(feature = "postgres")]
176impl<'a> RowIndex<&'a str> for CopyRow<'_> {
177    fn idx(&self, _idx: &'a str) -> Option<usize> {
178        None
179    }
180}
181
182#[cfg(feature = "postgres")]
183impl<'a> Row<'a, Pg> for CopyRow<'_> {
184    type Field<'f>
185        = CopyField<'f>
186    where
187        'a: 'f,
188        Self: 'f;
189
190    type InnerPartialRow = Self;
191
192    fn field_count(&self) -> usize {
193        self.buffers.len()
194    }
195
196    fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
197    where
198        'a: 'b,
199        Self: RowIndex<I>,
200    {
201        let idx = self.idx(idx)?;
202        let buffer = self.buffers.get(idx)?;
203        Some(CopyField {
204            field: buffer,
205            result: self.result,
206            col_idx: idx,
207        })
208    }
209
210    fn partial_row(
211        &self,
212        range: std::ops::Range<usize>,
213    ) -> row::PartialRow<'_, Self::InnerPartialRow> {
214        PartialRow::new(self, range)
215    }
216}
217
218pub trait ExecuteCopyToConnection: Connection<Backend = Pg> {
219    type CopyToBuffer<'a>: BufRead;
220
221    fn make_row<'a, 'b>(
222        out: &'a Self::CopyToBuffer<'_>,
223        buffers: Vec<Option<&'a [u8]>>,
224    ) -> impl Row<'b, Pg> + 'a;
225
226    fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8];
227
228    fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
229    where
230        T: CopyTarget;
231}
232
233#[cfg(feature = "postgres")]
234impl ExecuteCopyToConnection for crate::PgConnection {
235    type CopyToBuffer<'a> = crate::pg::connection::copy::CopyToBuffer<'a>;
236
237    fn make_row<'a, 'b>(
238        out: &'a Self::CopyToBuffer<'_>,
239        buffers: Vec<Option<&'a [u8]>>,
240    ) -> impl Row<'b, Pg> + 'a {
241        CopyRow {
242            buffers,
243            result: out.get_result(),
244        }
245    }
246
247    fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8] {
248        out.data_slice()
249    }
250
251    fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
252    where
253        T: CopyTarget,
254    {
255        self.copy_to(command)
256    }
257}
258
259#[cfg(feature = "r2d2")]
260impl<C> ExecuteCopyToConnection for crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>
261where
262    C: ExecuteCopyToConnection + crate::r2d2::R2D2Connection + 'static,
263{
264    type CopyToBuffer<'a> = C::CopyToBuffer<'a>;
265
266    fn make_row<'a, 'b>(
267        out: &'a Self::CopyToBuffer<'_>,
268        buffers: Vec<Option<&'a [u8]>>,
269    ) -> impl Row<'b, Pg> + 'a {
270        C::make_row(out, buffers)
271    }
272
273    fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8] {
274        C::get_buffer(out)
275    }
276
277    fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
278    where
279        T: CopyTarget,
280    {
281        C::execute(&mut **self, command)
282    }
283}
284
285impl<T> CopyToQuery<T, NotSet>
286where
287    T: CopyTarget,
288{
289    /// Copy data from the database by returning an iterator of deserialized data
290    ///
291    /// This function allows to easily load data from the database via a `COPY TO` statement.
292    /// It does **not** allow to configure any settings via the `with_*` method, as it internally
293    /// sets the required options itself. It will use the binary format to deserialize the result
294    /// into the specified type `U`. Column selection is performed via [`Selectable`].
295    pub fn load<U, C>(self, conn: &mut C) -> QueryResult<impl Iterator<Item = QueryResult<U>> + '_>
296    where
297        U: FromSqlRow<<U::SelectExpression as Expression>::SqlType, Pg> + Selectable<Pg>,
298        U::SelectExpression: AppearsOnTable<T::Table> + CopyTarget<Table = T::Table>,
299        C: ExecuteCopyToConnection,
300    {
301        let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e));
302
303        let command = CopyToCommand {
304            p: PhantomData::<U::SelectExpression>,
305            options: CopyToOptions {
306                header: None,
307                common: CommonOptions {
308                    format: Some(CopyFormat::Binary),
309                    ..Default::default()
310                },
311            },
312        };
313        // see https://www.postgresql.org/docs/current/sql-copy.html for
314        // a description of the binary format
315        //
316        // We don't write oids
317
318        let mut out = ExecuteCopyToConnection::execute(conn, command)?;
319        let buffer = out.fill_buf().map_err(io_result_mapper)?;
320        if buffer[..super::COPY_MAGIC_HEADER.len()] != super::COPY_MAGIC_HEADER {
321            return Err(crate::result::Error::DeserializationError(
322                "Unexpected protocol header".into(),
323            ));
324        }
325        // we care only about bit 16-31 here, so we can just skip the bytes in between
326        let flags_backward_incompatible = i16::from_be_bytes(
327            (&buffer[super::COPY_MAGIC_HEADER.len() + 2..super::COPY_MAGIC_HEADER.len() + 4])
328                .try_into()
329                .expect("Exactly 2 byte"),
330        );
331        if flags_backward_incompatible != 0 {
332            return Err(crate::result::Error::DeserializationError(
333                format!("Unexpected flag value: {flags_backward_incompatible:x}").into(),
334            ));
335        }
336        let header_size = usize::try_from(i32::from_be_bytes(
337            (&buffer[super::COPY_MAGIC_HEADER.len() + 4..super::COPY_MAGIC_HEADER.len() + 8])
338                .try_into()
339                .expect("Exactly 4 byte"),
340        ))
341        .map_err(|e| crate::result::Error::DeserializationError(Box::new(e)))?;
342        out.consume(super::COPY_MAGIC_HEADER.len() + 8 + header_size);
343        let mut len = None;
344        Ok(std::iter::from_fn(move || {
345            if let Some(len) = len {
346                out.consume(len);
347                if let Err(e) = out.fill_buf().map_err(io_result_mapper) {
348                    return Some(Err(e));
349                }
350            }
351            let buffer = C::get_buffer(&out);
352            len = Some(buffer.len());
353            let tuple_count =
354                i16::from_be_bytes((&buffer[..2]).try_into().expect("Exactly 2 bytes"));
355            if tuple_count > 0 {
356                let tuple_count = match usize::try_from(tuple_count) {
357                    Ok(o) => o,
358                    Err(e) => {
359                        return Some(Err(crate::result::Error::DeserializationError(Box::new(e))))
360                    }
361                };
362                let mut buffers = Vec::with_capacity(tuple_count);
363                let mut offset = 2;
364                for _t in 0..tuple_count {
365                    let data_size = i32::from_be_bytes(
366                        (&buffer[offset..offset + 4])
367                            .try_into()
368                            .expect("Exactly 4 bytes"),
369                    );
370
371                    if data_size < 0 {
372                        buffers.push(None);
373                    } else {
374                        match usize::try_from(data_size) {
375                            Ok(data_size) => {
376                                buffers.push(Some(&buffer[offset + 4..offset + 4 + data_size]));
377                                offset = offset + 4 + data_size;
378                            }
379                            Err(e) => {
380                                return Some(Err(crate::result::Error::DeserializationError(
381                                    Box::new(e),
382                                )));
383                            }
384                        }
385                    }
386                }
387
388                let row = C::make_row(&out, buffers);
389                Some(U::build_from_row(&row).map_err(crate::result::Error::DeserializationError))
390            } else {
391                None
392            }
393        }))
394    }
395}
396
397impl<T, O> CopyToQuery<T, O>
398where
399    O: CopyToMarker,
400    T: CopyTarget,
401{
402    /// Copy data from the database by directly accessing the provided response
403    ///
404    /// This function returns a type that implements [`std::io::BufRead`] which allows to directly read
405    /// the data as provided by the database. The exact format depends on what options are
406    /// set via the various `with_*` methods.
407    pub fn load_raw<C>(self, conn: &mut C) -> QueryResult<impl BufRead + '_>
408    where
409        C: ExecuteCopyToConnection,
410    {
411        let q = O::setup_options(self);
412        let command = CopyToCommand {
413            p: PhantomData::<T>,
414            options: q.options,
415        };
416        ExecuteCopyToConnection::execute(conn, command)
417    }
418
419    /// The format used for the copy statement
420    ///
421    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
422    /// for more details.
423    pub fn with_format(self, format: CopyFormat) -> CopyToQuery<T, CopyToOptions> {
424        let mut out = O::setup_options(self);
425        out.options.common.format = Some(format);
426        out
427    }
428
429    /// Whether or not the `freeze` option is set
430    ///
431    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
432    /// for more details.
433    pub fn with_freeze(self, freeze: bool) -> CopyToQuery<T, CopyToOptions> {
434        let mut out = O::setup_options(self);
435        out.options.common.freeze = Some(freeze);
436        out
437    }
438
439    /// Which delimiter should be used for textual output formats
440    ///
441    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
442    /// for more details.
443    pub fn with_delimiter(self, delimiter: char) -> CopyToQuery<T, CopyToOptions> {
444        let mut out = O::setup_options(self);
445        out.options.common.delimiter = Some(delimiter);
446        out
447    }
448
449    /// Which string should be used in place of a `NULL` value
450    /// for textual output formats
451    ///
452    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
453    /// for more details.
454    pub fn with_null(self, null: impl Into<String>) -> CopyToQuery<T, CopyToOptions> {
455        let mut out = O::setup_options(self);
456        out.options.common.null = Some(null.into());
457        out
458    }
459
460    /// Which quote character should be used for textual output formats
461    ///
462    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
463    /// for more details.
464    pub fn with_quote(self, quote: char) -> CopyToQuery<T, CopyToOptions> {
465        let mut out = O::setup_options(self);
466        out.options.common.quote = Some(quote);
467        out
468    }
469
470    /// Which escape character should be used for textual output formats
471    ///
472    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
473    /// for more details.
474    pub fn with_escape(self, escape: char) -> CopyToQuery<T, CopyToOptions> {
475        let mut out = O::setup_options(self);
476        out.options.common.escape = Some(escape);
477        out
478    }
479
480    /// Is a header provided as part of the textual input or not
481    ///
482    /// See the [PostgreSQL documentation](https://www.postgresql.org/docs/current/sql-copy.html)
483    /// for more details.
484    pub fn with_header(self, set: bool) -> CopyToQuery<T, CopyToOptions> {
485        let mut out = O::setup_options(self);
486        out.options.header = Some(set);
487        out
488    }
489}
490
491/// Creates a `COPY TO` statement
492///
493/// This function constructs a `COPY TO` statement which copies data
494/// from the database **to** a client side target. It's designed to move
495/// larger amounts of data out of the database.
496///
497/// This function accepts a target selection (table name or list of columns) as argument.
498///
499/// There are two ways to use a `COPY TO` statement with diesel:
500///
501/// * By using [`CopyToQuery::load`] directly to load the deserialized result
502///   directly into a specified type
503/// * By using the `with_*` methods to configure the format sent by the database
504///   and then by calling [`CopyToQuery::load_raw`] to receive the raw data
505///   sent by the database.
506///
507/// The first variant uses the `BINARY` format internally to receive
508/// the selected data efficiently. It automatically sets the right options
509/// and does not allow to change them via `with_*` methods.
510///
511/// The second variant allows you to control the behaviour of the
512/// generated `COPY TO` statement in detail. You can use the various
513/// `with_*` methods for that before issuing the statement via [`CopyToQuery::load_raw`].
514/// That method will return an type that implements [`std::io::BufRead`], which
515/// allows you to directly read the response from the database in the configured
516/// format.
517/// See [the postgresql documentation](https://www.postgresql.org/docs/current/sql-copy.html)
518/// for more details about the supported formats.
519///
520/// If you don't have any specific needs you should prefer using the more
521/// convenient first variant.
522///
523/// This functionality is postgresql specific.
524///
525/// # Examples
526///
527/// ## Via [`CopyToQuery::load()`]
528///
529/// ```rust
530/// # include!("../../../doctest_setup.rs");
531/// # use crate::schema::users;
532///
533/// #[derive(Queryable, Selectable, PartialEq, Debug)]
534/// #[diesel(table_name = users)]
535/// #[diesel(check_for_backend(diesel::pg::Pg))]
536/// struct User {
537///     name: String,
538/// }
539///
540/// # fn run_test() -> QueryResult<()> {
541/// # let connection = &mut establish_connection();
542/// let out = diesel::copy_to(users::table)
543///     .load::<User, _>(connection)?
544///     .collect::<Result<Vec<_>, _>>()?;
545///
546/// assert_eq!(out, vec![User{ name: "Sean".into() }, User{ name: "Tess".into() }]);
547/// # Ok(())
548/// # }
549/// # fn main() {
550/// #    run_test().unwrap();
551/// # }
552/// ```
553///
554/// ## Via [`CopyToQuery::load_raw()`]
555///
556/// ```rust
557/// # include!("../../../doctest_setup.rs");
558/// # fn run_test() -> QueryResult<()> {
559/// # use crate::schema::users;
560/// use diesel::pg::CopyFormat;
561/// use std::io::Read;
562/// # let connection = &mut establish_connection();
563///
564/// let mut copy = diesel::copy_to(users::table)
565///     .with_format(CopyFormat::Csv)
566///     .load_raw(connection)?;
567///
568/// let mut out = String::new();
569/// copy.read_to_string(&mut out).unwrap();
570/// assert_eq!(out, "1,Sean\n2,Tess\n");
571/// # Ok(())
572/// # }
573/// # fn main() {
574/// #    run_test().unwrap();
575/// # }
576/// ```
577#[cfg(feature = "postgres_backend")]
578pub fn copy_to<T>(target: T) -> CopyToQuery<T, NotSet>
579where
580    T: CopyTarget,
581{
582    CopyToQuery {
583        target,
584        options: NotSet,
585    }
586}