Skip to content

Commit

Permalink
refactor(query): use bigint to handle the fallback of decimal op over…
Browse files Browse the repository at this point in the history
…flow (#16215)

* refactor(query): use bigint to handle fallback overflow

* refactor(query): use bigint to handle fallback overflow
  • Loading branch information
sundy-li authored Aug 9, 2024
1 parent aed29dc commit 5240970
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 19 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/query/expression/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ log = { workspace = true }
match-template = { workspace = true }
memchr = { version = "2", default-features = false }
micromarshal = "0.5.0"
num-bigint = "0.4.6"
num-traits = "0.2.15"
ordered-float = { workspace = true, features = ["serde", "rand", "borsh"] }
rand = { workspace = true }
Expand Down
101 changes: 88 additions & 13 deletions src/query/expression/src/types/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ops::Range;
Expand All @@ -25,9 +26,13 @@ use databend_common_io::display_decimal_128;
use databend_common_io::display_decimal_256;
use enum_as_inner::EnumAsInner;
use ethnum::i256;
use ethnum::u256;
use ethnum::AsI256;
use itertools::Itertools;
use num_bigint::BigInt;
use num_traits::FromBytes;
use num_traits::NumCast;
use num_traits::ToPrimitive;
use serde::Deserialize;
use serde::Serialize;

Expand Down Expand Up @@ -356,7 +361,7 @@ pub trait Decimal:
fn checked_mul(self, rhs: Self) -> Option<Self>;
fn checked_rem(self, rhs: Self) -> Option<Self>;

fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self>;
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self>;

// mul two decimals and return a decimal with rounding option
fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option<Self>;
Expand All @@ -368,6 +373,7 @@ pub trait Decimal:

fn from_float(value: f64) -> Self;
fn from_i128<U: Into<i128>>(value: U) -> Self;
fn from_bigint(value: BigInt) -> Option<Self>;

fn de_binary(bytes: &mut &[u8]) -> Self;
fn display(self, scale: u8) -> String;
Expand Down Expand Up @@ -471,7 +477,8 @@ impl Decimal for i128 {
}
}

fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self> {
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self> {
let mul = i256::e(mul_scale);
if self.is_negative() == rhs.is_negative() {
let res = (i256::from(self) * i256::from(mul) + i256::from(rhs) / 2) / i256::from(rhs);
Some(*res.low())
Expand Down Expand Up @@ -535,6 +542,10 @@ impl Decimal for i128 {
value.into()
}

fn from_bigint(value: BigInt) -> Option<Self> {
value.to_i128()
}

fn de_binary(bytes: &mut &[u8]) -> Self {
let bs: [u8; std::mem::size_of::<Self>()] =
bytes[0..std::mem::size_of::<Self>()].try_into().unwrap();
Expand Down Expand Up @@ -687,19 +698,48 @@ impl Decimal for i256 {

fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option<Self> {
let div = i256::e(shift_scale);
if self.is_negative() == rhs.is_negative() {
let ret: Option<i256> = if self.is_negative() == rhs.is_negative() {
self.checked_mul(rhs).map(|x| (x + div / 2) / div)
} else {
self.checked_mul(rhs).map(|x| (x - div / 2) / div)
}
};

ret.or_else(|| {
let a = BigInt::from_le_bytes(&self.to_le_bytes());
let b = BigInt::from_le_bytes(&rhs.to_le_bytes());
let div = BigInt::from(10).pow(shift_scale);
if self.is_negative() == rhs.is_negative() {
Self::from_bigint((a * b + div.clone() / 2) / div)
} else {
Self::from_bigint((a * b - div.clone() / 2) / div)
}
})
}

fn do_round_div(self, rhs: Self, mul: Self) -> Option<Self> {
if self.is_negative() == rhs.is_negative() {
fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option<Self> {
let fallback = || {
let a = BigInt::from_le_bytes(&self.to_le_bytes());
let b = BigInt::from_le_bytes(&rhs.to_le_bytes());
let mul = BigInt::from(10).pow(mul_scale);
if self.is_negative() == rhs.is_negative() {
Self::from_bigint((a * mul + b.clone() / 2) / b)
} else {
Self::from_bigint((a * mul - b.clone() / 2) / b)
}
};

if mul_scale >= MAX_DECIMAL256_PRECISION as _ {
return fallback();
}

let mul = i256::e(mul_scale);
let ret: Option<i256> = if self.is_negative() == rhs.is_negative() {
self.checked_mul(mul).map(|x| (x + rhs / 2) / rhs)
} else {
self.checked_mul(mul).map(|x| (x - rhs / 2) / rhs)
}
};

ret.or_else(fallback)
}

fn min_for_precision(to_precision: u8) -> Self {
Expand All @@ -725,6 +765,32 @@ impl Decimal for i256 {
i256::from(value.into())
}

fn from_bigint(value: BigInt) -> Option<Self> {
let mut ret: u256 = u256::ZERO;
let mut bits = 0;

for i in value.iter_u64_digits() {
if bits >= 256 {
return None;
}
ret |= u256::from(i) << bits;
bits += 64;
}

match value.sign() {
num_bigint::Sign::Plus => i256::try_from(ret).ok(),
num_bigint::Sign::NoSign => Some(i256::ZERO),
num_bigint::Sign::Minus => {
let m: u256 = u256::ONE << 255;
match ret.cmp(&m) {
Ordering::Less => Some(-i256::try_from(ret).unwrap()),
Ordering::Equal => Some(i256::MIN),
Ordering::Greater => None,
}
}
}
}

fn de_binary(bytes: &mut &[u8]) -> Self {
let bs: [u8; std::mem::size_of::<Self>()] =
bytes[0..std::mem::size_of::<Self>()].try_into().unwrap();
Expand Down Expand Up @@ -947,10 +1013,9 @@ impl DecimalDataType {
let l = a.leading_digits() + b.leading_digits();
precision = l + scale;
} else if is_divide {
let l = a.leading_digits() + b.scale();
scale = a.scale().max((a.scale() + 6).min(12));
// P = L + S
precision = l + scale;
scale = a.scale().max((a.scale() + 6).min(12)); // scale must be >= a.sale()
let l = a.leading_digits() + b.scale(); // l must be >= a.leading_digits()
precision = l + scale; // so precision must be >= a.precision()
} else if is_plus_minus {
scale = std::cmp::max(a.scale(), b.scale());
// for addition/subtraction, we add 1 to the width to ensure we don't overflow
Expand Down Expand Up @@ -984,8 +1049,18 @@ impl DecimalDataType {
result_type,
))
} else if is_divide {
let (a, b) = Self::div_common_type(a, b, result_type.size())?;
Ok((a, b, result_type))
let p = precision.max(a.precision()).max(b.precision());
Ok((
Self::from_size(DecimalSize {
precision: p,
scale: a.scale(),
})?,
Self::from_size(DecimalSize {
precision: p,
scale: b.scale(),
})?,
result_type,
))
} else {
Ok((result_type, result_type, result_type))
}
Expand Down
50 changes: 50 additions & 0 deletions src/query/expression/tests/it/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use databend_common_expression::types::decimal::DecimalSize;
use databend_common_expression::types::DataType;
use databend_common_expression::types::DecimalDataType;
use databend_common_expression::types::NumberDataType;
use ethnum::i256;
use num_bigint::BigInt;
use pretty_assertions::assert_eq;

#[test]
Expand Down Expand Up @@ -168,3 +170,51 @@ fn test_float_to_128() {
assert_eq!(r, b);
}
}

#[test]
fn test_from_bigint() {
let cases = vec![
("0", 0i128),
("12345", 12345i128),
("-1", -1i128),
("-170141183460469231731687303715884105728", i128::MIN),
("170141183460469231731687303715884105727", i128::MAX),
];

for (a, b) in cases {
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
assert_eq!(i128::from_bigint(r), Some(b));
}

let cases = vec![
("0".to_string(), i256::ZERO),
("12345".to_string(), i256::from(12345)),
("-1".to_string(), i256::from(-1)),
(
"12".repeat(25),
i256::from_str_radix(&"12".repeat(25), 10).unwrap(),
),
(
"1".repeat(26),
i256::from_str_radix(&"1".repeat(26), 10).unwrap(),
),
(i256::MIN.to_string(), i256::MIN),
(i256::MAX.to_string(), i256::MAX),
];

for (a, b) in cases {
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
assert_eq!(i256::from_bigint(r), Some(b));
}

let cases = vec![
("1".repeat(78), None),
("12".repeat(78), None),
("234".repeat(78), None),
];

for (a, b) in cases {
let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap();
assert_eq!(i256::from_bigint(r), b);
}
}
5 changes: 2 additions & 3 deletions src/query/functions/src/scalars/decimal/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ macro_rules! binary_decimal {
let scale_b = $right.scale();

// Note: the result scale is always larger than the left scale
let scale_mul = scale_b + $size.scale - scale_a;
let multiplier = T::e(scale_mul as u32);
let scale_mul = (scale_b + $size.scale - scale_a) as u32;
let func = |a: T, b: T, result: &mut Vec<T>, ctx: &mut EvalContext| {
// We are using round div here which follow snowflake's behavior: https://docs.snowflake.com/sql-reference/operators-arithmetic
// For example:
Expand All @@ -102,7 +101,7 @@ macro_rules! binary_decimal {
ctx.set_error(result.len(), "divided by zero");
result.push(one);
} else {
match a.do_round_div(b, multiplier) {
match a.do_round_div(b, scale_mul) {
Some(t) => result.push(t),
None => {
ctx.set_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,13 @@ SELECT CAST(987654321.34 AS DECIMAL(20, 2)) / CAST(1.23 AS DECIMAL(6, 2)) AS res
----
802970992.95934959

query IIIIII
select 3.33 a , ('3.' || repeat('3', 72))::Decimal(76, 72) b, a / b, a * b, (-a) /b, (-a) * b
----
3.33 3.333333333333333333333333333333333333333333333333333333333333333333333333 0.99900000 11.099999999999999999999999999999999999999999999999999999999999999999999999 -0.99900000 -11.099999999999999999999999999999999999999999999999999999999999999999999999

statement error
select (repeat('9', 38) || '.3')::Decimal(76, 72) a, a * a

query I
SELECT CAST(987654321.34 AS DECIMAL(76, 2)) / CAST(1.23 AS DECIMAL(76, 2)) AS result;
Expand Down

0 comments on commit 5240970

Please sign in to comment.