diesel/pg/query_builder/copy/
copy_to.rs1use std::io::BufRead;
2use std::marker::PhantomData;
3
4use super::CommonOptions;
5use super::CopyFormat;
6use super::CopyTarget;
7use crate::deserialize::FromSqlRow;
8#[cfg(feature = "postgres")]
9use crate::pg::value::TypeOidLookup;
10use crate::pg::Pg;
11use crate::query_builder::QueryFragment;
12use crate::query_builder::QueryId;
13use crate::row::Row;
14#[cfg(feature = "postgres")]
15use crate::row::{self, Field, PartialRow, RowIndex, RowSealed};
16use crate::AppearsOnTable;
17use crate::Connection;
18use crate::Expression;
19use crate::QueryResult;
20use crate::Selectable;
21
22#[derive(Default, Debug)]
23pub struct CopyToOptions {
24 common: CommonOptions,
25 header: Option<bool>,
26}
27
28impl CopyToOptions {
29 fn any_set(&self) -> bool {
30 self.common.any_set() || self.header.is_some()
31 }
32}
33
34impl QueryFragment<Pg> for CopyToOptions {
35 fn walk_ast<'b>(
36 &'b self,
37 mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
38 ) -> crate::QueryResult<()> {
39 if self.any_set() {
40 let mut comma = "";
41 pass.push_sql(" WITH (");
42 self.common.walk_ast(pass.reborrow(), &mut comma);
43 if let Some(header_is_set) = self.header {
44 pass.push_sql(comma);
45 pass.push_sql("HEADER ");
48 pass.push_sql(if header_is_set { "1" } else { "0" });
49 }
50
51 pass.push_sql(")");
52 }
53 Ok(())
54 }
55}
56
57#[derive(Debug)]
58pub struct CopyToCommand<S> {
59 options: CopyToOptions,
60 p: PhantomData<S>,
61}
62
63impl<S> QueryId for CopyToCommand<S>
64where
65 S: CopyTarget,
66{
67 type QueryId = ();
68
69 const HAS_STATIC_QUERY_ID: bool = false;
70}
71
72impl<S> QueryFragment<Pg> for CopyToCommand<S>
73where
74 S: CopyTarget,
75{
76 fn walk_ast<'b>(
77 &'b self,
78 mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
79 ) -> crate::QueryResult<()> {
80 pass.unsafe_to_cache_prepared();
81 pass.push_sql("COPY ");
82 S::walk_target(pass.reborrow())?;
83 pass.push_sql(" TO STDOUT");
84 self.options.walk_ast(pass.reborrow())?;
85 Ok(())
86 }
87}
88
89#[derive(Debug, Clone, Copy)]
90pub struct NotSet;
91
92pub trait CopyToMarker: Sized {
93 fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions>;
94}
95
96impl CopyToMarker for NotSet {
97 fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions> {
98 CopyToQuery {
99 target: q.target,
100 options: CopyToOptions::default(),
101 }
102 }
103}
104impl CopyToMarker for CopyToOptions {
105 fn setup_options<T>(q: CopyToQuery<T, Self>) -> CopyToQuery<T, CopyToOptions> {
106 q
107 }
108}
109#[derive(Debug)]
122#[must_use = "`COPY TO` statements are only executed when calling `.load()` or `load_raw()`."]
123#[cfg(feature = "postgres_backend")]
124pub struct CopyToQuery<T, O> {
125 target: T,
126 options: O,
127}
128
129#[cfg(feature = "postgres")]
130struct CopyRow<'a> {
131 buffers: Vec<Option<&'a [u8]>>,
132 result: &'a crate::pg::connection::PgResult,
133}
134
135#[cfg(feature = "postgres")]
136struct CopyField<'a> {
137 field: &'a Option<&'a [u8]>,
138 result: &'a crate::pg::connection::PgResult,
139 col_idx: usize,
140}
141
142#[cfg(feature = "postgres")]
143impl<'f> Field<'f, Pg> for CopyField<'f> {
144 fn field_name(&self) -> Option<&str> {
145 None
146 }
147
148 fn value(&self) -> Option<<Pg as crate::backend::Backend>::RawValue<'_>> {
149 let value = self.field.as_deref()?;
150 Some(crate::pg::PgValue::new_internal(value, self))
151 }
152}
153
154#[cfg(feature = "postgres")]
155impl TypeOidLookup for CopyField<'_> {
156 fn lookup(&self) -> std::num::NonZeroU32 {
157 self.result.column_type(self.col_idx)
158 }
159}
160
161#[cfg(feature = "postgres")]
162impl RowSealed for CopyRow<'_> {}
163
164#[cfg(feature = "postgres")]
165impl RowIndex<usize> for CopyRow<'_> {
166 fn idx(&self, idx: usize) -> Option<usize> {
167 if idx < self.field_count() {
168 Some(idx)
169 } else {
170 None
171 }
172 }
173}
174
175#[cfg(feature = "postgres")]
176impl<'a> RowIndex<&'a str> for CopyRow<'_> {
177 fn idx(&self, _idx: &'a str) -> Option<usize> {
178 None
179 }
180}
181
182#[cfg(feature = "postgres")]
183impl<'a> Row<'a, Pg> for CopyRow<'_> {
184 type Field<'f>
185 = CopyField<'f>
186 where
187 'a: 'f,
188 Self: 'f;
189
190 type InnerPartialRow = Self;
191
192 fn field_count(&self) -> usize {
193 self.buffers.len()
194 }
195
196 fn get<'b, I>(&'b self, idx: I) -> Option<Self::Field<'b>>
197 where
198 'a: 'b,
199 Self: RowIndex<I>,
200 {
201 let idx = self.idx(idx)?;
202 let buffer = self.buffers.get(idx)?;
203 Some(CopyField {
204 field: buffer,
205 result: self.result,
206 col_idx: idx,
207 })
208 }
209
210 fn partial_row(
211 &self,
212 range: std::ops::Range<usize>,
213 ) -> row::PartialRow<'_, Self::InnerPartialRow> {
214 PartialRow::new(self, range)
215 }
216}
217
218pub trait ExecuteCopyToConnection: Connection<Backend = Pg> {
219 type CopyToBuffer<'a>: BufRead;
220
221 fn make_row<'a, 'b>(
222 out: &'a Self::CopyToBuffer<'_>,
223 buffers: Vec<Option<&'a [u8]>>,
224 ) -> impl Row<'b, Pg> + 'a;
225
226 fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8];
227
228 fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
229 where
230 T: CopyTarget;
231}
232
233#[cfg(feature = "postgres")]
234impl ExecuteCopyToConnection for crate::PgConnection {
235 type CopyToBuffer<'a> = crate::pg::connection::copy::CopyToBuffer<'a>;
236
237 fn make_row<'a, 'b>(
238 out: &'a Self::CopyToBuffer<'_>,
239 buffers: Vec<Option<&'a [u8]>>,
240 ) -> impl Row<'b, Pg> + 'a {
241 CopyRow {
242 buffers,
243 result: out.get_result(),
244 }
245 }
246
247 fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8] {
248 out.data_slice()
249 }
250
251 fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
252 where
253 T: CopyTarget,
254 {
255 self.copy_to(command)
256 }
257}
258
259#[cfg(feature = "r2d2")]
260impl<C> ExecuteCopyToConnection for crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>
261where
262 C: ExecuteCopyToConnection + crate::r2d2::R2D2Connection + 'static,
263{
264 type CopyToBuffer<'a> = C::CopyToBuffer<'a>;
265
266 fn make_row<'a, 'b>(
267 out: &'a Self::CopyToBuffer<'_>,
268 buffers: Vec<Option<&'a [u8]>>,
269 ) -> impl Row<'b, Pg> + 'a {
270 C::make_row(out, buffers)
271 }
272
273 fn get_buffer<'a>(out: &'a Self::CopyToBuffer<'_>) -> &'a [u8] {
274 C::get_buffer(out)
275 }
276
277 fn execute<T>(&mut self, command: CopyToCommand<T>) -> QueryResult<Self::CopyToBuffer<'_>>
278 where
279 T: CopyTarget,
280 {
281 C::execute(&mut **self, command)
282 }
283}
284
285impl<T> CopyToQuery<T, NotSet>
286where
287 T: CopyTarget,
288{
289 pub fn load<U, C>(self, conn: &mut C) -> QueryResult<impl Iterator<Item = QueryResult<U>> + '_>
296 where
297 U: FromSqlRow<<U::SelectExpression as Expression>::SqlType, Pg> + Selectable<Pg>,
298 U::SelectExpression: AppearsOnTable<T::Table> + CopyTarget<Table = T::Table>,
299 C: ExecuteCopyToConnection,
300 {
301 let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e));
302
303 let command = CopyToCommand {
304 p: PhantomData::<U::SelectExpression>,
305 options: CopyToOptions {
306 header: None,
307 common: CommonOptions {
308 format: Some(CopyFormat::Binary),
309 ..Default::default()
310 },
311 },
312 };
313 let mut out = ExecuteCopyToConnection::execute(conn, command)?;
319 let buffer = out.fill_buf().map_err(io_result_mapper)?;
320 if buffer[..super::COPY_MAGIC_HEADER.len()] != super::COPY_MAGIC_HEADER {
321 return Err(crate::result::Error::DeserializationError(
322 "Unexpected protocol header".into(),
323 ));
324 }
325 let flags_backward_incompatible = i16::from_be_bytes(
327 (&buffer[super::COPY_MAGIC_HEADER.len() + 2..super::COPY_MAGIC_HEADER.len() + 4])
328 .try_into()
329 .expect("Exactly 2 byte"),
330 );
331 if flags_backward_incompatible != 0 {
332 return Err(crate::result::Error::DeserializationError(
333 format!("Unexpected flag value: {flags_backward_incompatible:x}").into(),
334 ));
335 }
336 let header_size = usize::try_from(i32::from_be_bytes(
337 (&buffer[super::COPY_MAGIC_HEADER.len() + 4..super::COPY_MAGIC_HEADER.len() + 8])
338 .try_into()
339 .expect("Exactly 4 byte"),
340 ))
341 .map_err(|e| crate::result::Error::DeserializationError(Box::new(e)))?;
342 out.consume(super::COPY_MAGIC_HEADER.len() + 8 + header_size);
343 let mut len = None;
344 Ok(std::iter::from_fn(move || {
345 if let Some(len) = len {
346 out.consume(len);
347 if let Err(e) = out.fill_buf().map_err(io_result_mapper) {
348 return Some(Err(e));
349 }
350 }
351 let buffer = C::get_buffer(&out);
352 len = Some(buffer.len());
353 let tuple_count =
354 i16::from_be_bytes((&buffer[..2]).try_into().expect("Exactly 2 bytes"));
355 if tuple_count > 0 {
356 let tuple_count = match usize::try_from(tuple_count) {
357 Ok(o) => o,
358 Err(e) => {
359 return Some(Err(crate::result::Error::DeserializationError(Box::new(e))))
360 }
361 };
362 let mut buffers = Vec::with_capacity(tuple_count);
363 let mut offset = 2;
364 for _t in 0..tuple_count {
365 let data_size = i32::from_be_bytes(
366 (&buffer[offset..offset + 4])
367 .try_into()
368 .expect("Exactly 4 bytes"),
369 );
370
371 if data_size < 0 {
372 buffers.push(None);
373 } else {
374 match usize::try_from(data_size) {
375 Ok(data_size) => {
376 buffers.push(Some(&buffer[offset + 4..offset + 4 + data_size]));
377 offset = offset + 4 + data_size;
378 }
379 Err(e) => {
380 return Some(Err(crate::result::Error::DeserializationError(
381 Box::new(e),
382 )));
383 }
384 }
385 }
386 }
387
388 let row = C::make_row(&out, buffers);
389 Some(U::build_from_row(&row).map_err(crate::result::Error::DeserializationError))
390 } else {
391 None
392 }
393 }))
394 }
395}
396
397impl<T, O> CopyToQuery<T, O>
398where
399 O: CopyToMarker,
400 T: CopyTarget,
401{
402 pub fn load_raw<C>(self, conn: &mut C) -> QueryResult<impl BufRead + '_>
408 where
409 C: ExecuteCopyToConnection,
410 {
411 let q = O::setup_options(self);
412 let command = CopyToCommand {
413 p: PhantomData::<T>,
414 options: q.options,
415 };
416 ExecuteCopyToConnection::execute(conn, command)
417 }
418
419 pub fn with_format(self, format: CopyFormat) -> CopyToQuery<T, CopyToOptions> {
424 let mut out = O::setup_options(self);
425 out.options.common.format = Some(format);
426 out
427 }
428
429 pub fn with_freeze(self, freeze: bool) -> CopyToQuery<T, CopyToOptions> {
434 let mut out = O::setup_options(self);
435 out.options.common.freeze = Some(freeze);
436 out
437 }
438
439 pub fn with_delimiter(self, delimiter: char) -> CopyToQuery<T, CopyToOptions> {
444 let mut out = O::setup_options(self);
445 out.options.common.delimiter = Some(delimiter);
446 out
447 }
448
449 pub fn with_null(self, null: impl Into<String>) -> CopyToQuery<T, CopyToOptions> {
455 let mut out = O::setup_options(self);
456 out.options.common.null = Some(null.into());
457 out
458 }
459
460 pub fn with_quote(self, quote: char) -> CopyToQuery<T, CopyToOptions> {
465 let mut out = O::setup_options(self);
466 out.options.common.quote = Some(quote);
467 out
468 }
469
470 pub fn with_escape(self, escape: char) -> CopyToQuery<T, CopyToOptions> {
475 let mut out = O::setup_options(self);
476 out.options.common.escape = Some(escape);
477 out
478 }
479
480 pub fn with_header(self, set: bool) -> CopyToQuery<T, CopyToOptions> {
485 let mut out = O::setup_options(self);
486 out.options.header = Some(set);
487 out
488 }
489}
490
491#[cfg(feature = "postgres_backend")]
578pub fn copy_to<T>(target: T) -> CopyToQuery<T, NotSet>
579where
580 T: CopyTarget,
581{
582 CopyToQuery {
583 target,
584 options: NotSet,
585 }
586}