diesel/pg/connection/
raw.rs

1#![allow(clippy::too_many_arguments)]
2#![allow(unsafe_code)] // ffi code
3
4extern crate pq_sys;
5
6use self::pq_sys::*;
7use std::ffi::{CStr, CString};
8use std::os::raw as libc;
9use std::ptr::NonNull;
10use std::{ptr, str};
11
12use crate::result::*;
13
14use super::result::PgResult;
15
16#[allow(missing_debug_implementations, missing_copy_implementations)]
17pub(super) struct RawConnection {
18    pub(super) internal_connection: NonNull<PGconn>,
19}
20
21impl RawConnection {
22    pub(super) fn establish(database_url: &str) -> ConnectionResult<Self> {
23        let connection_string = CString::new(database_url)?;
24        let connection_ptr = unsafe { PQconnectdb(connection_string.as_ptr()) };
25        let connection_status = unsafe { PQstatus(connection_ptr) };
26
27        match connection_status {
28            ConnStatusType::CONNECTION_OK => {
29                let connection_ptr = unsafe { NonNull::new_unchecked(connection_ptr) };
30                Ok(RawConnection {
31                    internal_connection: connection_ptr,
32                })
33            }
34            _ => {
35                let message = last_error_message(connection_ptr);
36
37                if !connection_ptr.is_null() {
38                    // Note that even if the server connection attempt fails (as indicated by PQstatus),
39                    // the application should call PQfinish to free the memory used by the PGconn object.
40                    // https://www.postgresql.org/docs/current/libpq-connect.html
41                    unsafe { PQfinish(connection_ptr) }
42                }
43
44                Err(ConnectionError::BadConnection(message))
45            }
46        }
47    }
48
49    pub(super) fn last_error_message(&self) -> String {
50        last_error_message(self.internal_connection.as_ptr())
51    }
52
53    pub(super) fn set_notice_processor(&self, notice_processor: NoticeProcessor) {
54        unsafe {
55            PQsetNoticeProcessor(
56                self.internal_connection.as_ptr(),
57                Some(notice_processor),
58                ptr::null_mut(),
59            );
60        }
61    }
62
63    pub(super) unsafe fn exec(&self, query: *const libc::c_char) -> QueryResult<RawResult> {
64        RawResult::new(PQexec(self.internal_connection.as_ptr(), query), self)
65    }
66
67    pub(super) unsafe fn send_query_prepared(
68        &self,
69        stmt_name: *const libc::c_char,
70        param_count: libc::c_int,
71        param_values: *const *const libc::c_char,
72        param_lengths: *const libc::c_int,
73        param_formats: *const libc::c_int,
74        result_format: libc::c_int,
75    ) -> QueryResult<()> {
76        let res = PQsendQueryPrepared(
77            self.internal_connection.as_ptr(),
78            stmt_name,
79            param_count,
80            param_values,
81            param_lengths,
82            param_formats,
83            result_format,
84        );
85        if res == 1 {
86            Ok(())
87        } else {
88            Err(Error::DatabaseError(
89                DatabaseErrorKind::UnableToSendCommand,
90                Box::new(self.last_error_message()),
91            ))
92        }
93    }
94
95    pub(super) unsafe fn prepare(
96        &self,
97        stmt_name: *const libc::c_char,
98        query: *const libc::c_char,
99        param_count: libc::c_int,
100        param_types: *const Oid,
101    ) -> QueryResult<RawResult> {
102        let ptr = PQprepare(
103            self.internal_connection.as_ptr(),
104            stmt_name,
105            query,
106            param_count,
107            param_types,
108        );
109        RawResult::new(ptr, self)
110    }
111
112    /// This is reasonably inexpensive as it just accesses variables internal to the connection
113    /// that are kept up to date by the `ReadyForQuery` messages from the PG server
114    pub(super) fn transaction_status(&self) -> PgTransactionStatus {
115        unsafe { PQtransactionStatus(self.internal_connection.as_ptr()) }.into()
116    }
117
118    pub(super) fn get_status(&self) -> ConnStatusType {
119        unsafe { PQstatus(self.internal_connection.as_ptr()) }
120    }
121
122    pub(crate) fn get_next_result(&self) -> Result<Option<PgResult>, Error> {
123        let res = unsafe { PQgetResult(self.internal_connection.as_ptr()) };
124        if res.is_null() {
125            Ok(None)
126        } else {
127            let raw = RawResult::new(res, self)?;
128            Ok(Some(PgResult::new(raw, self)?))
129        }
130    }
131
132    pub(crate) fn enable_row_by_row_mode(&self) -> QueryResult<()> {
133        let res = unsafe { PQsetSingleRowMode(self.internal_connection.as_ptr()) };
134        if res == 1 {
135            Ok(())
136        } else {
137            Err(Error::DatabaseError(
138                DatabaseErrorKind::Unknown,
139                Box::new(self.last_error_message()),
140            ))
141        }
142    }
143
144    pub(super) fn put_copy_data(&mut self, buf: &[u8]) -> QueryResult<()> {
145        for c in buf.chunks(i32::MAX as usize) {
146            let res = unsafe {
147                pq_sys::PQputCopyData(
148                    self.internal_connection.as_ptr(),
149                    c.as_ptr() as *const libc::c_char,
150                    c.len()
151                        .try_into()
152                        .map_err(|e| Error::SerializationError(Box::new(e)))?,
153                )
154            };
155            if res != 1 {
156                return Err(Error::DatabaseError(
157                    DatabaseErrorKind::Unknown,
158                    Box::new(self.last_error_message()),
159                ));
160            }
161        }
162        Ok(())
163    }
164
165    pub(crate) fn finish_copy_from(&self, err: Option<String>) -> QueryResult<()> {
166        let error = err.map(CString::new).map(|r| {
167            r.unwrap_or_else(|_| {
168                CString::new("Error message contains a \\0 byte")
169                    .expect("Does not contain a null byte")
170            })
171        });
172        let error = error
173            .as_ref()
174            .map(|l| l.as_ptr())
175            .unwrap_or(std::ptr::null());
176        let ret = unsafe { pq_sys::PQputCopyEnd(self.internal_connection.as_ptr(), error) };
177        if ret == 1 {
178            Ok(())
179        } else {
180            Err(Error::DatabaseError(
181                DatabaseErrorKind::Unknown,
182                Box::new(self.last_error_message()),
183            ))
184        }
185    }
186}
187
188/// Represents the current in-transaction status of the connection
189#[derive(Debug, PartialEq, Eq, Clone, Copy)]
190pub(super) enum PgTransactionStatus {
191    /// Currently idle
192    Idle,
193    /// A command is in progress (sent to the server but not yet completed)
194    Active,
195    /// Idle, in a valid transaction block
196    InTransaction,
197    /// Idle, in a failed transaction block
198    InError,
199    /// Bad connection
200    Unknown,
201}
202
203impl From<PGTransactionStatusType> for PgTransactionStatus {
204    fn from(trans_status_type: PGTransactionStatusType) -> Self {
205        match trans_status_type {
206            PGTransactionStatusType::PQTRANS_IDLE => PgTransactionStatus::Idle,
207            PGTransactionStatusType::PQTRANS_ACTIVE => PgTransactionStatus::Active,
208            PGTransactionStatusType::PQTRANS_INTRANS => PgTransactionStatus::InTransaction,
209            PGTransactionStatusType::PQTRANS_INERROR => PgTransactionStatus::InError,
210            PGTransactionStatusType::PQTRANS_UNKNOWN => PgTransactionStatus::Unknown,
211        }
212    }
213}
214
215pub(super) type NoticeProcessor =
216    extern "C" fn(arg: *mut libc::c_void, message: *const libc::c_char);
217
218impl Drop for RawConnection {
219    fn drop(&mut self) {
220        unsafe { PQfinish(self.internal_connection.as_ptr()) };
221    }
222}
223
224fn last_error_message(conn: *const PGconn) -> String {
225    unsafe {
226        let error_ptr = PQerrorMessage(conn);
227        let bytes = CStr::from_ptr(error_ptr).to_bytes();
228        String::from_utf8_lossy(bytes).to_string()
229    }
230}
231
232/// Internal wrapper around a `*mut PGresult` which is known to be not-null, and
233/// have no aliases.  This wrapper is to ensure that it's always properly
234/// dropped.
235///
236/// If `Unique` is ever stabilized, we should use it here.
237#[allow(missing_debug_implementations)]
238pub(super) struct RawResult(NonNull<PGresult>);
239
240unsafe impl Send for RawResult {}
241unsafe impl Sync for RawResult {}
242
243impl RawResult {
244    #[allow(clippy::new_ret_no_self)]
245    fn new(ptr: *mut PGresult, conn: &RawConnection) -> QueryResult<Self> {
246        NonNull::new(ptr).map(RawResult).ok_or_else(|| {
247            Error::DatabaseError(
248                DatabaseErrorKind::UnableToSendCommand,
249                Box::new(conn.last_error_message()),
250            )
251        })
252    }
253
254    pub(super) fn as_ptr(&self) -> *mut PGresult {
255        self.0.as_ptr()
256    }
257
258    pub(super) fn error_message(&self) -> &str {
259        let ptr = unsafe { PQresultErrorMessage(self.0.as_ptr()) };
260        let cstr = unsafe { CStr::from_ptr(ptr) };
261        cstr.to_str().unwrap_or_default()
262    }
263}
264
265impl Drop for RawResult {
266    fn drop(&mut self) {
267        unsafe { PQclear(self.0.as_ptr()) }
268    }
269}