diesel/pg/connection/
raw.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(unsafe_code)] // ffi code
3
4extern crate pq_sys;
5
6use self::pq_sys::*;
7use std::ffi::{CStr, CString};
8use std::os::raw as libc;
9use std::ptr::NonNull;
10use std::{ptr, str};
11
12use crate::result::*;
13
14use super::result::PgResult;
15use crate::pg::PgNotification;
16
17#[allow(missing_debug_implementations, missing_copy_implementations)]
18pub(super) struct RawConnection {
19    pub(super) internal_connection: NonNull<PGconn>,
20}
21
22impl RawConnection {
23    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
24        let connection_string = CString::new(database_url)?;
25        let connection_ptr = unsafe { PQconnectdb(connection_string.as_ptr()) };
26        let connection_status = unsafe { PQstatus(connection_ptr) };
27
28        match connection_status {
29            ConnStatusType::CONNECTION_OK => {
30                let connection_ptr = unsafe { NonNull::new_unchecked(connection_ptr) };
31                Ok(RawConnection {
32                    internal_connection: connection_ptr,
33                })
34            }
35            _ => {
36                let message = last_error_message(connection_ptr);
37
38                if !connection_ptr.is_null() {
39                    // Note that even if the server connection attempt fails (as indicated by PQstatus),
40                    // the application should call PQfinish to free the memory used by the PGconn object.
41                    // https://www.postgresql.org/docs/current/libpq-connect.html
42                    unsafe { PQfinish(connection_ptr) }
43                }
44
45                Err(ConnectionError::BadConnection(message))
46            }
47        }
48    }
49
50    pub(super) fn last_error_message(&self) -> String {
51        last_error_message(self.internal_connection.as_ptr())
52    }
53
54    pub(super) fn set_notice_processor(&self, notice_processor: NoticeProcessor) {
55        unsafe {
56            PQsetNoticeProcessor(
57                self.internal_connection.as_ptr(),
58                Some(notice_processor),
59                ptr::null_mut(),
60            );
61        }
62    }
63
64    pub(super) unsafe fn exec(&self, query: *const libc::c_char) -> QueryResult<RawResult> {
65        let result_ptr = unsafe { PQexec(self.internal_connection.as_ptr(), query) };
66        RawResult::new(result_ptr, self)
67    }
68
69    /// Sends a query and parameters to the server without using the prepare/bind cycle.
70    ///
71    /// This method uses PQsendQueryParams which combines the prepare and bind steps
72    /// and is more compatible with connection poolers like PgBouncer.
73    pub(super) unsafe fn send_query_params(
74        &self,
75        query: *const libc::c_char,
76        param_count: libc::c_int,
77        param_types: *const Oid,
78        param_values: *const *const libc::c_char,
79        param_lengths: *const libc::c_int,
80        param_formats: *const libc::c_int,
81        result_format: libc::c_int,
82    ) -> QueryResult<()> {
83        let res = unsafe {
84            PQsendQueryParams(
85                self.internal_connection.as_ptr(),
86                query,
87                param_count,
88                param_types,
89                param_values,
90                param_lengths,
91                param_formats,
92                result_format,
93            )
94        };
95        if res == 1 {
96            Ok(())
97        } else {
98            Err(Error::DatabaseError(
99                DatabaseErrorKind::UnableToSendCommand,
100                Box::new(self.last_error_message()),
101            ))
102        }
103    }
104
105    pub(super) unsafe fn send_query_prepared(
106        &self,
107        stmt_name: *const libc::c_char,
108        param_count: libc::c_int,
109        param_values: *const *const libc::c_char,
110        param_lengths: *const libc::c_int,
111        param_formats: *const libc::c_int,
112        result_format: libc::c_int,
113    ) -> QueryResult<()> {
114        let res = unsafe {
115            PQsendQueryPrepared(
116                self.internal_connection.as_ptr(),
117                stmt_name,
118                param_count,
119                param_values,
120                param_lengths,
121                param_formats,
122                result_format,
123            )
124        };
125        if res == 1 {
126            Ok(())
127        } else {
128            Err(Error::DatabaseError(
129                DatabaseErrorKind::UnableToSendCommand,
130                Box::new(self.last_error_message()),
131            ))
132        }
133    }
134
135    pub(super) unsafe fn prepare(
136        &self,
137        stmt_name: *const libc::c_char,
138        query: *const libc::c_char,
139        param_count: libc::c_int,
140        param_types: *const Oid,
141    ) -> QueryResult<RawResult> {
142        let ptr = unsafe {
143            PQprepare(
144                self.internal_connection.as_ptr(),
145                stmt_name,
146                query,
147                param_count,
148                param_types,
149            )
150        };
151        RawResult::new(ptr, self)
152    }
153
154    /// This is reasonably inexpensive as it just accesses variables internal to the connection
155    /// that are kept up to date by the `ReadyForQuery` messages from the PG server
156    pub(super) fn transaction_status(&self) -> PgTransactionStatus {
157        unsafe { PQtransactionStatus(self.internal_connection.as_ptr()) }.into()
158    }
159
160    pub(super) fn get_status(&self) -> ConnStatusType {
161        unsafe { PQstatus(self.internal_connection.as_ptr()) }
162    }
163
164    pub(crate) fn get_next_result(&self) -> Result<Option<PgResult>, Error> {
165        let res = unsafe { PQgetResult(self.internal_connection.as_ptr()) };
166        if res.is_null() {
167            Ok(None)
168        } else {
169            let raw = RawResult::new(res, self)?;
170            Ok(Some(PgResult::new(raw, self)?))
171        }
172    }
173
174    pub(crate) fn enable_row_by_row_mode(&self) -> QueryResult<()> {
175        let res = unsafe { PQsetSingleRowMode(self.internal_connection.as_ptr()) };
176        if res == 1 {
177            Ok(())
178        } else {
179            Err(Error::DatabaseError(
180                DatabaseErrorKind::Unknown,
181                Box::new(self.last_error_message()),
182            ))
183        }
184    }
185
186    pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> {
187        for c in buf.chunks(i32::MAX as usize) {
188            let res = unsafe {
189                pq_sys::PQputCopyData(
190                    self.internal_connection.as_ptr(),
191                    c.as_ptr() as *const libc::c_char,
192                    c.len()
193                        .try_into()
194                        .map_err(|e| Error::SerializationError(Box::new(e)))?,
195                )
196            };
197            if res != 1 {
198                return Err(Error::DatabaseError(
199                    DatabaseErrorKind::Unknown,
200                    Box::new(self.last_error_message()),
201                ));
202            }
203        }
204        Ok(())
205    }
206
207    pub(crate) fn finish_copy_from(&self, err: Option<String>) -> QueryResult<()> {
208        let error = err.map(CString::new).map(|r| {
209            r.unwrap_or_else(|_| {
210                CString::new("Error message contains a \\0 byte")
211                    .expect("Does not contain a null byte")
212            })
213        });
214        let error = error
215            .as_ref()
216            .map(|l| l.as_ptr())
217            .unwrap_or(std::ptr::null());
218        let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) };
219        if ret == 1 {
220            Ok(())
221        } else {
222            Err(Error::DatabaseError(
223                DatabaseErrorKind::Unknown,
224                Box::new(self.last_error_message()),
225            ))
226        }
227    }
228
229    pub(super) fn pq_notifies(&self) -> Result<Option<PgNotification>, Error> {
230        let conn = self.internal_connection;
231        let ret = unsafe { PQconsumeInput(conn.as_ptr()) };
232        if ret == 0 {
233            return Err(Error::DatabaseError(
234                DatabaseErrorKind::Unknown,
235                Box::new(self.last_error_message()),
236            ));
237        }
238
239        let pgnotify = unsafe { PQnotifies(conn.as_ptr()) };
240        if pgnotify.is_null() {
241            Ok(None)
242        } else {
243            // we use a drop guard here to
244            // make sure that we always free
245            // the provided pointer, even if we
246            // somehow return an error below
247            struct Guard<'a> {
248                value: &'a mut pgNotify,
249            }
250
251            impl Drop for Guard<'_> {
252                fn drop(&mut self) {
253                    unsafe {
254                        // SAFETY: We know that this value is not null here
255                        PQfreemem(self.value as *mut pgNotify as *mut std::ffi::c_void)
256                    };
257                }
258            }
259
260            let pgnotify = unsafe {
261                // SAFETY: We checked for null values above
262                Guard {
263                    value: &mut *pgnotify,
264                }
265            };
266            if pgnotify.value.relname.is_null() {
267                return Err(Error::DeserializationError(
268                    "Received an unexpected null value for `relname` from the notification".into(),
269                ));
270            }
271            if pgnotify.value.extra.is_null() {
272                return Err(Error::DeserializationError(
273                    "Received an unexpected null value for `extra` from the notification".into(),
274                ));
275            }
276
277            let channel = unsafe {
278                // SAFETY: We checked for null values above
279                CStr::from_ptr(pgnotify.value.relname)
280            }
281            .to_str()
282            .map_err(|e| Error::DeserializationError(e.into()))?
283            .to_string();
284            let payload = unsafe {
285                // SAFETY: We checked for null values above
286                CStr::from_ptr(pgnotify.value.extra)
287            }
288            .to_str()
289            .map_err(|e| Error::DeserializationError(e.into()))?
290            .to_string();
291            let ret = PgNotification {
292                process_id: pgnotify.value.be_pid,
293                channel,
294                payload,
295            };
296            Ok(Some(ret))
297        }
298    }
299}
300
301/// Represents the current in-transaction status of the connection
302#[derive(Debug, PartialEq, Eq, Clone, Copy)]
303pub(super) enum PgTransactionStatus {
304    /// Currently idle
305    Idle,
306    /// A command is in progress (sent to the server but not yet completed)
307    Active,
308    /// Idle, in a valid transaction block
309    InTransaction,
310    /// Idle, in a failed transaction block
311    InError,
312    /// Bad connection
313    Unknown,
314}
315
316impl From<PGTransactionStatusType> for PgTransactionStatus {
317    fn from(trans_status_type: PGTransactionStatusType) -> Self {
318        match trans_status_type {
319            PGTransactionStatusType::PQTRANS_IDLE => PgTransactionStatus::Idle,
320            PGTransactionStatusType::PQTRANS_ACTIVE => PgTransactionStatus::Active,
321            PGTransactionStatusType::PQTRANS_INTRANS => PgTransactionStatus::InTransaction,
322            PGTransactionStatusType::PQTRANS_INERROR => PgTransactionStatus::InError,
323            PGTransactionStatusType::PQTRANS_UNKNOWN => PgTransactionStatus::Unknown,
324        }
325    }
326}
327
328pub(super) type NoticeProcessor =
329    extern "C" fn(arg: *mut libc::c_void, message: *const libc::c_char);
330
331impl Drop for RawConnection {
332    fn drop(&mut self) {
333        unsafe { PQfinish(self.internal_connection.as_ptr()) };
334    }
335}
336
337fn last_error_message(conn: *const PGconn) -> String {
338    unsafe {
339        let error_ptr = PQerrorMessage(conn);
340        let bytes = CStr::from_ptr(error_ptr).to_bytes();
341        String::from_utf8_lossy(bytes).to_string()
342    }
343}
344
345/// Internal wrapper around a `*mut PGresult` which is known to be not-null, and
346/// have no aliases.  This wrapper is to ensure that it's always properly
347/// dropped.
348///
349/// If `Unique` is ever stabilized, we should use it here.
350#[allow(missing_debug_implementations)]
351pub(super) struct RawResult(NonNull<PGresult>);
352
353unsafe impl Send for RawResult {}
354unsafe impl Sync for RawResult {}
355
356impl RawResult {
357    #[allow(clippy::new_ret_no_self)]
358    fn new(ptr: *mut PGresult, conn: &RawConnection) -> QueryResult<Self> {
359        NonNull::new(ptr).map(RawResult).ok_or_else(|| {
360            Error::DatabaseError(
361                DatabaseErrorKind::UnableToSendCommand,
362                Box::new(conn.last_error_message()),
363            )
364        })
365    }
366
367    pub(super) fn as_ptr(&self) -> *mut PGresult {
368        self.0.as_ptr()
369    }
370
371    pub(super) fn error_message(&self) -> &str {
372        let ptr = unsafe { PQresultErrorMessage(self.0.as_ptr()) };
373        let cstr = unsafe { CStr::from_ptr(ptr) };
374        cstr.to_str().unwrap_or_default()
375    }
376}
377
378impl Drop for RawResult {
379    fn drop(&mut self) {
380        unsafe { PQclear(self.0.as_ptr()) }
381    }
382}