Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions cpp/src/arrow/csv/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ class CSVWriterImpl : public ipc::RecordBatchWriter {
for (auto maybe_slice : iterator) {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> slice, maybe_slice);
RETURN_NOT_OK(TranslateMinimalBatch(*slice));
RETURN_NOT_OK(sink_->Write(data_buffer_));
RETURN_NOT_OK(FlushToSink());
stats_.num_record_batches++;
}
return Status::OK();
Expand All @@ -554,7 +554,7 @@ class CSVWriterImpl : public ipc::RecordBatchWriter {
RETURN_NOT_OK(reader.ReadNext(&batch));
while (batch != nullptr) {
RETURN_NOT_OK(TranslateMinimalBatch(*batch));
RETURN_NOT_OK(sink_->Write(data_buffer_));
RETURN_NOT_OK(FlushToSink());
RETURN_NOT_OK(reader.ReadNext(&batch));
stats_.num_record_batches++;
}
Expand Down Expand Up @@ -590,6 +590,13 @@ class CSVWriterImpl : public ipc::RecordBatchWriter {
return Status::OK();
}

// GH-36889: Flush buffer to sink and clear it to avoid stale content
// being written again if the next batch is empty.
Status FlushToSink() {
RETURN_NOT_OK(sink_->Write(data_buffer_));
return data_buffer_->Resize(0, /*shrink_to_fit=*/false);
}

int64_t CalculateHeaderSize(QuotingStyle quoting_style) const {
int64_t header_length = 0;
for (int col = 0; col < schema_->num_fields(); col++) {
Expand Down Expand Up @@ -654,7 +661,7 @@ class CSVWriterImpl : public ipc::RecordBatchWriter {
next += options_.eol.size();
DCHECK_EQ(reinterpret_cast<uint8_t*>(next),
data_buffer_->data() + data_buffer_->size());
return sink_->Write(data_buffer_);
return FlushToSink();
}

Status TranslateMinimalBatch(const RecordBatch& batch) {
Expand Down
32 changes: 32 additions & 0 deletions cpp/src/arrow/csv/writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/ipc/writer.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/matchers.h"
#include "arrow/type.h"
Expand Down Expand Up @@ -405,5 +406,36 @@ INSTANTIATE_TEST_SUITE_P(
"\n2016-02-29 10:42:23-0700,2016-02-29 17:42:23Z\n")));
#endif

TEST(TestWriteCSV, EmptyBatchShouldNotPolluteOutput) {
auto schema = arrow::schema({field("col1", utf8())});
auto empty_batch = RecordBatchFromJSON(schema, "[]");
auto batch_a = RecordBatchFromJSON(schema, R"([{"col1": "a"}])");
auto batch_b = RecordBatchFromJSON(schema, R"([{"col1": "b"}])");

struct TestParam {
std::shared_ptr<Table> table;
std::string expected_output;
};

std::vector<TestParam> test_params = {
// Empty batch in the beginning
{Table::FromRecordBatches(schema, {empty_batch, batch_a, batch_b}).ValueOrDie(),
"\"col1\"\n\"a\"\n\"b\"\n"},
// Empty batch in the middle
{Table::FromRecordBatches(schema, {batch_a, empty_batch, batch_b}).ValueOrDie(),
"\"col1\"\n\"a\"\n\"b\"\n"},
// Empty batch in the end
{Table::FromRecordBatches(schema, {batch_a, batch_b, empty_batch}).ValueOrDie(),
"\"col1\"\n\"a\"\n\"b\"\n"},
};

for (const auto& param : test_params) {
ASSERT_OK_AND_ASSIGN(auto out, io::BufferOutputStream::Create());
ASSERT_OK(WriteCSV(*param.table, WriteOptions::Defaults(), out.get()));
ASSERT_OK_AND_ASSIGN(auto buffer, out->Finish());
EXPECT_EQ(buffer->ToString(), param.expected_output);
}
}

} // namespace csv
} // namespace arrow
34 changes: 34 additions & 0 deletions python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,3 +2065,37 @@ def readinto(self, *args):
for i in range(20):
with pytest.raises(pa.ArrowInvalid):
read_csv(MyBytesIO(data))


@pytest.mark.parametrize("tables,expected", [
# GH-36889: Empty batch at the beginning
(
lambda: [pa.table({"col1": []}).cast(pa.schema([("col1", pa.string())])),
pa.table({"col1": ["a"]}),
pa.table({"col1": ["b"]})],
b'"col1"\n"a"\n"b"\n'
),
# GH-36889: Empty batch in the middle
(
lambda: [pa.table({"col1": ["a"]}),
pa.table({"col1": []}).cast(pa.schema([("col1", pa.string())])),
pa.table({"col1": ["b"]})],
b'"col1"\n"a"\n"b"\n'
),
# GH-36889: Empty batch at the end
(
lambda: [pa.table({"col1": ["a"]}),
pa.table({"col1": ["b"]}),
pa.table({"col1": []}).cast(pa.schema([("col1", pa.string())]))],
b'"col1"\n"a"\n"b"\n'
),
])
def test_write_csv_empty_batch_should_not_pollute_output(tables, expected):
combined = pa.concat_tables(tables())

buf = io.BytesIO()
write_csv(combined, buf)
buf.seek(0)
result = buf.read()

assert result == expected