diff --git a/mutable_batch/src/writer.rs b/mutable_batch/src/writer.rs index b3340a8d73..9ed436080e 100644 --- a/mutable_batch/src/writer.rs +++ b/mutable_batch/src/writer.rs @@ -42,6 +42,8 @@ pub struct Writer<'a> { statistics: Vec<(usize, Statistics)>, /// The initial number of rows in the MutableBatch initial_rows: usize, + /// The initial number of columns in the MutableBatch + initial_cols: usize, /// The number of rows to insert to_insert: usize, /// If this Writer committed successfully @@ -54,10 +56,12 @@ impl<'a> Writer<'a> { /// If the writer is dropped without calling commit all changes will be rolled back pub fn new(batch: &'a mut MutableBatch, to_insert: usize) -> Self { let initial_rows = batch.rows(); + let initial_cols = batch.columns.len(); Self { batch, statistics: vec![], initial_rows, + initial_cols, to_insert, success: false, } @@ -743,6 +747,13 @@ impl<'a> Drop for Writer<'a> { fn drop(&mut self) { if !self.success { let initial_rows = self.initial_rows; + let initial_cols = self.initial_cols; + + if self.batch.columns.len() != initial_cols { + self.batch.columns.truncate(initial_cols); + self.batch.column_names.retain(|_, v| *v < initial_cols) + } + for col in &mut self.batch.columns { col.valid.truncate(initial_rows); match &mut col.data { diff --git a/mutable_batch/tests/writer_drop.rs b/mutable_batch/tests/writer_drop.rs new file mode 100644 index 0000000000..9460415c95 --- /dev/null +++ b/mutable_batch/tests/writer_drop.rs @@ -0,0 +1,37 @@ +use arrow_util::assert_batches_eq; +use mutable_batch::writer::Writer; +use mutable_batch::MutableBatch; +use schema::selection::Selection; + +#[test] +fn test_new_column() { + let mut batch = MutableBatch::new(); + let mut writer = Writer::new(&mut batch, 2); + + writer + .write_bool("b1", None, vec![true, false].into_iter()) + .unwrap(); + + writer.commit(); + + let expected = &[ + "+-------+", + "| b1 |", + "+-------+", + "| true |", + "| false |", + "+-------+", + ]; + + assert_batches_eq!(expected, &[batch.to_arrow(Selection::All).unwrap()]); + + let mut writer = Writer::new(&mut batch, 1); + writer + .write_string("tag1", None, vec!["v1"].into_iter()) + .unwrap(); + + std::mem::drop(writer); + + // Should not include tag1 column + assert_batches_eq!(expected, &[batch.to_arrow(Selection::All).unwrap()]); +}