Skip to main content

diesel/pg/connection/
raw.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(unsafe_code)] // ffi code
3
4extern crate pq_sys;
5
6use self::pq_sys::*;
7use alloc::ffi::CString;
8use core::ffi as libc;
9use core::ffi::CStr;
10use core::ptr::NonNull;
11use core::{ptr, str};
12
13use crate::result::*;
14
15use super::result::PgResult;
16use crate::pg::PgNotification;
17
18#[allow(missing_debug_implementations, missing_copy_implementations)]
19pub(super) struct RawConnection {
20    pub(super) internal_connection: NonNull<PGconn>,
21}
22
23impl RawConnection {
24    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
25        let connection_string = CString::new(database_url)?;
26        let connection_ptr = unsafe { PQconnectdb(connection_string.as_ptr()) };
27        let connection_status = unsafe { PQstatus(connection_ptr) };
28
29        match connection_status {
30            ConnStatusType::CONNECTION_OK => {
31                let connection_ptr = unsafe { NonNull::new_unchecked(connection_ptr) };
32                Ok(RawConnection {
33                    internal_connection: connection_ptr,
34                })
35            }
36            _ => {
37                let message = last_error_message(connection_ptr);
38
39                if !connection_ptr.is_null() {
40                    // Note that even if the server connection attempt fails (as indicated by PQstatus),
41                    // the application should call PQfinish to free the memory used by the PGconn object.
42                    // https://www.postgresql.org/docs/current/libpq-connect.html
43                    unsafe { PQfinish(connection_ptr) }
44                }
45
46                Err(ConnectionError::BadConnection(message))
47            }
48        }
49    }
50
51    pub(super) fn last_error_message(&self) -> String {
52        last_error_message(self.internal_connection.as_ptr())
53    }
54
55    pub(super) fn set_notice_processor(&self, notice_processor: NoticeProcessor) {
56        unsafe {
57            PQsetNoticeProcessor(
58                self.internal_connection.as_ptr(),
59                Some(notice_processor),
60                ptr::null_mut(),
61            );
62        }
63    }
64
65    pub(super) unsafe fn exec(&self, query: *const libc::c_char) -> QueryResult<RawResult> {
66        let result_ptr = unsafe { PQexec(self.internal_connection.as_ptr(), query) };
67        RawResult::new(result_ptr, self)
68    }
69
70    /// Sends a query and parameters to the server without using the prepare/bind cycle.
71    ///
72    /// This method uses PQsendQueryParams which combines the prepare and bind steps
73    /// and is more compatible with connection poolers like PgBouncer.
74    pub(super) unsafe fn send_query_params(
75        &self,
76        query: *const libc::c_char,
77        param_count: libc::c_int,
78        param_types: *const Oid,
79        param_values: *const *const libc::c_char,
80        param_lengths: *const libc::c_int,
81        param_formats: *const libc::c_int,
82        result_format: libc::c_int,
83    ) -> QueryResult<()> {
84        let res = unsafe {
85            PQsendQueryParams(
86                self.internal_connection.as_ptr(),
87                query,
88                param_count,
89                param_types,
90                param_values,
91                param_lengths,
92                param_formats,
93                result_format,
94            )
95        };
96        if res == 1 {
97            Ok(())
98        } else {
99            Err(Error::DatabaseError(
100                DatabaseErrorKind::UnableToSendCommand,
101                Box::new(self.last_error_message()),
102            ))
103        }
104    }
105
106    pub(super) unsafe fn send_query_prepared(
107        &self,
108        stmt_name: *const libc::c_char,
109        param_count: libc::c_int,
110        param_values: *const *const libc::c_char,
111        param_lengths: *const libc::c_int,
112        param_formats: *const libc::c_int,
113        result_format: libc::c_int,
114    ) -> QueryResult<()> {
115        let res = unsafe {
116            PQsendQueryPrepared(
117                self.internal_connection.as_ptr(),
118                stmt_name,
119                param_count,
120                param_values,
121                param_lengths,
122                param_formats,
123                result_format,
124            )
125        };
126        if res == 1 {
127            Ok(())
128        } else {
129            Err(Error::DatabaseError(
130                DatabaseErrorKind::UnableToSendCommand,
131                Box::new(self.last_error_message()),
132            ))
133        }
134    }
135
136    pub(super) unsafe fn prepare(
137        &self,
138        stmt_name: *const libc::c_char,
139        query: *const libc::c_char,
140        param_count: libc::c_int,
141        param_types: *const Oid,
142    ) -> QueryResult<RawResult> {
143        let ptr = unsafe {
144            PQprepare(
145                self.internal_connection.as_ptr(),
146                stmt_name,
147                query,
148                param_count,
149                param_types,
150            )
151        };
152        RawResult::new(ptr, self)
153    }
154
155    /// This is reasonably inexpensive as it just accesses variables internal to the connection
156    /// that are kept up to date by the `ReadyForQuery` messages from the PG server
157    pub(super) fn transaction_status(&self) -> PgTransactionStatus {
158        unsafe { PQtransactionStatus(self.internal_connection.as_ptr()) }.into()
159    }
160
161    pub(super) fn get_status(&self) -> ConnStatusType {
162        unsafe { PQstatus(self.internal_connection.as_ptr()) }
163    }
164
165    pub(crate) fn get_next_result(&self) -> Result<Option<PgResult>, Error> {
166        let res = unsafe { PQgetResult(self.internal_connection.as_ptr()) };
167        if res.is_null() {
168            Ok(None)
169        } else {
170            let raw = RawResult::new(res, self)?;
171            Ok(Some(PgResult::new(raw, self)?))
172        }
173    }
174
175    pub(crate) fn enable_row_by_row_mode(&self) -> QueryResult<()> {
176        let res = unsafe { PQsetSingleRowMode(self.internal_connection.as_ptr()) };
177        if res == 1 {
178            Ok(())
179        } else {
180            Err(Error::DatabaseError(
181                DatabaseErrorKind::Unknown,
182                Box::new(self.last_error_message()),
183            ))
184        }
185    }
186
187    pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> {
188        for c in buf.chunks(i32::MAX as usize) {
189            let res = unsafe {
190                pq_sys::PQputCopyData(
191                    self.internal_connection.as_ptr(),
192                    c.as_ptr() as *const libc::c_char,
193                    c.len()
194                        .try_into()
195                        .map_err(|e| Error::SerializationError(Box::new(e)))?,
196                )
197            };
198            if res != 1 {
199                return Err(Error::DatabaseError(
200                    DatabaseErrorKind::Unknown,
201                    Box::new(self.last_error_message()),
202                ));
203            }
204        }
205        Ok(())
206    }
207
208    pub(crate) fn finish_copy_from(&self, err: Option<String>) -> QueryResult<()> {
209        let error = err.map(CString::new).map(|r| {
210            r.unwrap_or_else(|_| {
211                CString::new("Error message contains a \\0 byte")
212                    .expect("Does not contain a null byte")
213            })
214        });
215        let error = error
216            .as_ref()
217            .map(|l| l.as_ptr())
218            .unwrap_or(core::ptr::null());
219        let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) };
220        if ret == 1 {
221            Ok(())
222        } else {
223            Err(Error::DatabaseError(
224                DatabaseErrorKind::Unknown,
225                Box::new(self.last_error_message()),
226            ))
227        }
228    }
229
230    pub(super) fn pq_notifies(&self) -> Result<Option<PgNotification>, Error> {
231        let conn = self.internal_connection;
232        let ret = unsafe { PQconsumeInput(conn.as_ptr()) };
233        if ret == 0 {
234            return Err(Error::DatabaseError(
235                DatabaseErrorKind::Unknown,
236                Box::new(self.last_error_message()),
237            ));
238        }
239
240        let pgnotify = unsafe { PQnotifies(conn.as_ptr()) };
241        if pgnotify.is_null() {
242            Ok(None)
243        } else {
244            // we use a drop guard here to
245            // make sure that we always free
246            // the provided pointer, even if we
247            // somehow return an error below
248            struct Guard<'a> {
249                value: &'a mut pgNotify,
250            }
251
252            impl Drop for Guard<'_> {
253                fn drop(&mut self) {
254                    unsafe {
255                        // SAFETY: We know that this value is not null here
256                        PQfreemem(self.value as *mut pgNotify as *mut core::ffi::c_void)
257                    };
258                }
259            }
260
261            let pgnotify = unsafe {
262                // SAFETY: We checked for null values above
263                Guard {
264                    value: &mut *pgnotify,
265                }
266            };
267            if pgnotify.value.relname.is_null() {
268                return Err(Error::DeserializationError(
269                    "Received an unexpected null value for `relname` from the notification".into(),
270                ));
271            }
272            if pgnotify.value.extra.is_null() {
273                return Err(Error::DeserializationError(
274                    "Received an unexpected null value for `extra` from the notification".into(),
275                ));
276            }
277
278            let channel = unsafe {
279                // SAFETY: We checked for null values above
280                CStr::from_ptr(pgnotify.value.relname)
281            }
282            .to_str()
283            .map_err(|e| Error::DeserializationError(e.into()))?
284            .to_string();
285            let payload = unsafe {
286                // SAFETY: We checked for null values above
287                CStr::from_ptr(pgnotify.value.extra)
288            }
289            .to_str()
290            .map_err(|e| Error::DeserializationError(e.into()))?
291            .to_string();
292            let ret = PgNotification {
293                process_id: pgnotify.value.be_pid,
294                channel,
295                payload,
296            };
297            Ok(Some(ret))
298        }
299    }
300}
301
302/// Represents the current in-transaction status of the connection
303#[derive(#[automatically_derived]
impl ::core::fmt::Debug for PgTransactionStatus {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f,
            match self {
                PgTransactionStatus::Idle => "Idle",
                PgTransactionStatus::Active => "Active",
                PgTransactionStatus::InTransaction => "InTransaction",
                PgTransactionStatus::InError => "InError",
                PgTransactionStatus::Unknown => "Unknown",
            })
    }
}Debug, #[automatically_derived]
impl ::core::cmp::PartialEq for PgTransactionStatus {
    #[inline]
    fn eq(&self, other: &PgTransactionStatus) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr
    }
}PartialEq, #[automatically_derived]
impl ::core::cmp::Eq for PgTransactionStatus {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {}
}Eq, #[automatically_derived]
impl ::core::clone::Clone for PgTransactionStatus {
    #[inline]
    fn clone(&self) -> PgTransactionStatus { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for PgTransactionStatus { }Copy)]
304pub(super) enum PgTransactionStatus {
305    /// Currently idle
306    Idle,
307    /// A command is in progress (sent to the server but not yet completed)
308    Active,
309    /// Idle, in a valid transaction block
310    InTransaction,
311    /// Idle, in a failed transaction block
312    InError,
313    /// Bad connection
314    Unknown,
315}
316
317impl From<PGTransactionStatusType> for PgTransactionStatus {
318    fn from(trans_status_type: PGTransactionStatusType) -> Self {
319        match trans_status_type {
320            PGTransactionStatusType::PQTRANS_IDLE => PgTransactionStatus::Idle,
321            PGTransactionStatusType::PQTRANS_ACTIVE => PgTransactionStatus::Active,
322            PGTransactionStatusType::PQTRANS_INTRANS => PgTransactionStatus::InTransaction,
323            PGTransactionStatusType::PQTRANS_INERROR => PgTransactionStatus::InError,
324            PGTransactionStatusType::PQTRANS_UNKNOWN => PgTransactionStatus::Unknown,
325        }
326    }
327}
328
329pub(super) type NoticeProcessor =
330    extern "C" fn(arg: *mut libc::c_void, message: *const libc::c_char);
331
332impl Drop for RawConnection {
333    fn drop(&mut self) {
334        unsafe { PQfinish(self.internal_connection.as_ptr()) };
335    }
336}
337
338fn last_error_message(conn: *const PGconn) -> String {
339    unsafe {
340        let error_ptr = PQerrorMessage(conn);
341        let bytes = CStr::from_ptr(error_ptr).to_bytes();
342        String::from_utf8_lossy(bytes).to_string()
343    }
344}
345
346/// Internal wrapper around a `*mut PGresult` which is known to be not-null, and
347/// have no aliases.  This wrapper is to ensure that it's always properly
348/// dropped.
349///
350/// If `Unique` is ever stabilized, we should use it here.
351#[allow(missing_debug_implementations)]
352pub(super) struct RawResult(NonNull<PGresult>);
353
354unsafe impl Send for RawResult {}
355unsafe impl Sync for RawResult {}
356
357impl RawResult {
358    #[allow(clippy::new_ret_no_self)]
359    fn new(ptr: *mut PGresult, conn: &RawConnection) -> QueryResult<Self> {
360        NonNull::new(ptr).map(RawResult).ok_or_else(|| {
361            Error::DatabaseError(
362                DatabaseErrorKind::UnableToSendCommand,
363                Box::new(conn.last_error_message()),
364            )
365        })
366    }
367
368    pub(super) fn as_ptr(&self) -> *mut PGresult {
369        self.0.as_ptr()
370    }
371
372    pub(super) fn error_message(&self) -> &str {
373        let ptr = unsafe { PQresultErrorMessage(self.0.as_ptr()) };
374        let cstr = unsafe { CStr::from_ptr(ptr) };
375        cstr.to_str().unwrap_or_default()
376    }
377}
378
379impl Drop for RawResult {
380    fn drop(&mut self) {
381        unsafe { PQclear(self.0.as_ptr()) }
382    }
383}