diff --git a/arrow_util/src/optimize.rs b/arrow_util/src/optimize.rs index 0806dda6f2..eff425f078 100644 --- a/arrow_util/src/optimize.rs +++ b/arrow_util/src/optimize.rs @@ -85,9 +85,13 @@ fn optimize_dict_col( None => -1, }); - Ok(Arc::new( - new_dictionary.to_arrow(new_keys, keys.data().null_buffer().cloned()), - )) + let offset = keys.data().offset(); + let nulls = keys + .data() + .null_buffer() + .map(|buffer| buffer.bit_slice(offset, keys.len())); + + Ok(Arc::new(new_dictionary.to_arrow(new_keys, nulls))) } #[cfg(test)] @@ -118,19 +122,9 @@ mod tests { Some(3), ]); - let data = ArrayDataBuilder::new(DataType::Dictionary( - Box::new(DataType::Int32), - Box::new(DataType::Utf8), - )) - .len(keys.len()) - .add_buffer(keys.data().buffers()[0].clone()) - .null_bit_buffer(keys.data().null_buffer().unwrap().clone()) - .add_child_data(values.data().clone()) - .build(); - let batch = RecordBatch::try_from_iter(vec![( "foo", - Arc::new(DictionaryArray::::from(data)) as ArrayRef, + Arc::new(build_dict(keys, values)) as ArrayRef, )]) .unwrap(); @@ -244,4 +238,66 @@ mod tests { &[optimized] ); } + + #[test] + fn test_null() { + let values = StringArray::from(vec!["bananas"]); + let keys = Int32Array::from(vec![None, None, Some(0)]); + let col = Arc::new(build_dict(keys, values)) as ArrayRef; + + let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap(); + + let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap(); + + assert_batches_eq!( + vec![ + "+---------+", + "| t |", + "+---------+", + "| |", + "| |", + "| bananas |", + "+---------+", + ], + &[batch] + ); + } + + #[test] + fn test_slice() { + let values = StringArray::from(vec!["bananas"]); + let keys = Int32Array::from(vec![None, Some(0), None]); + let col = Arc::new(build_dict(keys, values)) as ArrayRef; + let col = col.slice(1, 2); + + let col = optimize_dict_col(&col, &DataType::Int32, &DataType::Utf8).unwrap(); + + let batch = RecordBatch::try_from_iter(vec![("t", col)]).unwrap(); + + assert_batches_eq!( + vec![ + "+---------+", + "| t |", + "+---------+", + "| bananas |", + "| |", + "+---------+", + ], + &[batch] + ); + } + + fn build_dict(keys: Int32Array, values: StringArray) -> DictionaryArray { + let data = ArrayDataBuilder::new(DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + )) + .len(keys.len()) + .add_buffer(keys.data().buffers()[0].clone()) + .null_bit_buffer(keys.data().null_buffer().unwrap().clone()) + .add_child_data(values.data().clone()) + .build(); + + DictionaryArray::from(data) + } } diff --git a/query/src/provider/deduplicate.rs b/query/src/provider/deduplicate.rs index 8800da9e55..9313dfe73d 100644 --- a/query/src/provider/deduplicate.rs +++ b/query/src/provider/deduplicate.rs @@ -251,7 +251,7 @@ async fn deduplicate( #[cfg(test)] mod test { use arrow::compute::SortOptions; - use arrow::datatypes::SchemaRef; + use arrow::datatypes::{Int32Type, SchemaRef}; use arrow::{ array::{ArrayRef, Float64Array, StringArray}, record_batch::RecordBatch, @@ -260,6 +260,8 @@ mod test { use datafusion::physical_plan::{collect, expressions::col, memory::MemoryExec}; use super::*; + use arrow::array::DictionaryArray; + use std::iter::FromIterator; #[tokio::test] async fn test_single_tag() { @@ -777,6 +779,90 @@ mod test { ); } + #[tokio::test] + async fn test_dictionary() { + let t1 = DictionaryArray::::from_iter(vec![Some("a"), Some("a"), Some("b")]); + let t2 = DictionaryArray::::from_iter(vec![Some("b"), Some("c"), Some("c")]); + let f1 = Float64Array::from(vec![Some(1.0), Some(3.0), Some(4.0)]); + let f2 = Float64Array::from(vec![Some(2.0), None, Some(5.0)]); + + let batch1 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let t1 = DictionaryArray::::from_iter(vec![Some("b"), Some("c")]); + let t2 = DictionaryArray::::from_iter(vec![Some("c"), Some("d")]); + let f1 = Float64Array::from(vec![None, Some(7.0)]); + let f2 = Float64Array::from(vec![Some(6.0), Some(8.0)]); + + let batch2 = RecordBatch::try_from_iter(vec![ + ("t1", Arc::new(t1) as ArrayRef), + ("t2", Arc::new(t2) as ArrayRef), + ("f1", Arc::new(f1) as ArrayRef), + ("f2", Arc::new(f2) as ArrayRef), + ]) + .unwrap(); + + let sort_keys = vec![ + PhysicalSortExpr { + expr: col("t1"), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + PhysicalSortExpr { + expr: col("t2"), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }, + ]; + + let results = dedupe(vec![batch1, batch2], sort_keys).await; + + let cols: Vec<_> = results + .output + .iter() + .map(|batch| { + batch + .column(batch.schema().column_with_name("t1").unwrap().0) + .as_any() + .downcast_ref::>() + .unwrap() + }) + .collect(); + + // Should produce optimised dictionaries + // The batching is not important + assert_eq!(cols.len(), 3); + assert_eq!(cols[0].keys().len(), 2); + assert_eq!(cols[0].values().len(), 1); // "a" + assert_eq!(cols[1].keys().len(), 1); + assert_eq!(cols[1].values().len(), 1); // "b" + assert_eq!(cols[2].keys().len(), 1); + assert_eq!(cols[2].values().len(), 1); // "c" + + let expected = vec![ + "+----+----+----+----+", + "| t1 | t2 | f1 | f2 |", + "+----+----+----+----+", + "| a | b | 1 | 2 |", + "| a | c | 3 | |", + "| b | c | 4 | 6 |", + "| c | d | 7 | 8 |", + "+----+----+----+----+", + ]; + assert_batches_eq!(&expected, &results.output); + // 5 rows in initial input, 4 rows in output ==> 1 dupes + assert_eq!(results.num_dupes(), 5 - 4); + } + struct TestResults { output: Vec, exec: Arc, diff --git a/query/src/provider/deduplicate/algo.rs b/query/src/provider/deduplicate/algo.rs index 271d276a74..d874d9d9d1 100644 --- a/query/src/provider/deduplicate/algo.rs +++ b/query/src/provider/deduplicate/algo.rs @@ -9,6 +9,7 @@ use arrow::{ record_batch::RecordBatch, }; +use arrow_util::optimize::optimize_dictionaries; use datafusion::physical_plan::{ coalesce_batches::concat_batches, expressions::PhysicalSortExpr, PhysicalExpr, SQLMetric, }; @@ -177,7 +178,12 @@ impl RecordBatchDeduplicator { } }) .collect::>>()?; - RecordBatch::try_new(batch.schema(), new_columns) + + let batch = RecordBatch::try_new(batch.schema(), new_columns)?; + // At time of writing, `MutableArrayData` concatenates the + // contents of dictionaries as well; Do a post pass to remove the + // redundancy if possible + optimize_dictionaries(&batch) } } @@ -233,7 +239,12 @@ impl RecordBatchDeduplicator { .map(|old_column| old_column.slice(offset, len)) .collect(); - RecordBatch::try_new(schema, new_columns) + let batch = RecordBatch::try_new(schema, new_columns)?; + + // At time of writing, `concat_batches` concatenates the + // contents of dictionaries as well; Do a post pass to remove the + // redundancy if possible + optimize_dictionaries(&batch) } }