From a7ee65b0016549873bb9acce0e9040a1ae429dff Mon Sep 17 00:00:00 2001 From: Amy Wu Date: Wed, 11 Feb 2026 16:48:15 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 868902411 --- google/genai/_api_client.py | 316 +++++++++++------- google/genai/_common.py | 3 + google/genai/errors.py | 62 +++- .../genai/tests/client/test_client_close.py | 8 + .../client/test_client_initialization.py | 16 +- google/genai/tests/client/test_retries.py | 1 + 6 files changed, 260 insertions(+), 146 deletions(-) diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index 2dfc0366b..f9ac074c5 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -1,4 +1,4 @@ -# Copyright 2025 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ import sys import threading import time -from typing import Any, AsyncIterator, Iterator, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, AsyncIterator, Iterator, Optional, TYPE_CHECKING, Tuple, Union from urllib.parse import urlparse from urllib.parse import urlunparse import warnings @@ -44,9 +44,12 @@ import google.auth import google.auth.credentials from google.auth.credentials import Credentials +from google.auth.transport import mtls import httpx from pydantic import BaseModel from pydantic import ValidationError +import requests +from requests.structures import CaseInsensitiveDict import tenacity from . import _common @@ -59,6 +62,16 @@ from .types import ResourceScope +try: + from google.auth.transport.requests import AuthorizedSession + from google.auth.aio.credentials import StaticCredentials + from google.auth.aio.transport.sessions import AsyncAuthorizedSession +except ImportError: + # This try/except is for TAP + StaticCredentials = None + AsyncAuthorizedSession = None + mtls = None + try: from websockets.asyncio.client import connect as ws_connect except ModuleNotFoundError: @@ -182,12 +195,6 @@ def join_url_path(base_url: str, path: str) -> str: def load_auth(*, project: Union[str, None]) -> Tuple[Credentials, str]: """Loads google auth credentials and project id.""" - ## Set GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES to false - ## to disable bound token sharing. Tracking on - ## https://github.com/googleapis/python-genai/issues/1956 - os.environ['GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES'] = ( - 'false' - ) credentials, loaded_project_id = google.auth.default( # type: ignore[no-untyped-call] scopes=['https://www.googleapis.com/auth/cloud-platform'], ) @@ -235,7 +242,12 @@ class HttpResponse: def __init__( self, - headers: Union[dict[str, str], httpx.Headers, 'CIMultiDictProxy[str]'], + headers: Union[ + dict[str, str], + httpx.Headers, + 'CIMultiDictProxy[str]', + CaseInsensitiveDict, + ], response_stream: Union[Any, str] = None, byte_stream: Union[Any, bytes] = None, ): @@ -245,6 +257,10 @@ def __init__( self.headers = { key: ', '.join(headers.get_list(key)) for key in headers.keys() } + elif isinstance(headers, CaseInsensitiveDict): + self.headers = { + key: value for key,value in headers.items() + } elif type(headers).__name__ == 'CIMultiDictProxy': self.headers = { key: ', '.join(headers.getall(key)) for key in headers.keys() @@ -321,7 +337,10 @@ def _copy_to_dict(self, response_payload: dict[str, object]) -> None: def _iter_response_stream(self) -> Iterator[str]: """Iterates over chunks retrieved from the API.""" - if not isinstance(self.response_stream, httpx.Response): + if not ( + isinstance(self.response_stream, httpx.Response) + or isinstance(self.response_stream, requests.Response) + ): raise TypeError( 'Expected self.response_stream to be an httpx.Response object, ' f'but got {type(self.response_stream).__name__}.' @@ -329,7 +348,11 @@ def _iter_response_stream(self) -> Iterator[str]: chunk = '' balance = 0 - for line in self.response_stream.iter_lines(): + if isinstance(self.response_stream, httpx.Response): + response_stream = self.response_stream.iter_lines() + else: + response_stream = self.response_stream.iter_lines(decode_unicode=True) + for line in response_stream: if not line: continue @@ -729,8 +752,11 @@ def __init__( self._http_options ) self._async_httpx_client_args = async_client_args + self.authorized_session: Optional[AuthorizedSession] = None - if self._http_options.httpx_client: + if self._use_google_auth_sync(): + self._httpx_client = None + elif self._http_options.httpx_client: self._httpx_client = self._http_options.httpx_client else: self._httpx_client = SyncHttpxClient(**client_args) @@ -759,13 +785,36 @@ def __init__( self._retry = tenacity.Retrying(**retry_kwargs) self._async_retry = tenacity.AsyncRetrying(**retry_kwargs) - async def _get_aiohttp_session(self) -> 'aiohttp.ClientSession': + def _use_google_auth_sync(self) -> bool: + return self.vertexai and not ( + self._http_options.httpx_client or self._http_options.client_args + ) + + def _use_google_auth_async(self) -> bool: + return ( + StaticCredentials + and AsyncAuthorizedSession + and self.vertexai + and not ( + self._http_options.aiohttp_client + or self._http_options.async_client_args + ) + ) + + async def _get_aiohttp_session( + self, + ) -> Union['aiohttp.ClientSession', 'AsyncAuthorizedSession']: """Returns the aiohttp client session.""" - if ( - self._aiohttp_session is None - or self._aiohttp_session.closed - or self._aiohttp_session._loop.is_closed() # pylint: disable=protected-access - ): + + # Use aiohttp directly + if self._aiohttp_session is None or ( + isinstance(self._aiohttp_session, aiohttp.ClientSession) + and ( + self._aiohttp_session.closed + or self._aiohttp_session._loop.is_closed() + ) + ): # pylint: disable=protected-access + # Initialize the aiohttp client session if it's not set up or closed. class AiohttpClientSession(aiohttp.ClientSession): # type: ignore[misc] @@ -802,6 +851,17 @@ def __del__(self, _warnings: Any = warnings) -> None: trust_env=True, read_bufsize=READ_BUFFER_SIZE, ) + # Use google.auth if available. + if self._use_google_auth_async(): + token = await self._async_access_token() + async_creds = StaticCredentials(token=token) + auth_request = google.auth.aio.transport.aiohttp.Request( + session=self._aiohttp_session, + ) + self._aiohttp_session = AsyncAuthorizedSession( + async_creds, auth_request + ) + return self._aiohttp_session return self._aiohttp_session @staticmethod @@ -1190,31 +1250,50 @@ def _request_once( else: data = http_request.data - if stream: - httpx_request = self._httpx_client.build_request( - method=http_request.method, - url=http_request.url, - content=data, + if self._use_google_auth_sync(): + url = http_request.url + if self.authorized_session is None: + self.authorized_session = AuthorizedSession( + self._credentials, + max_refresh_attempts=1, + ) + # Application default SSL credentials will be used to configure mtls + # channel. + self.authorized_session.configure_mtls_channel() + if ( + self.authorized_session._is_mtls + and 'googleapis.com' in http_request.url + ): + if 'sandbox' in http_request.url: + url = http_request.url.replace( + 'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com' + ) + else: + url = http_request.url.replace( + 'googleapis.com', 'mtls.googleapis.com' + ) + print('request.url: %s' % url) + response = self.authorized_session.request( + method=http_request.method.upper(), + url=url, + data=data, headers=http_request.headers, timeout=http_request.timeout, - ) - response = self._httpx_client.send(httpx_request, stream=stream) - errors.APIError.raise_for_response(response) - return HttpResponse( - response.headers, response if stream else [response.text] + stream=stream, ) else: - response = self._httpx_client.request( + httpx_request = self._httpx_client.build_request( method=http_request.method, url=http_request.url, - headers=http_request.headers, content=data, + headers=http_request.headers, timeout=http_request.timeout, ) - errors.APIError.raise_for_response(response) - return HttpResponse( - response.headers, response if stream else [response.text] - ) + response = self._httpx_client.send(httpx_request, stream=stream) + errors.APIError.raise_for_response(response) + return HttpResponse( + response.headers, response if stream else [response.text] + ) def _request( self, @@ -1258,107 +1337,87 @@ async def _async_request_once( else: data = http_request.data - if stream: - if self._use_aiohttp(): - self._aiohttp_session = await self._get_aiohttp_session() - try: - response = await self._aiohttp_session.request( - method=http_request.method, - url=http_request.url, - headers=http_request.headers, - data=data, - timeout=aiohttp.ClientTimeout(total=http_request.timeout), - **self._async_client_session_request_args, - ) - except ( - aiohttp.ClientConnectorError, - aiohttp.ClientConnectorDNSError, - aiohttp.ClientOSError, - aiohttp.ServerDisconnectedError, - ) as e: - await asyncio.sleep(1 + random.randint(0, 9)) - logger.info('Retrying due to aiohttp error: %s' % e) - # Retrieve the SSL context from the session. - self._async_client_session_request_args = ( - self._ensure_aiohttp_ssl_ctx(self._http_options) - ) - # Instantiate a new session with the updated SSL context. - self._aiohttp_session = await self._get_aiohttp_session() - response = await self._aiohttp_session.request( - method=http_request.method, - url=http_request.url, - headers=http_request.headers, - data=data, - timeout=aiohttp.ClientTimeout(total=http_request.timeout), - **self._async_client_session_request_args, - ) - - await errors.APIError.raise_for_async_response(response) - return HttpResponse(response.headers, response) - else: - # aiohttp is not available. Fall back to httpx. - httpx_request = self._async_httpx_client.build_request( + if self._use_aiohttp(): + self._aiohttp_session = await self._get_aiohttp_session() + url = http_request.url + if self._use_google_auth_async(): + self._async_client_session_request_args['max_allowed_time'] = float( + 'inf' + ) + self._async_client_session_request_args['total_attempts'] = 1 + # Application default SSL credentials will be used to configure mtls + # channel. + await self._aiohttp_session.configure_mtls_channel() + if ( + self._aiohttp_session._is_mtls + and 'googleapis.com' in http_request.url + ): + if 'sandbox' in http_request.url: + url = http_request.url.replace( + 'sandbox.googleapis.com', 'mtls.sandbox.googleapis.com' + ) + else: + url = http_request.url.replace( + 'googleapis.com', 'mtls.googleapis.com' + ) + try: + print('request.url: %s' % url) + response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, - content=data, + url=url, headers=http_request.headers, - timeout=http_request.timeout, + data=data, + timeout=aiohttp.ClientTimeout(total=http_request.timeout), + **self._async_client_session_request_args, ) - client_response = await self._async_httpx_client.send( - httpx_request, - stream=stream, + except ( + aiohttp.ClientConnectorError, + aiohttp.ClientConnectorDNSError, + aiohttp.ClientOSError, + aiohttp.ServerDisconnectedError, + ) as e: + await asyncio.sleep(1 + random.randint(0, 9)) + logger.info('Retrying due to aiohttp error: %s' % e) + # Retrieve the SSL context from the session. + self._async_client_session_request_args = self._ensure_aiohttp_ssl_ctx( + self._http_options ) - await errors.APIError.raise_for_async_response(client_response) - return HttpResponse(client_response.headers, client_response) - else: - if self._use_aiohttp(): + # Instantiate a new session with the updated SSL context. self._aiohttp_session = await self._get_aiohttp_session() - try: - response = await self._aiohttp_session.request( - method=http_request.method, - url=http_request.url, - headers=http_request.headers, - data=data, - timeout=aiohttp.ClientTimeout(total=http_request.timeout), - **self._async_client_session_request_args, - ) - await errors.APIError.raise_for_async_response(response) - return HttpResponse(response.headers, [await response.text()]) - except ( - aiohttp.ClientConnectorError, - aiohttp.ClientConnectorDNSError, - aiohttp.ClientOSError, - aiohttp.ServerDisconnectedError, - ) as e: - await asyncio.sleep(1 + random.randint(0, 9)) - logger.info('Retrying due to aiohttp error: %s' % e) - # Retrieve the SSL context from the session. - self._async_client_session_request_args = ( - self._ensure_aiohttp_ssl_ctx(self._http_options) - ) - # Instantiate a new session with the updated SSL context. - self._aiohttp_session = await self._get_aiohttp_session() - response = await self._aiohttp_session.request( - method=http_request.method, - url=http_request.url, - headers=http_request.headers, - data=data, - timeout=aiohttp.ClientTimeout(total=http_request.timeout), - **self._async_client_session_request_args, - ) - await errors.APIError.raise_for_async_response(response) - return HttpResponse(response.headers, [await response.text()]) - else: - # aiohttp is not available. Fall back to httpx. - client_response = await self._async_httpx_client.request( + response = await self._aiohttp_session.request( method=http_request.method, - url=http_request.url, + url=url, headers=http_request.headers, - content=data, - timeout=http_request.timeout, + data=data, + timeout=aiohttp.ClientTimeout(total=http_request.timeout), + **self._async_client_session_request_args, ) - await errors.APIError.raise_for_async_response(client_response) - return HttpResponse(client_response.headers, [client_response.text]) + await errors.APIError.raise_for_async_response(response) + if self._use_google_auth_async() and response: + # Extract the underlying aiohttp.ClientResponse from the + # AsyncAuthorizedSession Response. + response = response._response + return HttpResponse( + response.headers, response if stream else [await response.text()] + ) + else: + # aiohttp is not available. Fall back to httpx. + httpx_request = self._async_httpx_client.build_request( + method=http_request.method, + url=http_request.url, + content=data, + headers=http_request.headers, + timeout=http_request.timeout, + ) + client_response = await self._async_httpx_client.send( + httpx_request, + stream=stream, + ) + await errors.APIError.raise_for_async_response(client_response) + return HttpResponse( + client_response.headers, + client_response if stream else [client_response.text], + ) async def _async_request( self, @@ -1909,7 +1968,7 @@ def close(self) -> None: """Closes the API client.""" # Let users close the custom client explicitly by themselves. Otherwise, # close the client when the object is garbage collected. - if not self._http_options.httpx_client: + if not self._http_options.httpx_client and self._httpx_client: self._httpx_client.close() async def aclose(self) -> None: @@ -1951,6 +2010,7 @@ def get_token_from_credentials( raise RuntimeError('Could not resolve API token from the environment') return credentials.token # type: ignore[no-any-return] + async def async_get_token_from_credentials( client: 'BaseApiClient', credentials: google.auth.credentials.Credentials diff --git a/google/genai/_common.py b/google/genai/_common.py index f4a6984f7..187e9968b 100644 --- a/google/genai/_common.py +++ b/google/genai/_common.py @@ -21,11 +21,14 @@ import enum import functools import logging +import os import re import typing from typing import Any, Callable, FrozenSet, Optional, Union, get_args, get_origin import uuid import warnings +from google.auth.exceptions import MutualTLSChannelError +from google.auth.transport import mtls import pydantic from pydantic import alias_generators from typing_extensions import TypeAlias diff --git a/google/genai/errors.py b/google/genai/errors.py index 63d9334b9..b2f7df478 100644 --- a/google/genai/errors.py +++ b/google/genai/errors.py @@ -18,9 +18,13 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, Union import httpx import json -import websockets +import requests from . import _common +try: + from google.auth.aio.transport.aiohttp import Response as AsyncAuthorizedSessionResponse +except ImportError: + AsyncAuthorizedSessionResponse = Any if TYPE_CHECKING: from .replay_api_client import ReplayResponse @@ -30,7 +34,12 @@ class APIError(Exception): """General errors raised by the GenAI API.""" code: int - response: Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + response: Union[ + requests.Response, + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + ] status: Optional[str] = None message: Optional[str] = None @@ -40,7 +49,12 @@ def __init__( code: int, response_json: Any, response: Optional[ - Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + Union[ + requests.Response, + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + ] ] = None, ): if isinstance(response_json, list) and len(response_json) == 1: @@ -112,7 +126,7 @@ def _to_replay_record(self) -> _common.StringDict: @classmethod def raise_for_response( - cls, response: Union['ReplayResponse', httpx.Response] + cls, response: Union['ReplayResponse', httpx.Response, requests.Response] ) -> None: """Raises an error with detailed error message if the response has an error status.""" if response.status_code == 200: @@ -128,6 +142,16 @@ def raise_for_response( 'message': message, 'status': response.reason_phrase, } + elif isinstance(response, requests.Response): + try: + # do not do any extra muanipulation on the response. + # return the raw response json as is. + response_json = response.json() + except requests.exceptions.JSONDecodeError: + response_json = { + 'message': response.text, + 'status': response.reason, + } else: response_json = response.body_segments[0].get('error', {}) @@ -139,7 +163,12 @@ def raise_error( status_code: int, response_json: Any, response: Optional[ - Union['ReplayResponse', httpx.Response, 'aiohttp.ClientResponse'] + Union[ + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + requests.Response, + ] ], ) -> None: """Raises an appropriate APIError subclass based on the status code. @@ -166,7 +195,10 @@ def raise_error( async def raise_for_async_response( cls, response: Union[ - 'ReplayResponse', httpx.Response, 'aiohttp.ClientResponse' + 'ReplayResponse', + httpx.Response, + 'AsyncAuthorizedSessionResponse', + 'google.auth.aio.transport.aiohttp.Response', ], ) -> None: """Raises an error with detailed error message if the response has an error status.""" @@ -196,18 +228,24 @@ async def raise_for_async_response( try: import aiohttp # pylint: disable=g-import-not-at-top - if isinstance(response, aiohttp.ClientResponse): - if response.status == 200: + if isinstance(response, aiohttp.ClientResponse) or isinstance( + response, AsyncAuthorizedSessionResponse + ): + if isinstance(response, AsyncAuthorizedSessionResponse): + aiohttp_response = response._response + else: + aiohttp_response = response + if aiohttp_response.status == 200: return try: - response_json = await response.json() + response_json = await aiohttp_response.json() except aiohttp.client_exceptions.ContentTypeError: - message = await response.text() + message = await aiohttp_response.text() response_json = { 'message': message, - 'status': response.reason, + 'status': aiohttp_response.reason, } - status_code = response.status + status_code = aiohttp_response.status else: raise ValueError(f'Unsupported response type: {type(response)}') except ImportError: diff --git a/google/genai/tests/client/test_client_close.py b/google/genai/tests/client/test_client_close.py index 2beaf7ea8..00e2bc6bc 100644 --- a/google/genai/tests/client/test_client_close.py +++ b/google/genai/tests/client/test_client_close.py @@ -43,6 +43,7 @@ def test_close_httpx_client(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions(client_args={'max_redirects': 10}), ) client.close() assert client._api_client._httpx_client.is_closed @@ -55,6 +56,7 @@ def test_httpx_client_context_manager(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions(client_args={'max_redirects': 10}), ) as client: pass assert not client._api_client._httpx_client.is_closed @@ -135,6 +137,9 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions( + async_client_args={'trust_env': False} + ), ).aio # aiohttp session is created in the first request instead of client # initialization. @@ -176,6 +181,9 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options=api_client.HttpOptions( + async_client_args={'trust_env': False} + ), ).aio as async_client: # aiohttp session is created in the first request instead of client # initialization. diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index 7b0136044..1e2daea10 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -20,6 +20,7 @@ import concurrent.futures import logging import os +import requests import ssl from unittest import mock @@ -1332,13 +1333,16 @@ def refresh_side_effect(request): mock_creds.refresh = mock_refresh # Mock the actual request to avoid network calls - mock_httpx_response = httpx.Response( - status_code=200, - headers={}, - text='{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}', + mock_http_response = requests.Response() + mock_http_response.status_code = 200 + mock_http_response.headers = {} + mock_http_response._content = ( + b'{"candidates": [{"content": {"parts": [{"text": "response"}]}}]}' + ) + mock_request = mock.Mock(return_value=mock_http_response) + monkeypatch.setattr( + google.auth.transport.requests.AuthorizedSession, "request", mock_request ) - mock_request = mock.Mock(return_value=mock_httpx_response) - monkeypatch.setattr(api_client.SyncHttpxClient, "request", mock_request) client = Client( vertexai=True, project="fake_project_id", location="fake-location" diff --git a/google/genai/tests/client/test_retries.py b/google/genai/tests/client/test_retries.py index da2eea752..4aa841154 100644 --- a/google/genai/tests/client/test_retries.py +++ b/google/genai/tests/client/test_retries.py @@ -834,6 +834,7 @@ async def run(): vertexai=True, project='test_project', location='global', + http_options={'aiohttp_client': aiohttp.ClientSession(trust_env=False)}, ) with _patch_auth_default():