1use std::borrow::Cow;
2use std::marker::PhantomData;
3
4use byteorder::NetworkEndian;
5use byteorder::WriteBytesExt;
6
7use super::CommonOptions;
8use super::CopyFormat;
9use super::CopyTarget;
10use crate::expression::bound::Bound;
11use crate::insertable::ColumnInsertValue;
12use crate::pg::backend::FailedToLookupTypeError;
13use crate::pg::metadata_lookup::PgMetadataCacheKey;
14use crate::pg::Pg;
15use crate::pg::PgMetadataLookup;
16use crate::query_builder::BatchInsert;
17use crate::query_builder::QueryFragment;
18use crate::query_builder::QueryId;
19use crate::query_builder::ValuesClause;
20use crate::serialize::IsNull;
21use crate::serialize::ToSql;
22use crate::Connection;
23use crate::Insertable;
24use crate::QueryResult;
25use crate::{Column, Table};
26
27#[derive(Debug, Copy, Clone)]
30pub enum CopyHeader {
31 Set(bool),
33 Match,
36}
37
38#[derive(Debug, Default)]
39pub struct CopyFromOptions {
40 common: CommonOptions,
41 default: Option<String>,
42 header: Option<CopyHeader>,
43}
44
45impl QueryFragment<Pg> for CopyFromOptions {
46 fn walk_ast<'b>(
47 &'b self,
48 mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
49 ) -> crate::QueryResult<()> {
50 if self.any_set() {
51 let mut comma = "";
52 pass.push_sql(" WITH (");
53 self.common.walk_ast(pass.reborrow(), &mut comma);
54 if let Some(ref default) = self.default {
55 pass.push_sql(comma);
56 comma = ", ";
57 pass.push_sql("DEFAULT '");
58 pass.push_sql(default);
60 pass.push_sql("'");
61 }
62 if let Some(ref header) = self.header {
63 pass.push_sql(comma);
64 pass.push_sql("HEADER ");
67 match header {
68 CopyHeader::Set(true) => pass.push_sql("1"),
69 CopyHeader::Set(false) => pass.push_sql("0"),
70 CopyHeader::Match => pass.push_sql("MATCH"),
71 }
72 }
73
74 pass.push_sql(")");
75 }
76 Ok(())
77 }
78}
79
80impl CopyFromOptions {
81 fn any_set(&self) -> bool {
82 self.common.any_set() || self.default.is_some() || self.header.is_some()
83 }
84}
85
86#[derive(Debug)]
87pub struct CopyFrom<S, F> {
88 options: CopyFromOptions,
89 copy_callback: F,
90 p: PhantomData<S>,
91}
92
93pub(crate) struct InternalCopyFromQuery<S, T> {
94 pub(crate) target: S,
95 p: PhantomData<T>,
96}
97
98#[cfg(feature = "postgres")]
99impl<S, T> InternalCopyFromQuery<S, T> {
100 pub(crate) fn new(target: S) -> Self {
101 Self {
102 target,
103 p: PhantomData,
104 }
105 }
106}
107
108impl<S, T> QueryId for InternalCopyFromQuery<S, T>
109where
110 S: CopyFromExpression<T>,
111{
112 const HAS_STATIC_QUERY_ID: bool = false;
113 type QueryId = ();
114}
115
116impl<S, T> QueryFragment<Pg> for InternalCopyFromQuery<S, T>
117where
118 S: CopyFromExpression<T>,
119{
120 fn walk_ast<'b>(
121 &'b self,
122 mut pass: crate::query_builder::AstPass<'_, 'b, Pg>,
123 ) -> crate::QueryResult<()> {
124 pass.unsafe_to_cache_prepared();
125 pass.push_sql("COPY ");
126 self.target.walk_target(pass.reborrow())?;
127 pass.push_sql(" FROM STDIN");
128 self.target.options().walk_ast(pass.reborrow())?;
129 Ok(())
130 }
131}
132
133pub trait CopyFromExpression<T> {
134 type Error: From<crate::result::Error> + std::error::Error;
135
136 fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error>;
137
138 fn walk_target<'b>(
139 &'b self,
140 pass: crate::query_builder::AstPass<'_, 'b, Pg>,
141 ) -> crate::QueryResult<()>;
142
143 fn options(&self) -> &CopyFromOptions;
144}
145
146impl<S, F, E> CopyFromExpression<S::Table> for CopyFrom<S, F>
147where
148 E: From<crate::result::Error> + std::error::Error,
149 S: CopyTarget,
150 F: Fn(&mut dyn std::io::Write) -> Result<(), E>,
151{
152 type Error = E;
153
154 fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error> {
155 (self.copy_callback)(copy)
156 }
157
158 fn options(&self) -> &CopyFromOptions {
159 &self.options
160 }
161
162 fn walk_target<'b>(
163 &'b self,
164 pass: crate::query_builder::AstPass<'_, 'b, Pg>,
165 ) -> crate::QueryResult<()> {
166 S::walk_target(pass)
167 }
168}
169
170struct Dummy;
171
172impl PgMetadataLookup for Dummy {
173 fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> crate::pg::PgTypeMetadata {
174 let cache_key = PgMetadataCacheKey::new(
175 schema.map(Into::into).map(Cow::Owned),
176 Cow::Owned(type_name.into()),
177 );
178 crate::pg::PgTypeMetadata(Err(FailedToLookupTypeError::new_internal(cache_key)))
179 }
180}
181
182trait CopyFromInsertableHelper {
183 type Target: CopyTarget;
184 const COLUMN_COUNT: i16;
185
186 fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull>;
187}
188
189macro_rules! impl_copy_from_insertable_helper_for_values_clause {
190 ($(
191 $Tuple:tt {
192 $(($idx:tt) -> $T:ident, $ST:ident, $TT:ident,)+
193 }
194 )+) => {
195 $(
196 impl<T, $($ST,)* $($T,)* $($TT,)*> CopyFromInsertableHelper for ValuesClause<
197 ($(ColumnInsertValue<$ST, Bound<$T, $TT>>,)*),
198 T>
199 where
200 T: Table,
201 $($ST: Column<Table = T>,)*
202 ($($ST,)*): CopyTarget,
203 $($TT: ToSql<$T, Pg>,)*
204 {
205 type Target = ($($ST,)*);
206
207 #[allow(clippy::cast_possible_truncation)]
210 const COLUMN_COUNT: i16 = $Tuple as i16;
211
212 fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull> {
213 use crate::query_builder::ByteWrapper;
214 use crate::serialize::Output;
215
216 let values = &self.values;
217 match idx {
218 $($idx =>{
219 let item = &values.$idx.expr.item;
220 let is_null = ToSql::<$T, Pg>::to_sql(
221 item,
222 &mut Output::new( ByteWrapper(out), &mut Dummy as _)
223 ).map_err(crate::result::Error::SerializationError)?;
224 return Ok(is_null);
225 })*
226 _ => unreachable!(),
227 }
228 }
229 }
230
231 impl<'a, T, $($ST,)* $($T,)* $($TT,)*> CopyFromInsertableHelper for ValuesClause<
232 ($(ColumnInsertValue<$ST, &'a Bound<$T, $TT>>,)*),
233 T>
234 where
235 T: Table,
236 $($ST: Column<Table = T>,)*
237 ($($ST,)*): CopyTarget,
238 $($TT: ToSql<$T, Pg>,)*
239 {
240 type Target = ($($ST,)*);
241
242 #[allow(clippy::cast_possible_truncation)]
245 const COLUMN_COUNT: i16 = $Tuple as i16;
246
247 fn write_to_buffer(&self, idx: i16, out: &mut Vec<u8>) -> QueryResult<IsNull> {
248 use crate::query_builder::ByteWrapper;
249 use crate::serialize::Output;
250
251 let values = &self.values;
252 match idx {
253 $($idx =>{
254 let item = &values.$idx.expr.item;
255 let is_null = ToSql::<$T, Pg>::to_sql(
256 item,
257 &mut Output::new( ByteWrapper(out), &mut Dummy as _)
258 ).map_err(crate::result::Error::SerializationError)?;
259 return Ok(is_null);
260 })*
261 _ => unreachable!(),
262 }
263 }
264 }
265 )*
266 }
267}
268
269diesel_derives::__diesel_for_each_tuple!(impl_copy_from_insertable_helper_for_values_clause);
270
271#[derive(Debug)]
272pub struct InsertableWrapper<I>(Option<I>);
273
274impl<I, T, V, QId, const STATIC_QUERY_ID: bool> CopyFromExpression<T> for InsertableWrapper<I>
275where
276 I: Insertable<T, Values = BatchInsert<Vec<V>, T, QId, STATIC_QUERY_ID>>,
277 V: CopyFromInsertableHelper,
278{
279 type Error = crate::result::Error;
280
281 fn callback(&mut self, copy: &mut impl std::io::Write) -> Result<(), Self::Error> {
282 let io_result_mapper = |e| crate::result::Error::DeserializationError(Box::new(e));
283 copy.write_all(&super::COPY_MAGIC_HEADER)
290 .map_err(io_result_mapper)?;
291 copy.write_i32::<NetworkEndian>(0)
292 .map_err(io_result_mapper)?;
293 copy.write_i32::<NetworkEndian>(0)
294 .map_err(io_result_mapper)?;
295 let mut buffer = Vec::<u8>::new();
300 let values = self
301 .0
302 .take()
303 .expect("We only call this callback once")
304 .values();
305 for i in values.values {
306 buffer
308 .write_i16::<NetworkEndian>(V::COLUMN_COUNT)
309 .map_err(io_result_mapper)?;
310 for idx in 0..V::COLUMN_COUNT {
311 buffer
313 .write_i32::<NetworkEndian>(-1)
314 .map_err(io_result_mapper)?;
315 let len_before = buffer.len();
316 let is_null = i.write_to_buffer(idx, &mut buffer)?;
317 if is_null == IsNull::No {
318 let len_after = buffer.len();
320 let diff = (len_after - len_before)
321 .try_into()
322 .map_err(|e| crate::result::Error::SerializationError(Box::new(e)))?;
323 let bytes = i32::to_be_bytes(diff);
324 for (b, t) in bytes.into_iter().zip(&mut buffer[len_before - 4..]) {
325 *t = b;
326 }
327 }
328 }
329 copy.write_all(&buffer).map_err(io_result_mapper)?;
330 buffer.clear();
331 }
332 copy.write_i16::<NetworkEndian>(-1)
334 .map_err(io_result_mapper)?;
335 Ok(())
336 }
337
338 fn options(&self) -> &CopyFromOptions {
339 &CopyFromOptions {
340 common: CommonOptions {
341 format: Some(CopyFormat::Binary),
342 freeze: None,
343 delimiter: None,
344 null: None,
345 quote: None,
346 escape: None,
347 },
348 default: None,
349 header: None,
350 }
351 }
352
353 fn walk_target<'b>(
354 &'b self,
355 pass: crate::query_builder::AstPass<'_, 'b, Pg>,
356 ) -> crate::QueryResult<()> {
357 <V as CopyFromInsertableHelper>::Target::walk_target(pass)
358 }
359}
360
361#[derive(Debug)]
372#[must_use = "`COPY FROM` statements are only executed when calling `.execute()`."]
373#[cfg(feature = "postgres_backend")]
374pub struct CopyFromQuery<T, Action> {
375 table: T,
376 action: Action,
377}
378
379impl<T> CopyFromQuery<T, NotSet>
380where
381 T: Table,
382{
383 #[allow(clippy::wrong_self_convention)] pub fn from_raw_data<F, C, E>(self, _target: C, action: F) -> CopyFromQuery<T, CopyFrom<C, F>>
390 where
391 C: CopyTarget<Table = T>,
392 F: Fn(&mut dyn std::io::Write) -> Result<(), E>,
393 {
394 CopyFromQuery {
395 table: self.table,
396 action: CopyFrom {
397 p: PhantomData,
398 options: Default::default(),
399 copy_callback: action,
400 },
401 }
402 }
403
404 #[allow(clippy::wrong_self_convention)] pub fn from_insertable<I>(self, insertable: I) -> CopyFromQuery<T, InsertableWrapper<I>>
415 where
416 InsertableWrapper<I>: CopyFromExpression<T>,
417 {
418 CopyFromQuery {
419 table: self.table,
420 action: InsertableWrapper(Some(insertable)),
421 }
422 }
423}
424
425impl<T, C, F> CopyFromQuery<T, CopyFrom<C, F>> {
426 pub fn with_format(mut self, format: CopyFormat) -> Self {
431 self.action.options.common.format = Some(format);
432 self
433 }
434
435 pub fn with_freeze(mut self, freeze: bool) -> Self {
440 self.action.options.common.freeze = Some(freeze);
441 self
442 }
443
444 pub fn with_delimiter(mut self, delimiter: char) -> Self {
449 self.action.options.common.delimiter = Some(delimiter);
450 self
451 }
452
453 pub fn with_null(mut self, null: impl Into<String>) -> Self {
459 self.action.options.common.null = Some(null.into());
460 self
461 }
462
463 pub fn with_quote(mut self, quote: char) -> Self {
468 self.action.options.common.quote = Some(quote);
469 self
470 }
471
472 pub fn with_escape(mut self, escape: char) -> Self {
477 self.action.options.common.escape = Some(escape);
478 self
479 }
480
481 pub fn with_default(mut self, default: impl Into<String>) -> Self {
490 self.action.options.default = Some(default.into());
491 self
492 }
493
494 pub fn with_header(mut self, header: CopyHeader) -> Self {
499 self.action.options.header = Some(header);
500 self
501 }
502}
503
504pub trait ExecuteCopyFromDsl<C>
509where
510 C: Connection<Backend = Pg>,
511{
512 type Error: std::error::Error;
514
515 fn execute(self, conn: &mut C) -> Result<usize, Self::Error>;
517}
518
519#[cfg(feature = "postgres")]
520impl<T, A> ExecuteCopyFromDsl<crate::PgConnection> for CopyFromQuery<T, A>
521where
522 A: CopyFromExpression<T>,
523{
524 type Error = A::Error;
525
526 fn execute(self, conn: &mut crate::PgConnection) -> Result<usize, A::Error> {
527 conn.copy_from::<A, T>(self.action)
528 }
529}
530
531#[cfg(feature = "r2d2")]
532impl<T, A, C> ExecuteCopyFromDsl<crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>>
533 for CopyFromQuery<T, A>
534where
535 A: CopyFromExpression<T>,
536 C: crate::r2d2::R2D2Connection<Backend = Pg> + 'static,
537 Self: ExecuteCopyFromDsl<C>,
538{
539 type Error = <Self as ExecuteCopyFromDsl<C>>::Error;
540
541 fn execute(
542 self,
543 conn: &mut crate::r2d2::PooledConnection<crate::r2d2::ConnectionManager<C>>,
544 ) -> Result<usize, Self::Error> {
545 self.execute(&mut **conn)
546 }
547}
548
549#[derive(Debug, Clone, Copy)]
550pub struct NotSet;
551
552#[cfg(feature = "postgres_backend")]
648pub fn copy_from<T>(table: T) -> CopyFromQuery<T, NotSet>
649where
650 T: Table,
651{
652 CopyFromQuery {
653 table,
654 action: NotSet,
655 }
656}