diff --git a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py index 7adcdd1c9..5e84485d5 100644 --- a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py @@ -151,9 +151,16 @@ async def close(self) -> None: """Closes the bidi-gRPC connection.""" if not self._is_stream_open: raise ValueError("Stream is not open") + await self.requests_done() await self.socket_like_rpc.close() self._is_stream_open = False + async def requests_done(self): + """Signals that all requests have been sent.""" + + await self.socket_like_rpc.send(None) + await self.socket_like_rpc.recv() + async def send( self, bidi_read_object_request: _storage_v2.BidiReadObjectRequest ) -> None: diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 183a8eeb1..a0ebaa498 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -152,9 +152,16 @@ async def close(self) -> None: """Closes the bidi-gRPC connection.""" if not self._is_stream_open: raise ValueError("Stream is not open") + await self.requests_done() await self.socket_like_rpc.close() self._is_stream_open = False + async def requests_done(self): + """Signals that all requests have been sent.""" + + await self.socket_like_rpc.send(None) + await self.socket_like_rpc.recv() + async def send( self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest ) -> None: @@ -186,3 +193,4 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse: @property def is_stream_open(self) -> bool: return self._is_stream_open + diff --git a/tests/unit/asyncio/test_async_read_object_stream.py b/tests/unit/asyncio/test_async_read_object_stream.py index 4ba8d34a1..18e0c464d 100644 --- a/tests/unit/asyncio/test_async_read_object_stream.py +++ b/tests/unit/asyncio/test_async_read_object_stream.py @@ -197,15 +197,41 @@ async def test_close(mock_client, mock_cls_async_bidi_rpc): read_obj_stream = await instantiate_read_obj_stream( mock_client, mock_cls_async_bidi_rpc, open=True ) + read_obj_stream.requests_done = AsyncMock() # act await read_obj_stream.close() # assert + read_obj_stream.requests_done.assert_called_once() read_obj_stream.socket_like_rpc.close.assert_called_once() assert not read_obj_stream.is_stream_open +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" +) +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" +) +@pytest.mark.asyncio +async def test_requests_done(mock_client, mock_cls_async_bidi_rpc): + """Test that requests_done signals the end of requests.""" + # Arrange + read_obj_stream = await instantiate_read_obj_stream( + mock_client, mock_cls_async_bidi_rpc, open=True + ) + read_obj_stream.socket_like_rpc.send = AsyncMock() + read_obj_stream.socket_like_rpc.recv = AsyncMock() + + # Act + await read_obj_stream.requests_done() + + # Assert + read_obj_stream.socket_like_rpc.send.assert_called_once_with(None) + read_obj_stream.socket_like_rpc.recv.assert_called_once() + + @mock.patch( "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" ) diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index c6ea8a8ff..63b5495bd 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -289,11 +289,13 @@ async def test_close(mock_cls_async_bidi_rpc, mock_client): write_obj_stream = await instantiate_write_obj_stream( mock_client, mock_cls_async_bidi_rpc, open=True ) + write_obj_stream.requests_done = AsyncMock() # Act await write_obj_stream.close() # Assert + write_obj_stream.requests_done.assert_called_once() write_obj_stream.socket_like_rpc.close.assert_called_once() assert not write_obj_stream.is_stream_open @@ -394,3 +396,24 @@ async def test_recv_without_open_should_raise_error( # Act & Assert with pytest.raises(ValueError, match="Stream is not open"): await write_obj_stream.recv() + + +@pytest.mark.asyncio +@mock.patch( + "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" +) +async def test_requests_done(mock_cls_async_bidi_rpc, mock_client): + """Test that requests_done signals the end of requests.""" + # Arrange + write_obj_stream = await instantiate_write_obj_stream( + mock_client, mock_cls_async_bidi_rpc, open=True + ) + write_obj_stream.socket_like_rpc.send = AsyncMock() + write_obj_stream.socket_like_rpc.recv = AsyncMock() + + # Act + await write_obj_stream.requests_done() + + # Assert + write_obj_stream.socket_like_rpc.send.assert_called_once_with(None) + write_obj_stream.socket_like_rpc.recv.assert_called_once() \ No newline at end of file