#![allow(unsafe_code)] extern crate pq_sys;
use self::pq_sys::*;
use std::ffi::CStr;
use std::num::NonZeroU32;
use std::os::raw as libc;
use std::rc::Rc;
use std::{slice, str};
use super::raw::{RawConnection, RawResult};
use super::row::PgRow;
use crate::result::{DatabaseErrorInformation, DatabaseErrorKind, Error, QueryResult};
use std::cell::OnceCell;
#[allow(missing_debug_implementations)]
pub struct PgResult {
internal_result: RawResult,
column_count: libc::c_int,
row_count: libc::c_int,
column_name_map: OnceCell<Vec<Option<*const str>>>,
}
impl PgResult {
#[allow(clippy::new_ret_no_self)]
pub(super) fn new(internal_result: RawResult, conn: &RawConnection) -> QueryResult<Self> {
let result_status = unsafe { PQresultStatus(internal_result.as_ptr()) };
match result_status {
ExecStatusType::PGRES_SINGLE_TUPLE
| ExecStatusType::PGRES_COMMAND_OK
| ExecStatusType::PGRES_COPY_IN
| ExecStatusType::PGRES_COPY_OUT
| ExecStatusType::PGRES_TUPLES_OK => {
let column_count = unsafe { PQnfields(internal_result.as_ptr()) };
let row_count = unsafe { PQntuples(internal_result.as_ptr()) };
Ok(PgResult {
internal_result,
column_count,
row_count,
column_name_map: OnceCell::new(),
})
}
ExecStatusType::PGRES_EMPTY_QUERY => {
let error_message = "Received an empty query".to_string();
Err(Error::DatabaseError(
DatabaseErrorKind::Unknown,
Box::new(error_message),
))
}
_ => {
while conn.get_next_result().map_or(true, |r| r.is_some()) {}
let mut error_kind =
match get_result_field(internal_result.as_ptr(), ResultField::SqlState) {
Some(error_codes::UNIQUE_VIOLATION) => DatabaseErrorKind::UniqueViolation,
Some(error_codes::FOREIGN_KEY_VIOLATION) => {
DatabaseErrorKind::ForeignKeyViolation
}
Some(error_codes::SERIALIZATION_FAILURE) => {
DatabaseErrorKind::SerializationFailure
}
Some(error_codes::READ_ONLY_TRANSACTION) => {
DatabaseErrorKind::ReadOnlyTransaction
}
Some(error_codes::NOT_NULL_VIOLATION) => {
DatabaseErrorKind::NotNullViolation
}
Some(error_codes::CHECK_VIOLATION) => DatabaseErrorKind::CheckViolation,
Some(error_codes::CONNECTION_EXCEPTION)
| Some(error_codes::CONNECTION_FAILURE)
| Some(error_codes::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION)
| Some(error_codes::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION) => {
DatabaseErrorKind::ClosedConnection
}
_ => DatabaseErrorKind::Unknown,
};
let error_information = Box::new(PgErrorInformation(internal_result));
let conn_status = conn.get_status();
if conn_status == ConnStatusType::CONNECTION_BAD {
error_kind = DatabaseErrorKind::ClosedConnection;
}
Err(Error::DatabaseError(error_kind, error_information))
}
}
}
pub(super) fn rows_affected(&self) -> usize {
unsafe {
let count_char_ptr = PQcmdTuples(self.internal_result.as_ptr());
let count_bytes = CStr::from_ptr(count_char_ptr).to_bytes();
let count_str = str::from_utf8_unchecked(count_bytes);
match count_str {
"" => 0,
_ => count_str
.parse()
.expect("Error parsing `rows_affected` as integer value"),
}
}
}
pub(super) fn num_rows(&self) -> usize {
self.row_count.try_into().expect(
"Diesel expects to run on a >= 32 bit OS \
(or libpq is giving out negative row count)",
)
}
pub(super) fn get_row(self: Rc<Self>, idx: usize) -> PgRow {
PgRow::new(self, idx)
}
pub(super) fn get(&self, row_idx: usize, col_idx: usize) -> Option<&[u8]> {
if self.is_null(row_idx, col_idx) {
None
} else {
let row_idx = row_idx.try_into().ok()?;
let col_idx = col_idx.try_into().ok()?;
unsafe {
let value_ptr =
PQgetvalue(self.internal_result.as_ptr(), row_idx, col_idx) as *const u8;
let num_bytes = PQgetlength(self.internal_result.as_ptr(), row_idx, col_idx);
Some(slice::from_raw_parts(
value_ptr,
num_bytes
.try_into()
.expect("Diesel expects at least a 32 bit operating system"),
))
}
}
}
pub(super) fn is_null(&self, row_idx: usize, col_idx: usize) -> bool {
let row_idx = row_idx
.try_into()
.expect("Row indices are expected to fit into 32 bit");
let col_idx = col_idx
.try_into()
.expect("Column indices are expected to fit into 32 bit");
unsafe { 0 != PQgetisnull(self.internal_result.as_ptr(), row_idx, col_idx) }
}
pub(in crate::pg) fn column_type(&self, col_idx: usize) -> NonZeroU32 {
let col_idx: i32 = col_idx
.try_into()
.expect("Column indices are expected to fit into 32 bit");
let type_oid = unsafe { PQftype(self.internal_result.as_ptr(), col_idx) };
NonZeroU32::new(type_oid).expect(
"Got a zero oid from postgres. If you see this error message \
please report it as issue on the diesel github bug tracker.",
)
}
#[inline(always)] pub(super) fn column_name(&self, col_idx: usize) -> Option<&str> {
self.column_name_map
.get_or_init(|| {
(0..self.column_count)
.map(|idx| unsafe {
let ptr = PQfname(self.internal_result.as_ptr(), idx as libc::c_int);
if ptr.is_null() {
None
} else {
Some(CStr::from_ptr(ptr).to_str().expect(
"Expect postgres field names to be UTF-8, because we \
requested UTF-8 encoding on connection setup",
) as *const str)
}
})
.collect()
})
.get(col_idx)
.and_then(|n| {
n.map(|n: *const str| unsafe {
&*n
})
})
}
pub(super) fn column_count(&self) -> usize {
self.column_count.try_into().expect(
"Diesel expects to run on a >= 32 bit OS \
(or libpq is giving out negative column count)",
)
}
}
struct PgErrorInformation(RawResult);
impl DatabaseErrorInformation for PgErrorInformation {
fn message(&self) -> &str {
get_result_field(self.0.as_ptr(), ResultField::MessagePrimary)
.unwrap_or_else(|| self.0.error_message())
}
fn details(&self) -> Option<&str> {
get_result_field(self.0.as_ptr(), ResultField::MessageDetail)
}
fn hint(&self) -> Option<&str> {
get_result_field(self.0.as_ptr(), ResultField::MessageHint)
}
fn table_name(&self) -> Option<&str> {
get_result_field(self.0.as_ptr(), ResultField::TableName)
}
fn column_name(&self) -> Option<&str> {
get_result_field(self.0.as_ptr(), ResultField::ColumnName)
}
fn constraint_name(&self) -> Option<&str> {
get_result_field(self.0.as_ptr(), ResultField::ConstraintName)
}
fn statement_position(&self) -> Option<i32> {
let str_pos = get_result_field(self.0.as_ptr(), ResultField::StatementPosition)?;
str_pos.parse::<i32>().ok()
}
}
#[repr(i32)]
enum ResultField {
SqlState = 'C' as i32,
MessagePrimary = 'M' as i32,
MessageDetail = 'D' as i32,
MessageHint = 'H' as i32,
TableName = 't' as i32,
ColumnName = 'c' as i32,
ConstraintName = 'n' as i32,
StatementPosition = 'P' as i32,
}
fn get_result_field<'a>(res: *mut PGresult, field: ResultField) -> Option<&'a str> {
let ptr = unsafe { PQresultErrorField(res, field as libc::c_int) };
if ptr.is_null() {
return None;
}
let c_str = unsafe { CStr::from_ptr(ptr) };
c_str.to_str().ok()
}
mod error_codes {
pub(in crate::pg::connection) const CONNECTION_EXCEPTION: &str = "08000";
pub(in crate::pg::connection) const CONNECTION_FAILURE: &str = "08006";
pub(in crate::pg::connection) const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: &str = "08001";
pub(in crate::pg::connection) const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: &str =
"08004";
pub(in crate::pg::connection) const NOT_NULL_VIOLATION: &str = "23502";
pub(in crate::pg::connection) const FOREIGN_KEY_VIOLATION: &str = "23503";
pub(in crate::pg::connection) const UNIQUE_VIOLATION: &str = "23505";
pub(in crate::pg::connection) const CHECK_VIOLATION: &str = "23514";
pub(in crate::pg::connection) const READ_ONLY_TRANSACTION: &str = "25006";
pub(in crate::pg::connection) const SERIALIZATION_FAILURE: &str = "40001";
}