Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,7 @@ class RecordBatchStream::RecordBatchStreamImpl {

Status GetSchemaPayload(FlightPayload* payload) {
if (!writer_) {
// Create the IPC writer on first call
auto payload_writer =
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
ARROW_ASSIGN_OR_RAISE(
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
reader_->schema(), options_));
RETURN_NOT_OK(InitializeWriter());
}

// Return the expected schema payload.
Expand All @@ -317,8 +312,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
return Status::OK();
}
if (!writer_) {
return Status::UnknownError(
"Writer should be initialized before reading Next batches");
RETURN_NOT_OK(InitializeWriter());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this specific scenario we might have to drop the first message after calling writer_->WriteRecordBatch, which will be a schema message, after creation of the ipc::internal::OpenRecordBatchWriter and we want to keep the rest of the messages which should be the expected RecordBatch for those cases.
I would have to debug but from what I understand that might be the cause of the current problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion, that is a good idea. I will take a look as soon as I can, unfortunately something more urgent cropped up I have to deal with first :( I will let you know what I find out

// If the writer has not been initialized yet, the first batch in the payload
// queue is going to be a SCHEMA one. In this context, that is
// unexpected, so drop it from the queue so that there is a RECORD_BATCH
// message on the top (same as would be if the writer had been initialized
// in GetSchemaPayload).
if (payload_deque_.front().ipc_message.type == ipc::MessageType::SCHEMA) {
payload_deque_.pop_front();
}
}
// One WriteRecordBatch call might generate multiple payloads, so we
// need to collect them in a deque.
Expand Down Expand Up @@ -370,6 +372,15 @@ class RecordBatchStream::RecordBatchStreamImpl {
ipc::IpcWriteOptions options_;
std::unique_ptr<ipc::RecordBatchWriter> writer_;
std::deque<FlightPayload> payload_deque_;

Status InitializeWriter() {
auto payload_writer =
std::make_unique<ServerRecordBatchPayloadWriter>(&payload_deque_);
ARROW_ASSIGN_OR_RAISE(
writer_, ipc::internal::OpenRecordBatchWriter(std::move(payload_writer),
reader_->schema(), options_));
return Status::OK();
}
};

FlightMetadataWriter::~FlightMetadataWriter() = default;
Expand Down
131 changes: 130 additions & 1 deletion python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,40 @@ def do_action(self, context, action):
raise NotImplementedError


class EchoTableStreamFlightServer(EchoFlightServer):
"""An echo server that streams the whole table."""

def do_get(self, context, ticket):
return flight.GeneratorStream(
self.last_message.schema,
[self.last_message])

def list_actions(self, context):
return []

def do_action(self, context, action):
if action.type == "who-am-i":
return [context.peer_identity(), context.peer().encode("utf-8")]
raise NotImplementedError


class EchoRecordBatchReaderStreamFlightServer(EchoFlightServer):
"""An echo server that streams the whole table as a RecordBatchReader."""

def do_get(self, context, ticket):
return flight.GeneratorStream(
self.last_message.schema,
[self.last_message.to_reader()])

def list_actions(self, context):
return []

def do_action(self, context, action):
if action.type == "who-am-i":
return [context.peer_identity(), context.peer().encode("utf-8")]
raise NotImplementedError


class GetInfoFlightServer(FlightServerBase):
"""A Flight server that tests GetFlightInfo."""

Expand Down Expand Up @@ -1362,7 +1396,7 @@ def test_flight_large_message():
assert result.equals(data)


def test_flight_generator_stream():
def test_flight_generator_stream_of_batches():
"""Try downloading a flight of RecordBatches in a GeneratorStream."""
data = pa.Table.from_arrays([
pa.array(range(0, 10 * 1024))
Expand All @@ -1378,6 +1412,101 @@ def test_flight_generator_stream():
assert result.equals(data)


def test_flight_generator_stream_of_batches_with_dict():
"""
Try downloading a flight of RecordBatches with dictionaries
in a GeneratorStream.
"""
data = pa.Table.from_arrays([
pa.array(["foo", "bar", "baz", "foo", "foo"],
pa.dictionary(pa.int64(), pa.utf8())),
pa.array([123, 234, 345, 456, 567])
], names=['a', 'b'])

with EchoRecordBatchReaderStreamFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
assert result.equals(data)


def test_flight_generator_stream_of_table():
"""Try downloading a flight of Table in a GeneratorStream."""
data = pa.Table.from_arrays([
pa.array(range(0, 10 * 1024))
], names=['a'])

with EchoTableStreamFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
assert result.equals(data)


def test_flight_generator_stream_of_table_with_dict():
"""
Try downloading a flight of Table with dictionaries
in a GeneratorStream.
"""
data = pa.Table.from_arrays([
pa.array(["foo", "bar", "baz", "foo", "foo"],
pa.dictionary(pa.int64(), pa.utf8())),
pa.array([123, 234, 345, 456, 567])
], names=['a', 'b'])

with EchoRecordBatchReaderStreamFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
assert result.equals(data)


def test_flight_generator_stream_of_record_batch_reader():
"""Try downloading a flight of RecordBatchReader in a GeneratorStream."""
data = pa.Table.from_arrays([
pa.array(range(0, 10 * 1024))
], names=['a'])

with EchoRecordBatchReaderStreamFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
assert result.equals(data)


def test_flight_generator_stream_of_record_batch_reader_with_dict():
"""
Try downloading a flight of RecordBatchReader with dictionaries
in a GeneratorStream.
"""
data = pa.Table.from_arrays([
pa.array(["foo", "bar", "baz", "foo", "foo"],
pa.dictionary(pa.int64(), pa.utf8())),
pa.array([123, 234, 345, 456, 567])
], names=['a', 'b'])

with EchoRecordBatchReaderStreamFlightServer() as server, \
FlightClient(('localhost', server.port)) as client:
writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
data.schema)
writer.write_table(data)
writer.close()
result = client.do_get(flight.Ticket(b'')).read_all()
assert result.equals(data)


def test_flight_invalid_generator_stream():
"""Try streaming data with mismatched schemas."""
with InvalidStreamFlightServer() as server, \
Expand Down
Loading