Skip to content

Commit 72d4258

Browse files
committed
GH-48076: [C++][Flight] fix GeneratorStream for Tables
After the changes in #47115, GeneratorStreams backed by anything else than RecordBatches failed. This includes Tables and RecordBatchReaders. This was caused by a too strict assumption that the RecordBatchStream#GetSchemaPayload would always get called, which is not the case when the GeneratorStream is backed by a Table or a RecordBatchReader. So to fix this, remove the assertion and instead initialize the writer on first access. Also, to accommodate for this case, drop the incoming message when initializing the writer in Next, as the message there is of the SCHEMA type and we want RECORD_BATCH one.
1 parent cd23a76 commit 72d4258

File tree

2 files changed

+84
-9
lines changed

2 files changed

+84
-9
lines changed

cpp/src/arrow/flight/server.cc

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,7 @@ class RecordBatchStream::RecordBatchStreamImpl {
285285

286286
Status GetSchemaPayload(FlightPayload* payload) {
287287
if (!writer_) {
288-
// Create the IPC writer on first call
289-
auto payload_writer =
290-
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
291-
ARROW_ASSIGN_OR_RAISE(
292-
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
293-
reader_->schema(), options_));
288+
RETURN_NOT_OK(InitializeWriter());
294289
}
295290

296291
// Return the expected schema payload.
@@ -317,8 +312,13 @@ class RecordBatchStream::RecordBatchStreamImpl {
317312
return Status::OK();
318313
}
319314
if (!writer_) {
320-
return Status::UnknownError(
321-
"Writer should be initialized before reading Next batches");
315+
RETURN_NOT_OK(InitializeWriter());
316+
// If the writer is not initialized yet, the first batch in the payload
317+
// queue is going to be a schema-only one. In this context, that is
318+
// unexpected, so drop it from the queue so that there is a RECORD_BATCH
319+
// message on the top (same as would be if the writer had been initialized
320+
// in GetSchemaPayload).
321+
payload_deque_.pop_front();
322322
}
323323
// One WriteRecordBatch call might generate multiple payloads, so we
324324
// need to collect them in a deque.
@@ -370,6 +370,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
370370
ipc::IpcWriteOptions options_;
371371
std::unique_ptr<ipc::RecordBatchWriter> writer_;
372372
std::deque<FlightPayload> payload_deque_;
373+
374+
Status InitializeWriter() {
375+
auto payload_writer =
376+
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
377+
ARROW_ASSIGN_OR_RAISE(
378+
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
379+
reader_->schema(), options_));
380+
return Status::OK();
381+
}
373382
};
374383

375384
FlightMetadataWriter::~FlightMetadataWriter() = default;

python/pyarrow/tests/test_flight.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,40 @@ def do_action(self, context, action):
246246
raise NotImplementedError
247247

248248

249+
class EchoTableStreamFlightServer(EchoFlightServer):
250+
"""An echo server that streams the whole table."""
251+
252+
def do_get(self, context, ticket):
253+
return flight.GeneratorStream(
254+
self.last_message.schema,
255+
[self.last_message])
256+
257+
def list_actions(self, context):
258+
return []
259+
260+
def do_action(self, context, action):
261+
if action.type == "who-am-i":
262+
return [context.peer_identity(), context.peer().encode("utf-8")]
263+
raise NotImplementedError
264+
265+
266+
class EchoRecordBatchReaderStreamFlightServer(EchoFlightServer):
267+
"""An echo server that streams the whole table as a RecordBatchReader."""
268+
269+
def do_get(self, context, ticket):
270+
return flight.GeneratorStream(
271+
self.last_message.schema,
272+
[self.last_message.to_reader()])
273+
274+
def list_actions(self, context):
275+
return []
276+
277+
def do_action(self, context, action):
278+
if action.type == "who-am-i":
279+
return [context.peer_identity(), context.peer().encode("utf-8")]
280+
raise NotImplementedError
281+
282+
249283
class GetInfoFlightServer(FlightServerBase):
250284
"""A Flight server that tests GetFlightInfo."""
251285

@@ -1362,7 +1396,7 @@ def test_flight_large_message():
13621396
assert result.equals(data)
13631397

13641398

1365-
def test_flight_generator_stream():
1399+
def test_flight_generator_stream_of_batches():
13661400
"""Try downloading a flight of RecordBatches in a GeneratorStream."""
13671401
data = pa.Table.from_arrays([
13681402
pa.array(range(0, 10 * 1024))
@@ -1378,6 +1412,38 @@ def test_flight_generator_stream():
13781412
assert result.equals(data)
13791413

13801414

1415+
def test_flight_generator_stream_of_table():
1416+
"""Try downloading a flight of Table in a GeneratorStream."""
1417+
data = pa.Table.from_arrays([
1418+
pa.array(range(0, 10 * 1024))
1419+
], names=['a'])
1420+
1421+
with EchoTableStreamFlightServer() as server, \
1422+
FlightClient(('localhost', server.port)) as client:
1423+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1424+
data.schema)
1425+
writer.write_table(data)
1426+
writer.close()
1427+
result = client.do_get(flight.Ticket(b'')).read_all()
1428+
assert result.equals(data)
1429+
1430+
1431+
def test_flight_generator_stream_of_record_batch_reader():
1432+
"""Try downloading a flight of RecordBatchReader in a GeneratorStream."""
1433+
data = pa.Table.from_arrays([
1434+
pa.array(range(0, 10 * 1024))
1435+
], names=['a'])
1436+
1437+
with EchoRecordBatchReaderStreamFlightServer() as server, \
1438+
FlightClient(('localhost', server.port)) as client:
1439+
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1440+
data.schema)
1441+
writer.write_table(data)
1442+
writer.close()
1443+
result = client.do_get(flight.Ticket(b'')).read_all()
1444+
assert result.equals(data)
1445+
1446+
13811447
def test_flight_invalid_generator_stream():
13821448
"""Try streaming data with mismatched schemas."""
13831449
with InvalidStreamFlightServer() as server, \

0 commit comments

Comments
 (0)