Skip to content

Commit c11ea33

Browse files
committed
GH-48076: [C++][Flight] fix GeneratorStream for Tables
After the changes in #47115, GeneratorStreams backed by anything else than RecordBatches failed. This includes Tables and RecordBatchReaders. This was caused by a too strict assumption that the RecordBatchStream#GetSchemaPayload would always get called, which is not the case when the GeneratorStream is backed by a Table or a RecordBatchReader. So to fix this, remove the assertion and instead initialize the writer on first access.
1 parent cd23a76 commit c11ea33

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

cpp/src/arrow/flight/server.cc

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
285285

286286
Status GetSchemaPayload(FlightPayload* payload) {
287287
if (!writer_) {
288-
// Create the IPC writer on first call
289-
auto payload_writer =
290-
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
291-
ARROW_ASSIGN_OR_RAISE(
292-
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
293-
reader_->schema(), options_));
288+
RETURN_NOT_OK(InitializeWriter());
294289
}
295290

296291
// Return the expected schema payload.
@@ -317,8 +312,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
317312
return Status::OK();
318313
}
319314
if (!writer_) {
320-
return Status::UnknownError(
321-
"Writer should be initialized before reading Next batches");
315+
RETURN_NOT_OK(InitializeWriter());
322316
}
323317
// One WriteRecordBatch call might generate multiple payloads, so we
324318
// need to collect them in a deque.
@@ -370,6 +364,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
370364
ipc::IpcWriteOptions options_;
371365
std::unique_ptr<ipc::RecordBatchWriter> writer_;
372366
std::deque<FlightPayload> payload_deque_;
367+
368+
Status InitializeWriter() {
369+
auto payload_writer =
370+
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
371+
ARROW_ASSIGN_OR_RAISE(
372+
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
373+
reader_->schema(), options_));
374+
return Status::OK();
375+
}
373376
};
374377

375378
FlightMetadataWriter::~FlightMetadataWriter() = default;

python/pyarrow/tests/test_flight.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,40 @@ def do_action(self, context, action):
246246
raise NotImplementedError
247247

248248

249+
class EchoTableStreamFlightServer(EchoFlightServer):
250+
"""An echo server that streams the whole table."""
251+
252+
def do_get(self, context, ticket):
253+
return flight.GeneratorStream(
254+
self.last_message.schema,
255+
[self.last_message])
256+
257+
def list_actions(self, context):
258+
return []
259+
260+
def do_action(self, context, action):
261+
if action.type == "who-am-i":
262+
return [context.peer_identity(), context.peer().encode("utf-8")]
263+
raise NotImplementedError
264+
265+
266+
class EchoRecordBatchReaderStreamFlightServer(EchoFlightServer):
267+
"""An echo server that streams the whole table as a RecordBatchReader."""
268+
269+
def do_get(self, context, ticket):
270+
return flight.GeneratorStream(
271+
self.last_message.schema,
272+
[self.last_message.to_reader()])
273+
274+
def list_actions(self, context):
275+
return []
276+
277+
def do_action(self, context, action):
278+
if action.type == "who-am-i":
279+
return [context.peer_identity(), context.peer().encode("utf-8")]
280+
raise NotImplementedError
281+
282+
249283
class GetInfoFlightServer(FlightServerBase):
250284
"""A Flight server that tests GetFlightInfo."""
251285

@@ -1362,7 +1396,7 @@ def test_flight_large_message():
13621396
assert result.equals(data)
13631397

13641398

1365-
def test_flight_generator_stream():
1399+
def test_flight_generator_stream_of_batches():
13661400
"""Try downloading a flight of RecordBatches in a GeneratorStream."""
13671401
data = pa.Table.from_arrays([
13681402
pa.array(range(0, 10 * 1024))
@@ -1378,6 +1412,38 @@ def test_flight_generator_stream():
13781412
assert result.equals(data)
13791413

13801414

1415+
def test_flight_generator_stream_of_table():
1416+
"""Try downloading a flight of Table in a GeneratorStream."""
1417+
data = pa.Table.from_arrays([
1418+
pa.array(range(0, 10 * 1024))
1419+
], names=['a'])
1420+
1421+
with EchoTableStreamFlightServer() as server, \
1422+
FlightClient(('localhost', server.port)) as client:
1423+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1424+
data.schema)
1425+
writer.write_table(data)
1426+
writer.close()
1427+
result = client.do_get(flight.Ticket(b'')).read_all()
1428+
assert result.equals(data)
1429+
1430+
1431+
def test_flight_generator_stream_of_record_batch_reader():
1432+
"""Try downloading a flight of RecordBatchReader in a GeneratorStream."""
1433+
data = pa.Table.from_arrays([
1434+
pa.array(range(0, 10 * 1024))
1435+
], names=['a'])
1436+
1437+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1438+
FlightClient(('localhost', server.port)) as client:
1439+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1440+
data.schema)
1441+
writer.write_table(data)
1442+
writer.close()
1443+
result = client.do_get(flight.Ticket(b'')).read_all()
1444+
assert result.equals(data)
1445+
1446+
13811447
def test_flight_invalid_generator_stream():
13821448
"""Try streaming data with mismatched schemas."""
13831449
with InvalidStreamFlightServer() as server, \

0 commit comments

Comments
 (0)