diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 2dfc0366b..f22ce0c9b 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -1461,7 +1461,7 @@ async def async_request_streamed( http_method, path, request_dict, http_options ) - response = await self._async_request(http_request=http_request, stream=True) + response = await self._async_request(http_request=http_request, http_options=http_options, stream=True) async def async_generator(): # type: ignore[no-untyped-def] async for chunk in response: diff --git a/google/genai/tests/client/test_retries.py b/google/genai/tests/client/test_retries.py index da2eea752..c01c9091c 100644 --- a/google/genai/tests/client/test_retries.py +++ b/google/genai/tests/client/test_retries.py @@ -588,13 +588,17 @@ async def run(): # Async aiohttp -async def _aiohttp_async_response(status: int): +async def _aiohttp_async_response(status: int, streamable: bool = False): """Has to return a coroutine hence async.""" response = mock.Mock(spec=aiohttp.ClientResponse) response.status = status response.headers = {'status-code': str(status)} response.json.return_value = {} response.text.return_value = 'test' + if streamable: + response.content = mock.Mock() + response.content.readline = mock.AsyncMock(return_value=b'') + response.release = mock.MagicMock() return response @@ -810,6 +814,409 @@ async def run(): asyncio.run(run()) +# Sync Streaming + + +def test_retries_streamed_failed_request_retries_successfully(): + mock_transport = mock.Mock(spec=httpx.BaseTransport) + mock_transport.handle_request.side_effect = ( + _httpx_response(429), + _httpx_response(200), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + transport=mock_transport, + ), + ) + + with _patch_auth_default(): + stream = client.request_streamed( + http_method='GET', path='path', request_dict={} + ) + list(stream) + mock_transport.handle_request.assert_called() + assert mock_transport.handle_request.call_count == 2 + + +def test_retries_streamed_failed_request_retries_successfully_at_request_level(): + mock_transport = mock.Mock(spec=httpx.BaseTransport) + mock_transport.handle_request.side_effect = ( + _httpx_response(429), + _httpx_response(200), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + transport=mock_transport, + ), + ) + + with _patch_auth_default(): + stream = client.request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ) + list(stream) + mock_transport.handle_request.assert_called() + assert mock_transport.handle_request.call_count == 2 + + +def test_retries_streamed_failed_request_retries_unsuccessfully(): + mock_transport = mock.Mock(spec=httpx.BaseTransport) + mock_transport.handle_request.side_effect = ( + _httpx_response(429), + _httpx_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + transport=mock_transport, + ), + ) + + with _patch_auth_default(): + try: + stream = client.request_streamed( + http_method='GET', path='path', request_dict={} + ) + list(stream) + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_transport.handle_request.assert_called() + + +def test_retries_streamed_failed_request_retries_unsuccessfully_at_request_level(): + mock_transport = mock.Mock(spec=httpx.BaseTransport) + mock_transport.handle_request.side_effect = ( + _httpx_response(429), + _httpx_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + transport=mock_transport, + ), + ) + + with _patch_auth_default(): + try: + stream = client.request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options={'retry_options': _RETRY_OPTIONS}, + ) + list(stream) + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_transport.handle_request.assert_called() + + +# Async httpx Streaming + + +def test_async_retries_streamed_failed_request_retries_successfully(): + api_client.has_aiohttp = False + + async def run(): + mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport) + mock_transport.handle_async_request.side_effect = ( + _httpx_response(429), + _httpx_response(200), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + async_transport=mock_transport, + ), + ) + + with _patch_auth_default(): + stream = await client.async_request_streamed( + http_method='GET', path='path', request_dict={} + ) + async for _ in stream: + pass + mock_transport.handle_async_request.assert_called() + assert mock_transport.handle_async_request.call_count == 2 + + asyncio.run(run()) + + +def test_async_retries_streamed_failed_request_retries_successfully_at_request_level(): + api_client.has_aiohttp = False + + async def run(): + mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport) + mock_transport.handle_async_request.side_effect = ( + _httpx_response(429), + _httpx_response(200), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + async_transport=mock_transport, + ), + ) + + with _patch_auth_default(): + stream = await client.async_request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ) + async for _ in stream: + pass + mock_transport.handle_async_request.assert_called() + assert mock_transport.handle_async_request.call_count == 2 + + asyncio.run(run()) + + +def test_async_retries_streamed_failed_request_retries_unsuccessfully(): + api_client.has_aiohttp = False + + async def run(): + mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport) + mock_transport.handle_async_request.side_effect = ( + _httpx_response(429), + _httpx_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + async_transport=mock_transport, + ), + ) + + with _patch_auth_default(): + try: + stream = await client.async_request_streamed( + http_method='GET', path='path', request_dict={} + ) + async for _ in stream: + pass + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_transport.handle_async_request.assert_called() + + asyncio.run(run()) + + +def test_async_retries_streamed_failed_request_retries_unsuccessfully_at_request_level(): + api_client.has_aiohttp = False + + async def run(): + mock_transport = mock.Mock(spec=httpx.AsyncBaseTransport) + mock_transport.handle_async_request.side_effect = ( + _httpx_response(429), + _httpx_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + async_transport=mock_transport, + ), + ) + + with _patch_auth_default(): + try: + stream = await client.async_request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ) + async for _ in stream: + pass + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_transport.handle_async_request.assert_called() + + asyncio.run(run()) + + +# Async aiohttp Streaming + + +@requires_aiohttp +@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +def test_aiohttp_retries_streamed_failed_request_retries_successfully( + mock_request, +): + api_client.has_aiohttp = True + + async def run(): + mock_request.side_effect = ( + _aiohttp_async_response(429), + _aiohttp_async_response(200, streamable=True), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ), + ) + + with _patch_auth_default(): + stream = await client.async_request_streamed( + http_method='GET', path='path', request_dict={} + ) + async for _ in stream: + pass + mock_request.assert_called() + assert mock_request.call_count == 2 + + asyncio.run(run()) + + +@requires_aiohttp +@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +def test_aiohttp_retries_streamed_failed_request_retries_successfully_at_request_level( + mock_request, +): + api_client.has_aiohttp = True + + async def run(): + mock_request.side_effect = ( + _aiohttp_async_response(429), + _aiohttp_async_response(200, streamable=True), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + ) + + with _patch_auth_default(): + stream = await client.async_request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ) + async for _ in stream: + pass + mock_request.assert_called() + assert mock_request.call_count == 2 + + asyncio.run(run()) + + +@requires_aiohttp +@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +def test_aiohttp_retries_streamed_failed_request_retries_unsuccessfully( + mock_request, +): + api_client.has_aiohttp = True + + async def run(): + mock_request.side_effect = ( + _aiohttp_async_response(429), + _aiohttp_async_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + http_options=_transport_options( + http_options=types.HttpOptions(retry_options=_RETRY_OPTIONS), + ), + ) + + with _patch_auth_default(): + try: + stream = await client.async_request_streamed( + http_method='GET', path='path', request_dict={} + ) + async for _ in stream: + pass + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_request.assert_called() + + asyncio.run(run()) + + +@requires_aiohttp +@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) +def test_aiohttp_retries_streamed_failed_request_retries_unsuccessfully_at_request_level( + mock_request, +): + api_client.has_aiohttp = True + + async def run(): + mock_request.side_effect = ( + _aiohttp_async_response(429), + _aiohttp_async_response(504), + ) + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + ) + + with _patch_auth_default(): + try: + stream = await client.async_request_streamed( + http_method='GET', + path='path', + request_dict={}, + http_options={'retry_options': _RETRY_OPTIONS}, + ) + async for _ in stream: + pass + assert False, 'Expected APIError to be raised.' + except errors.APIError as e: + assert e.code == 504 + mock_request.assert_called() + + asyncio.run(run()) + + @requires_aiohttp @mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) def test_aiohttp_retries_client_connector_error_retries_successfully(