diff --git a/Cargo.lock b/Cargo.lock index f434dc2bfa3d..a87f5ffc3bb9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3415,6 +3415,7 @@ dependencies = [ "match-template", "memchr", "micromarshal 0.5.0", + "num-bigint", "num-traits", "ordered-float 4.2.0", "pretty_assertions", @@ -10841,12 +10842,11 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "arbitrary", - "autocfg", "num-integer", "num-traits", ] diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index b15d88948b61..a8d9a55c35ff 100644 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -44,6 +44,7 @@ log = { workspace = true } match-template = { workspace = true } memchr = { version = "2", default-features = false } micromarshal = "0.5.0" +num-bigint = "0.4.6" num-traits = "0.2.15" ordered-float = { workspace = true, features = ["serde", "rand", "borsh"] } rand = { workspace = true } diff --git a/src/query/expression/src/types/decimal.rs b/src/query/expression/src/types/decimal.rs index 4530ebdb56a7..99ada93822e5 100644 --- a/src/query/expression/src/types/decimal.rs +++ b/src/query/expression/src/types/decimal.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::Ordering; use std::fmt::Debug; use std::marker::PhantomData; use std::ops::Range; @@ -25,9 +26,13 @@ use databend_common_io::display_decimal_128; use databend_common_io::display_decimal_256; use enum_as_inner::EnumAsInner; use ethnum::i256; +use ethnum::u256; use ethnum::AsI256; use itertools::Itertools; +use num_bigint::BigInt; +use num_traits::FromBytes; use num_traits::NumCast; +use num_traits::ToPrimitive; use serde::Deserialize; use serde::Serialize; @@ -356,7 +361,7 @@ pub trait Decimal: fn checked_mul(self, rhs: Self) -> Option; fn checked_rem(self, rhs: Self) -> Option; - fn do_round_div(self, rhs: Self, mul: Self) -> Option; + fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option; // mul two decimals and return a decimal with rounding option fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option; @@ -368,6 +373,7 @@ pub trait Decimal: fn from_float(value: f64) -> Self; fn from_i128>(value: U) -> Self; + fn from_bigint(value: BigInt) -> Option; fn de_binary(bytes: &mut &[u8]) -> Self; fn display(self, scale: u8) -> String; @@ -471,7 +477,8 @@ impl Decimal for i128 { } } - fn do_round_div(self, rhs: Self, mul: Self) -> Option { + fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option { + let mul = i256::e(mul_scale); if self.is_negative() == rhs.is_negative() { let res = (i256::from(self) * i256::from(mul) + i256::from(rhs) / 2) / i256::from(rhs); Some(*res.low()) @@ -535,6 +542,10 @@ impl Decimal for i128 { value.into() } + fn from_bigint(value: BigInt) -> Option { + value.to_i128() + } + fn de_binary(bytes: &mut &[u8]) -> Self { let bs: [u8; std::mem::size_of::()] = bytes[0..std::mem::size_of::()].try_into().unwrap(); @@ -687,19 +698,48 @@ impl Decimal for i256 { fn do_round_mul(self, rhs: Self, shift_scale: u32) -> Option { let div = i256::e(shift_scale); - if self.is_negative() == rhs.is_negative() { + let ret: Option = if self.is_negative() == rhs.is_negative() { self.checked_mul(rhs).map(|x| (x + div / 2) / div) } else { self.checked_mul(rhs).map(|x| (x - div / 2) / div) - } + }; + + ret.or_else(|| { + let a = BigInt::from_le_bytes(&self.to_le_bytes()); + let b = BigInt::from_le_bytes(&rhs.to_le_bytes()); + let div = BigInt::from(10).pow(shift_scale); + if self.is_negative() == rhs.is_negative() { + Self::from_bigint((a * b + div.clone() / 2) / div) + } else { + Self::from_bigint((a * b - div.clone() / 2) / div) + } + }) } - fn do_round_div(self, rhs: Self, mul: Self) -> Option { - if self.is_negative() == rhs.is_negative() { + fn do_round_div(self, rhs: Self, mul_scale: u32) -> Option { + let fallback = || { + let a = BigInt::from_le_bytes(&self.to_le_bytes()); + let b = BigInt::from_le_bytes(&rhs.to_le_bytes()); + let mul = BigInt::from(10).pow(mul_scale); + if self.is_negative() == rhs.is_negative() { + Self::from_bigint((a * mul + b.clone() / 2) / b) + } else { + Self::from_bigint((a * mul - b.clone() / 2) / b) + } + }; + + if mul_scale >= MAX_DECIMAL256_PRECISION as _ { + return fallback(); + } + + let mul = i256::e(mul_scale); + let ret: Option = if self.is_negative() == rhs.is_negative() { self.checked_mul(mul).map(|x| (x + rhs / 2) / rhs) } else { self.checked_mul(mul).map(|x| (x - rhs / 2) / rhs) - } + }; + + ret.or_else(fallback) } fn min_for_precision(to_precision: u8) -> Self { @@ -725,6 +765,32 @@ impl Decimal for i256 { i256::from(value.into()) } + fn from_bigint(value: BigInt) -> Option { + let mut ret: u256 = u256::ZERO; + let mut bits = 0; + + for i in value.iter_u64_digits() { + if bits >= 256 { + return None; + } + ret |= u256::from(i) << bits; + bits += 64; + } + + match value.sign() { + num_bigint::Sign::Plus => i256::try_from(ret).ok(), + num_bigint::Sign::NoSign => Some(i256::ZERO), + num_bigint::Sign::Minus => { + let m: u256 = u256::ONE << 255; + match ret.cmp(&m) { + Ordering::Less => Some(-i256::try_from(ret).unwrap()), + Ordering::Equal => Some(i256::MIN), + Ordering::Greater => None, + } + } + } + } + fn de_binary(bytes: &mut &[u8]) -> Self { let bs: [u8; std::mem::size_of::()] = bytes[0..std::mem::size_of::()].try_into().unwrap(); @@ -947,10 +1013,9 @@ impl DecimalDataType { let l = a.leading_digits() + b.leading_digits(); precision = l + scale; } else if is_divide { - let l = a.leading_digits() + b.scale(); - scale = a.scale().max((a.scale() + 6).min(12)); - // P = L + S - precision = l + scale; + scale = a.scale().max((a.scale() + 6).min(12)); // scale must be >= a.sale() + let l = a.leading_digits() + b.scale(); // l must be >= a.leading_digits() + precision = l + scale; // so precision must be >= a.precision() } else if is_plus_minus { scale = std::cmp::max(a.scale(), b.scale()); // for addition/subtraction, we add 1 to the width to ensure we don't overflow @@ -984,8 +1049,18 @@ impl DecimalDataType { result_type, )) } else if is_divide { - let (a, b) = Self::div_common_type(a, b, result_type.size())?; - Ok((a, b, result_type)) + let p = precision.max(a.precision()).max(b.precision()); + Ok(( + Self::from_size(DecimalSize { + precision: p, + scale: a.scale(), + })?, + Self::from_size(DecimalSize { + precision: p, + scale: b.scale(), + })?, + result_type, + )) } else { Ok((result_type, result_type, result_type)) } diff --git a/src/query/expression/tests/it/decimal.rs b/src/query/expression/tests/it/decimal.rs index 14a3beb5fc91..bd8a87701a0e 100644 --- a/src/query/expression/tests/it/decimal.rs +++ b/src/query/expression/tests/it/decimal.rs @@ -21,6 +21,8 @@ use databend_common_expression::types::decimal::DecimalSize; use databend_common_expression::types::DataType; use databend_common_expression::types::DecimalDataType; use databend_common_expression::types::NumberDataType; +use ethnum::i256; +use num_bigint::BigInt; use pretty_assertions::assert_eq; #[test] @@ -168,3 +170,51 @@ fn test_float_to_128() { assert_eq!(r, b); } } + +#[test] +fn test_from_bigint() { + let cases = vec![ + ("0", 0i128), + ("12345", 12345i128), + ("-1", -1i128), + ("-170141183460469231731687303715884105728", i128::MIN), + ("170141183460469231731687303715884105727", i128::MAX), + ]; + + for (a, b) in cases { + let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap(); + assert_eq!(i128::from_bigint(r), Some(b)); + } + + let cases = vec![ + ("0".to_string(), i256::ZERO), + ("12345".to_string(), i256::from(12345)), + ("-1".to_string(), i256::from(-1)), + ( + "12".repeat(25), + i256::from_str_radix(&"12".repeat(25), 10).unwrap(), + ), + ( + "1".repeat(26), + i256::from_str_radix(&"1".repeat(26), 10).unwrap(), + ), + (i256::MIN.to_string(), i256::MIN), + (i256::MAX.to_string(), i256::MAX), + ]; + + for (a, b) in cases { + let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap(); + assert_eq!(i256::from_bigint(r), Some(b)); + } + + let cases = vec![ + ("1".repeat(78), None), + ("12".repeat(78), None), + ("234".repeat(78), None), + ]; + + for (a, b) in cases { + let r = BigInt::parse_bytes(a.as_bytes(), 10).unwrap(); + assert_eq!(i256::from_bigint(r), b); + } +} diff --git a/src/query/functions/src/scalars/decimal/arithmetic.rs b/src/query/functions/src/scalars/decimal/arithmetic.rs index 625888a5c868..3b72256c9f8d 100644 --- a/src/query/functions/src/scalars/decimal/arithmetic.rs +++ b/src/query/functions/src/scalars/decimal/arithmetic.rs @@ -89,8 +89,7 @@ macro_rules! binary_decimal { let scale_b = $right.scale(); // Note: the result scale is always larger than the left scale - let scale_mul = scale_b + $size.scale - scale_a; - let multiplier = T::e(scale_mul as u32); + let scale_mul = (scale_b + $size.scale - scale_a) as u32; let func = |a: T, b: T, result: &mut Vec, ctx: &mut EvalContext| { // We are using round div here which follow snowflake's behavior: https://docs.snowflake.com/sql-reference/operators-arithmetic // For example: @@ -102,7 +101,7 @@ macro_rules! binary_decimal { ctx.set_error(result.len(), "divided by zero"); result.push(one); } else { - match a.do_round_div(b, multiplier) { + match a.do_round_div(b, scale_mul) { Some(t) => result.push(t), None => { ctx.set_error( diff --git a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test index 7a42560c7aa6..51643cefdab6 100644 --- a/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test +++ b/tests/sqllogictests/suites/base/11_data_type/11_0006_data_type_decimal.test @@ -295,6 +295,13 @@ SELECT CAST(987654321.34 AS DECIMAL(20, 2)) / CAST(1.23 AS DECIMAL(6, 2)) AS res ---- 802970992.95934959 +query IIIIII +select 3.33 a , ('3.' || repeat('3', 72))::Decimal(76, 72) b, a / b, a * b, (-a) /b, (-a) * b +---- +3.33 3.333333333333333333333333333333333333333333333333333333333333333333333333 0.99900000 11.099999999999999999999999999999999999999999999999999999999999999999999999 -0.99900000 -11.099999999999999999999999999999999999999999999999999999999999999999999999 + +statement error +select (repeat('9', 38) || '.3')::Decimal(76, 72) a, a * a query I SELECT CAST(987654321.34 AS DECIMAL(76, 2)) / CAST(1.23 AS DECIMAL(76, 2)) AS result;