diesel/pg/connection/
raw.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(unsafe_code)] // ffi code
3
4extern crate pq_sys;
5
6use self::pq_sys::*;
7use std::ffi::{CStr, CString};
8use std::os::raw as libc;
9use std::ptr::NonNull;
10use std::{ptr, str};
11
12use crate::result::*;
13
14use super::result::PgResult;
15use crate::pg::PgNotification;
16
17#[allow(missing_debug_implementations, missing_copy_implementations)]
18pub(super) struct RawConnection {
19    pub(super) internal_connection: NonNull<PGconn>,
20}
21
22impl RawConnection {
23    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
24        let connection_string = CString::new(database_url)?;
25        let connection_ptr = unsafe { PQconnectdb(connection_string.as_ptr()) };
26        let connection_status = unsafe { PQstatus(connection_ptr) };
27
28        match connection_status {
29            ConnStatusType::CONNECTION_OK => {
30                let connection_ptr = unsafe { NonNull::new_unchecked(connection_ptr) };
31                Ok(RawConnection {
32                    internal_connection: connection_ptr,
33                })
34            }
35            _ => {
36                let message = last_error_message(connection_ptr);
37
38                if !connection_ptr.is_null() {
39                    // Note that even if the server connection attempt fails (as indicated by PQstatus),
40                    // the application should call PQfinish to free the memory used by the PGconn object.
41                    // https://www.postgresql.org/docs/current/libpq-connect.html
42                    unsafe { PQfinish(connection_ptr) }
43                }
44
45                Err(ConnectionError::BadConnection(message))
46            }
47        }
48    }
49
50    pub(super) fn last_error_message(&self) -> String {
51        last_error_message(self.internal_connection.as_ptr())
52    }
53
54    pub(super) fn set_notice_processor(&self, notice_processor: NoticeProcessor) {
55        unsafe {
56            PQsetNoticeProcessor(
57                self.internal_connection.as_ptr(),
58                Some(notice_processor),
59                ptr::null_mut(),
60            );
61        }
62    }
63
64    pub(super) unsafe fn exec(&self, query: *const libc::c_char) -> QueryResult<RawResult> {
65        RawResult::new(PQexec(self.internal_connection.as_ptr(), query), self)
66    }
67
68    /// Sends a query and parameters to the server without using the prepare/bind cycle.
69    ///
70    /// This method uses PQsendQueryParams which combines the prepare and bind steps
71    /// and is more compatible with connection poolers like PgBouncer.
72    pub(super) unsafe fn send_query_params(
73        &self,
74        query: *const libc::c_char,
75        param_count: libc::c_int,
76        param_types: *const Oid,
77        param_values: *const *const libc::c_char,
78        param_lengths: *const libc::c_int,
79        param_formats: *const libc::c_int,
80        result_format: libc::c_int,
81    ) -> QueryResult<()> {
82        let res = PQsendQueryParams(
83            self.internal_connection.as_ptr(),
84            query,
85            param_count,
86            param_types,
87            param_values,
88            param_lengths,
89            param_formats,
90            result_format,
91        );
92        if res == 1 {
93            Ok(())
94        } else {
95            Err(Error::DatabaseError(
96                DatabaseErrorKind::UnableToSendCommand,
97                Box::new(self.last_error_message()),
98            ))
99        }
100    }
101
102    pub(super) unsafe fn send_query_prepared(
103        &self,
104        stmt_name: *const libc::c_char,
105        param_count: libc::c_int,
106        param_values: *const *const libc::c_char,
107        param_lengths: *const libc::c_int,
108        param_formats: *const libc::c_int,
109        result_format: libc::c_int,
110    ) -> QueryResult<()> {
111        let res = PQsendQueryPrepared(
112            self.internal_connection.as_ptr(),
113            stmt_name,
114            param_count,
115            param_values,
116            param_lengths,
117            param_formats,
118            result_format,
119        );
120        if res == 1 {
121            Ok(())
122        } else {
123            Err(Error::DatabaseError(
124                DatabaseErrorKind::UnableToSendCommand,
125                Box::new(self.last_error_message()),
126            ))
127        }
128    }
129
130    pub(super) unsafe fn prepare(
131        &self,
132        stmt_name: *const libc::c_char,
133        query: *const libc::c_char,
134        param_count: libc::c_int,
135        param_types: *const Oid,
136    ) -> QueryResult<RawResult> {
137        let ptr = PQprepare(
138            self.internal_connection.as_ptr(),
139            stmt_name,
140            query,
141            param_count,
142            param_types,
143        );
144        RawResult::new(ptr, self)
145    }
146
147    /// This is reasonably inexpensive as it just accesses variables internal to the connection
148    /// that are kept up to date by the `ReadyForQuery` messages from the PG server
149    pub(super) fn transaction_status(&self) -> PgTransactionStatus {
150        unsafe { PQtransactionStatus(self.internal_connection.as_ptr()) }.into()
151    }
152
153    pub(super) fn get_status(&self) -> ConnStatusType {
154        unsafe { PQstatus(self.internal_connection.as_ptr()) }
155    }
156
157    pub(crate) fn get_next_result(&self) -> Result<Option<PgResult>, Error> {
158        let res = unsafe { PQgetResult(self.internal_connection.as_ptr()) };
159        if res.is_null() {
160            Ok(None)
161        } else {
162            let raw = RawResult::new(res, self)?;
163            Ok(Some(PgResult::new(raw, self)?))
164        }
165    }
166
167    pub(crate) fn enable_row_by_row_mode(&self) -> QueryResult<()> {
168        let res = unsafe { PQsetSingleRowMode(self.internal_connection.as_ptr()) };
169        if res == 1 {
170            Ok(())
171        } else {
172            Err(Error::DatabaseError(
173                DatabaseErrorKind::Unknown,
174                Box::new(self.last_error_message()),
175            ))
176        }
177    }
178
179    pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> {
180        for c in buf.chunks(i32::MAX as usize) {
181            let res = unsafe {
182                pq_sys::PQputCopyData(
183                    self.internal_connection.as_ptr(),
184                    c.as_ptr() as *const libc::c_char,
185                    c.len()
186                        .try_into()
187                        .map_err(|e| Error::SerializationError(Box::new(e)))?,
188                )
189            };
190            if res != 1 {
191                return Err(Error::DatabaseError(
192                    DatabaseErrorKind::Unknown,
193                    Box::new(self.last_error_message()),
194                ));
195            }
196        }
197        Ok(())
198    }
199
200    pub(crate) fn finish_copy_from(&self, err: Option<String>) -> QueryResult<()> {
201        let error = err.map(CString::new).map(|r| {
202            r.unwrap_or_else(|_| {
203                CString::new("Error message contains a \\0 byte")
204                    .expect("Does not contain a null byte")
205            })
206        });
207        let error = error
208            .as_ref()
209            .map(|l| l.as_ptr())
210            .unwrap_or(std::ptr::null());
211        let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) };
212        if ret == 1 {
213            Ok(())
214        } else {
215            Err(Error::DatabaseError(
216                DatabaseErrorKind::Unknown,
217                Box::new(self.last_error_message()),
218            ))
219        }
220    }
221
222    pub(super) fn pq_notifies(&self) -> Result<Option<PgNotification>, Error> {
223        let conn = self.internal_connection;
224        let ret = unsafe { PQconsumeInput(conn.as_ptr()) };
225        if ret == 0 {
226            return Err(Error::DatabaseError(
227                DatabaseErrorKind::Unknown,
228                Box::new(self.last_error_message()),
229            ));
230        }
231
232        let pgnotify = unsafe { PQnotifies(conn.as_ptr()) };
233        if pgnotify.is_null() {
234            Ok(None)
235        } else {
236            // we use a drop guard here to
237            // make sure that we always free
238            // the provided pointer, even if we
239            // somehow return an error below
240            struct Guard<'a> {
241                value: &'a mut pgNotify,
242            }
243
244            impl Drop for Guard<'_> {
245                fn drop(&mut self) {
246                    unsafe {
247                        // SAFETY: We know that this value is not null here
248                        PQfreemem(self.value as *mut pgNotify as *mut std::ffi::c_void)
249                    };
250                }
251            }
252
253            let pgnotify = unsafe {
254                // SAFETY: We checked for null values above
255                Guard {
256                    value: &mut *pgnotify,
257                }
258            };
259            if pgnotify.value.relname.is_null() {
260                return Err(Error::DeserializationError(
261                    "Received an unexpected null value for `relname` from the notification".into(),
262                ));
263            }
264            if pgnotify.value.extra.is_null() {
265                return Err(Error::DeserializationError(
266                    "Received an unexpected null value for `extra` from the notification".into(),
267                ));
268            }
269
270            let channel = unsafe {
271                // SAFETY: We checked for null values above
272                CStr::from_ptr(pgnotify.value.relname)
273            }
274            .to_str()
275            .map_err(|e| Error::DeserializationError(e.into()))?
276            .to_string();
277            let payload = unsafe {
278                // SAFETY: We checked for null values above
279                CStr::from_ptr(pgnotify.value.extra)
280            }
281            .to_str()
282            .map_err(|e| Error::DeserializationError(e.into()))?
283            .to_string();
284            let ret = PgNotification {
285                process_id: pgnotify.value.be_pid,
286                channel,
287                payload,
288            };
289            Ok(Some(ret))
290        }
291    }
292}
293
294/// Represents the current in-transaction status of the connection
295#[derive(Debug, PartialEq, Eq, Clone, Copy)]
296pub(super) enum PgTransactionStatus {
297    /// Currently idle
298    Idle,
299    /// A command is in progress (sent to the server but not yet completed)
300    Active,
301    /// Idle, in a valid transaction block
302    InTransaction,
303    /// Idle, in a failed transaction block
304    InError,
305    /// Bad connection
306    Unknown,
307}
308
309impl From<PGTransactionStatusType> for PgTransactionStatus {
310    fn from(trans_status_type: PGTransactionStatusType) -> Self {
311        match trans_status_type {
312            PGTransactionStatusType::PQTRANS_IDLE => PgTransactionStatus::Idle,
313            PGTransactionStatusType::PQTRANS_ACTIVE => PgTransactionStatus::Active,
314            PGTransactionStatusType::PQTRANS_INTRANS => PgTransactionStatus::InTransaction,
315            PGTransactionStatusType::PQTRANS_INERROR => PgTransactionStatus::InError,
316            PGTransactionStatusType::PQTRANS_UNKNOWN => PgTransactionStatus::Unknown,
317        }
318    }
319}
320
321pub(super) type NoticeProcessor =
322    extern "C" fn(arg: *mut libc::c_void, message: *const libc::c_char);
323
324impl Drop for RawConnection {
325    fn drop(&mut self) {
326        unsafe { PQfinish(self.internal_connection.as_ptr()) };
327    }
328}
329
330fn last_error_message(conn: *const PGconn) -> String {
331    unsafe {
332        let error_ptr = PQerrorMessage(conn);
333        let bytes = CStr::from_ptr(error_ptr).to_bytes();
334        String::from_utf8_lossy(bytes).to_string()
335    }
336}
337
338/// Internal wrapper around a `*mut PGresult` which is known to be not-null, and
339/// have no aliases.  This wrapper is to ensure that it's always properly
340/// dropped.
341///
342/// If `Unique` is ever stabilized, we should use it here.
343#[allow(missing_debug_implementations)]
344pub(super) struct RawResult(NonNull<PGresult>);
345
346unsafe impl Send for RawResult {}
347unsafe impl Sync for RawResult {}
348
349impl RawResult {
350    #[allow(clippy::new_ret_no_self)]
351    fn new(ptr: *mut PGresult, conn: &RawConnection) -> QueryResult<Self> {
352        NonNull::new(ptr).map(RawResult).ok_or_else(|| {
353            Error::DatabaseError(
354                DatabaseErrorKind::UnableToSendCommand,
355                Box::new(conn.last_error_message()),
356            )
357        })
358    }
359
360    pub(super) fn as_ptr(&self) -> *mut PGresult {
361        self.0.as_ptr()
362    }
363
364    pub(super) fn error_message(&self) -> &str {
365        let ptr = unsafe { PQresultErrorMessage(self.0.as_ptr()) };
366        let cstr = unsafe { CStr::from_ptr(ptr) };
367        cstr.to_str().unwrap_or_default()
368    }
369}
370
371impl Drop for RawResult {
372    fn drop(&mut self) {
373        unsafe { PQclear(self.0.as_ptr()) }
374    }
375}