Skip to content

Commit

Permalink
fix RecordBatch size in topK (#13906)
Browse files Browse the repository at this point in the history
  • Loading branch information
getChan authored Dec 26, 2024
1 parent 30660e0 commit 2d985b4
Showing 1 changed file with 45 additions and 4 deletions.
49 changes: 45 additions & 4 deletions datafusion/physical-plan/src/topk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use arrow::{
use std::mem::size_of;
use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc};

use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};
use crate::spill::get_record_batch_memory_size;
use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream};
use arrow_array::{Array, ArrayRef, RecordBatch};
use arrow_schema::SchemaRef;
Expand All @@ -36,8 +38,6 @@ use datafusion_execution::{
use datafusion_physical_expr::PhysicalSortExpr;
use datafusion_physical_expr_common::sort_expr::LexOrdering;

use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder};

/// Global TopK
///
/// # Background
Expand Down Expand Up @@ -575,7 +575,7 @@ impl RecordBatchStore {
pub fn insert(&mut self, entry: RecordBatchEntry) {
// uses of 0 means that none of the rows in the batch were stored in the topk
if entry.uses > 0 {
self.batches_size += entry.batch.get_array_memory_size();
self.batches_size += get_record_batch_memory_size(&entry.batch);
self.batches.insert(entry.id, entry);
}
}
Expand Down Expand Up @@ -630,7 +630,7 @@ impl RecordBatchStore {
let old_entry = self.batches.remove(&id).unwrap();
self.batches_size = self
.batches_size
.checked_sub(old_entry.batch.get_array_memory_size())
.checked_sub(get_record_batch_memory_size(&old_entry.batch))
.unwrap();
}
}
Expand All @@ -643,3 +643,44 @@ impl RecordBatchStore {
+ self.batches_size
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow_array::Float64Array;

/// This test ensures the size calculation is correct for RecordBatches with multiple columns.
#[test]
fn test_record_batch_store_size() {
// given
let schema = Arc::new(Schema::new(vec![
Field::new("ints", DataType::Int32, true),
Field::new("float64", DataType::Float64, false),
]));
let mut record_batch_store = RecordBatchStore::new(Arc::clone(&schema));
let int_array =
Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); // 5 * 4 = 20
let float64_array = Float64Array::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]); // 5 * 8 = 40

let record_batch_entry = RecordBatchEntry {
id: 0,
batch: RecordBatch::try_new(
schema,
vec![Arc::new(int_array), Arc::new(float64_array)],
)
.unwrap(),
uses: 1,
};

// when insert record batch entry
record_batch_store.insert(record_batch_entry);
assert_eq!(record_batch_store.batches_size, 60);

// when unuse record batch entry
record_batch_store.unuse(0);
assert_eq!(record_batch_store.batches_size, 0);
}
}

0 comments on commit 2d985b4

Please sign in to comment.