1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
use crate::deserialize::{self, FromSql, FromSqlRow};
use crate::expression::AsExpression;
use crate::pg::{Pg, PgValue};
use crate::serialize::{self, IsNull, Output, ToSql};
use crate::sql_types;
use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
use std::error::Error;

#[cfg(feature = "quickcheck")]
mod quickcheck_impls;

#[derive(Debug, Clone, PartialEq, Eq, AsExpression, FromSqlRow)]
#[diesel(sql_type = sql_types::Numeric)]
/// Represents a NUMERIC value, closely mirroring the PG wire protocol
/// representation
pub enum PgNumeric {
    /// A positive number
    Positive {
        /// How many digits come before the decimal point?
        weight: i16,
        /// How many significant digits are there?
        scale: u16,
        /// The digits in this number, stored in base 10000
        digits: Vec<i16>,
    },
    /// A negative number
    Negative {
        /// How many digits come before the decimal point?
        weight: i16,
        /// How many significant digits are there?
        scale: u16,
        /// The digits in this number, stored in base 10000
        digits: Vec<i16>,
    },
    /// Not a number
    NaN,
}

#[derive(Debug, Clone, Copy)]
#[allow(dead_code)] // that's used by debug in the error impl
struct InvalidNumericSign(u16);

impl ::std::fmt::Display for InvalidNumericSign {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
        f.write_str("sign for numeric field was not one of 0, 0x4000, 0xC000")
    }
}

impl Error for InvalidNumericSign {}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Numeric, Pg> for PgNumeric {
    fn from_sql(bytes: PgValue<'_>) -> deserialize::Result<Self> {
        let mut bytes = bytes.as_bytes();
        let digit_count = bytes.read_u16::<NetworkEndian>()?;
        let mut digits = Vec::with_capacity(digit_count as usize);
        let weight = bytes.read_i16::<NetworkEndian>()?;
        let sign = bytes.read_u16::<NetworkEndian>()?;
        let scale = bytes.read_u16::<NetworkEndian>()?;
        for _ in 0..digit_count {
            digits.push(bytes.read_i16::<NetworkEndian>()?);
        }

        match sign {
            0 => Ok(PgNumeric::Positive {
                weight,
                scale,
                digits,
            }),
            0x4000 => Ok(PgNumeric::Negative {
                weight,
                scale,
                digits,
            }),
            0xC000 => Ok(PgNumeric::NaN),
            invalid => Err(Box::new(InvalidNumericSign(invalid))),
        }
    }
}

#[cfg(feature = "postgres_backend")]
impl ToSql<sql_types::Numeric, Pg> for PgNumeric {
    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
        let sign = match *self {
            PgNumeric::Positive { .. } => 0,
            PgNumeric::Negative { .. } => 0x4000,
            PgNumeric::NaN => 0xC000,
        };
        let empty_vec = Vec::new();
        let digits = match *self {
            PgNumeric::Positive { ref digits, .. } | PgNumeric::Negative { ref digits, .. } => {
                digits
            }
            PgNumeric::NaN => &empty_vec,
        };
        let weight = match *self {
            PgNumeric::Positive { weight, .. } | PgNumeric::Negative { weight, .. } => weight,
            PgNumeric::NaN => 0,
        };
        let scale = match *self {
            PgNumeric::Positive { scale, .. } | PgNumeric::Negative { scale, .. } => scale,
            PgNumeric::NaN => 0,
        };
        out.write_u16::<NetworkEndian>(digits.len().try_into()?)?;
        out.write_i16::<NetworkEndian>(weight)?;
        out.write_u16::<NetworkEndian>(sign)?;
        out.write_u16::<NetworkEndian>(scale)?;
        for digit in digits.iter() {
            out.write_i16::<NetworkEndian>(*digit)?;
        }

        Ok(IsNull::No)
    }
}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Float, Pg> for f32 {
    fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
        let mut bytes = value.as_bytes();

        if bytes.len() < 4 {
            return deserialize::Result::Err(
                "Received less than 4 bytes while decoding an f32. \
                 Was a numeric accidentally marked as float?"
                    .into(),
            );
        }

        if bytes.len() > 4 {
            return deserialize::Result::Err(
                "Received more than 4 bytes while decoding an f32. \
                 Was a double accidentally marked as float?"
                    .into(),
            );
        }

        bytes
            .read_f32::<NetworkEndian>()
            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
    }
}

#[cfg(feature = "postgres_backend")]
impl FromSql<sql_types::Double, Pg> for f64 {
    fn from_sql(value: PgValue<'_>) -> deserialize::Result<Self> {
        let mut bytes = value.as_bytes();

        if bytes.len() < 8 {
            return deserialize::Result::Err(
                "Received less than 8 bytes while decoding an f64. \
                    Was a float accidentally marked as double?"
                    .into(),
            );
        }

        if bytes.len() > 8 {
            return deserialize::Result::Err(
                "Received more than 8 bytes while decoding an f64. \
                    Was a numeric accidentally marked as double?"
                    .into(),
            );
        }

        bytes
            .read_f64::<NetworkEndian>()
            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
    }
}

#[cfg(feature = "postgres_backend")]
impl ToSql<sql_types::Float, Pg> for f32 {
    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
        out.write_f32::<NetworkEndian>(*self)
            .map(|_| IsNull::No)
            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
    }
}

#[cfg(feature = "postgres_backend")]
impl ToSql<sql_types::Double, Pg> for f64 {
    fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Pg>) -> serialize::Result {
        out.write_f64::<NetworkEndian>(*self)
            .map(|_| IsNull::No)
            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync>)
    }
}