@@ -246,6 +246,40 @@ def do_action(self, context, action):
246246 raise NotImplementedError
247247
248248
249+ class EchoTableStreamFlightServer (EchoFlightServer ):
250+ """An echo server that streams the whole table."""
251+
252+ def do_get (self , context , ticket ):
253+ return flight .GeneratorStream (
254+ self .last_message .schema ,
255+ [self .last_message ])
256+
257+ def list_actions (self , context ):
258+ return []
259+
260+ def do_action (self , context , action ):
261+ if action .type == "who-am-i" :
262+ return [context .peer_identity (), context .peer ().encode ("utf-8" )]
263+ raise NotImplementedError
264+
265+
266+ class EchoRecordBatchReaderStreamFlightServer (EchoFlightServer ):
267+ """An echo server that streams the whole table as a RecordBatchReader."""
268+
269+ def do_get (self , context , ticket ):
270+ return flight .GeneratorStream (
271+ self .last_message .schema ,
272+ [self .last_message .to_reader ()])
273+
274+ def list_actions (self , context ):
275+ return []
276+
277+ def do_action (self , context , action ):
278+ if action .type == "who-am-i" :
279+ return [context .peer_identity (), context .peer ().encode ("utf-8" )]
280+ raise NotImplementedError
281+
282+
249283class GetInfoFlightServer (FlightServerBase ):
250284 """A Flight server that tests GetFlightInfo."""
251285
@@ -1362,7 +1396,7 @@ def test_flight_large_message():
13621396 assert result .equals (data )
13631397
13641398
1365- def test_flight_generator_stream ():
1399+ def test_flight_generator_stream_of_batches ():
13661400 """Try downloading a flight of RecordBatches in a GeneratorStream."""
13671401 data = pa .Table .from_arrays ([
13681402 pa .array (range (0 , 10 * 1024 ))
@@ -1378,6 +1412,38 @@ def test_flight_generator_stream():
13781412 assert result .equals (data )
13791413
13801414
1415+ def test_flight_generator_stream_of_table ():
1416+ """Try downloading a flight of Table in a GeneratorStream."""
1417+ data = pa .Table .from_arrays ([
1418+ pa .array (range (0 , 10 * 1024 ))
1419+ ], names = ['a' ])
1420+
1421+ with EchoTableStreamFlightServer () as server , \
1422+ FlightClient (('localhost' , server .port )) as client :
1423+ writer , _ = client .do_put (flight .FlightDescriptor .for_path ('test' ),
1424+ data .schema )
1425+ writer .write_table (data )
1426+ writer .close ()
1427+ result = client .do_get (flight .Ticket (b'' )).read_all ()
1428+ assert result .equals (data )
1429+
1430+
1431+ def test_flight_generator_stream_of_record_batch_reader ():
1432+ """Try downloading a flight of RecordBatchReader in a GeneratorStream."""
1433+ data = pa .Table .from_arrays ([
1434+ pa .array (range (0 , 10 * 1024 ))
1435+ ], names = ['a' ])
1436+
1437+ with EchoRecordBatchReaderStreamFlightServer () as server , \
1438+ FlightClient (('localhost' , server .port )) as client :
1439+ writer , _ = client .do_put (flight .FlightDescriptor .for_path ('test' ),
1440+ data .schema )
1441+ writer .write_table (data )
1442+ writer .close ()
1443+ result = client .do_get (flight .Ticket (b'' )).read_all ()
1444+ assert result .equals (data )
1445+
1446+
13811447def test_flight_invalid_generator_stream ():
13821448 """Try streaming data with mismatched schemas."""
13831449 with InvalidStreamFlightServer () as server , \
0 commit comments