diff --git a/android-test/src/androidTest/java/okhttp/android/test/AndroidNetworkPinningTest.kt b/android-test/src/androidTest/java/okhttp/android/test/AndroidNetworkPinningTest.kt new file mode 100644 index 000000000000..f3aee73068f8 --- /dev/null +++ b/android-test/src/androidTest/java/okhttp/android/test/AndroidNetworkPinningTest.kt @@ -0,0 +1,125 @@ +/* + * Copyright (C) 2025 Square, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp.android.test + +import android.content.Context +import android.net.ConnectivityManager +import android.net.Network +import android.net.NetworkRequest +import android.os.Build +import androidx.test.core.app.ApplicationProvider +import androidx.test.filters.SdkSuppress +import mockwebserver3.MockResponse +import mockwebserver3.MockWebServer +import mockwebserver3.junit5.StartStop +import okhttp3.OkHttpClient +import okhttp3.OkHttpClientTestRule +import okhttp3.Request +import okhttp3.android.AndroidNetworkPinning +import okhttp3.internal.connection.RealCall +import okhttp3.internal.platform.PlatformRegistry +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Assumptions.assumeTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.RegisterExtension + +@Tag("Slow") +@SdkSuppress(minSdkVersion = Build.VERSION_CODES.Q) +class AndroidNetworkPinningTest { + @Suppress("RedundantVisibilityModifier") + @JvmField + @RegisterExtension + public val clientTestRule = OkHttpClientTestRule() + + val applicationContext = ApplicationProvider.getApplicationContext() + val connectivityManager = applicationContext.getSystemService(ConnectivityManager::class.java) + + val pinning = AndroidNetworkPinning() + + private var client: OkHttpClient = + clientTestRule + .newClientBuilder() + .addCallDecorator(pinning) + .addCallDecorator { + it.proceed( + it.request + .newBuilder() + .header("second-decorator", "true") + .build(), + ) + }.addInterceptor { + val call = (it.call() as RealCall) + val dns = call.client.dns + it + .proceed(it.request()) + .newBuilder() + .header("used-dns", dns.javaClass.simpleName) + .build() + }.build() + + @StartStop + private val server = MockWebServer() + + @BeforeEach + fun setup() { + // Needed because of Platform.resetForTests + PlatformRegistry.applicationContext = applicationContext + + connectivityManager.registerNetworkCallback(NetworkRequest.Builder().build(), pinning.networkCallback) + } + + @Test + fun testDefaultRequest() { + server.enqueue(MockResponse(200, body = "Hello")) + + val request = Request.Builder().url(server.url("/")).build() + + val response = client.newCall(request).execute() + + response.use { + assertEquals(200, response.code) + assertNotEquals("AndroidDns", response.header("used-dns")) + assertEquals("true", response.request.header("second-decorator")) + } + } + + @Test + fun testPinnedRequest() { + server.enqueue(MockResponse(200, body = "Hello")) + + val network = connectivityManager.activeNetwork + + assumeTrue(network != null) + + val request = + Request + .Builder() + .url(server.url("/")) + .tag(network) + .build() + + val response = client.newCall(request).execute() + + response.use { + assertEquals(200, response.code) + assertEquals("AndroidDns", response.header("used-dns")) + assertEquals("true", response.request.header("second-decorator")) + } + } +} diff --git a/okhttp/api/android/okhttp.api b/okhttp/api/android/okhttp.api index 1f0b9839ac73..5b9f15291bb1 100644 --- a/okhttp/api/android/okhttp.api +++ b/okhttp/api/android/okhttp.api @@ -129,6 +129,16 @@ public abstract interface class okhttp3/Call : java/lang/Cloneable { public abstract fun timeout ()Lokio/Timeout; } +public abstract interface class okhttp3/Call$Chain { + public abstract fun getClient ()Lokhttp3/OkHttpClient; + public abstract fun getRequest ()Lokhttp3/Request; + public abstract fun proceed (Lokhttp3/Request;)Lokhttp3/Call; +} + +public abstract interface class okhttp3/Call$Decorator { + public abstract fun newCall (Lokhttp3/Call$Chain;)Lokhttp3/Call; +} + public abstract interface class okhttp3/Call$Factory { public abstract fun newCall (Lokhttp3/Request;)Lokhttp3/Call; } @@ -902,6 +912,7 @@ public class okhttp3/OkHttpClient : okhttp3/Call$Factory, okhttp3/WebSocket$Fact public final fun fastFallback ()Z public final fun followRedirects ()Z public final fun followSslRedirects ()Z + public final fun getCallDecorators ()Ljava/util/List; public final fun hostnameVerifier ()Ljavax/net/ssl/HostnameVerifier; public final fun interceptors ()Ljava/util/List; public final fun minWebSocketMessageToCompress ()J @@ -927,6 +938,7 @@ public final class okhttp3/OkHttpClient$Builder { public final fun -addInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public final fun -addNetworkInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public fun ()V + public final fun addCallDecorator (Lokhttp3/Call$Decorator;)Lokhttp3/OkHttpClient$Builder; public final fun addInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun addNetworkInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun authenticator (Lokhttp3/Authenticator;)Lokhttp3/OkHttpClient$Builder; @@ -1274,3 +1286,9 @@ public abstract class okhttp3/WebSocketListener { public fun onOpen (Lokhttp3/WebSocket;Lokhttp3/Response;)V } +public final class okhttp3/android/AndroidNetworkPinning : okhttp3/Call$Decorator { + public fun ()V + public final fun getNetworkCallback ()Landroid/net/ConnectivityManager$NetworkCallback; + public fun newCall (Lokhttp3/Call$Chain;)Lokhttp3/Call; +} + diff --git a/okhttp/api/jvm/okhttp.api b/okhttp/api/jvm/okhttp.api index ca4df1afdcfa..d4a6d7b9d50d 100644 --- a/okhttp/api/jvm/okhttp.api +++ b/okhttp/api/jvm/okhttp.api @@ -129,6 +129,16 @@ public abstract interface class okhttp3/Call : java/lang/Cloneable { public abstract fun timeout ()Lokio/Timeout; } +public abstract interface class okhttp3/Call$Chain { + public abstract fun getClient ()Lokhttp3/OkHttpClient; + public abstract fun getRequest ()Lokhttp3/Request; + public abstract fun proceed (Lokhttp3/Request;)Lokhttp3/Call; +} + +public abstract interface class okhttp3/Call$Decorator { + public abstract fun newCall (Lokhttp3/Call$Chain;)Lokhttp3/Call; +} + public abstract interface class okhttp3/Call$Factory { public abstract fun newCall (Lokhttp3/Request;)Lokhttp3/Call; } @@ -901,6 +911,7 @@ public class okhttp3/OkHttpClient : okhttp3/Call$Factory, okhttp3/WebSocket$Fact public final fun fastFallback ()Z public final fun followRedirects ()Z public final fun followSslRedirects ()Z + public final fun getCallDecorators ()Ljava/util/List; public final fun hostnameVerifier ()Ljavax/net/ssl/HostnameVerifier; public final fun interceptors ()Ljava/util/List; public final fun minWebSocketMessageToCompress ()J @@ -926,6 +937,7 @@ public final class okhttp3/OkHttpClient$Builder { public final fun -addInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public final fun -addNetworkInterceptor (Lkotlin/jvm/functions/Function1;)Lokhttp3/OkHttpClient$Builder; public fun ()V + public final fun addCallDecorator (Lokhttp3/Call$Decorator;)Lokhttp3/OkHttpClient$Builder; public final fun addInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun addNetworkInterceptor (Lokhttp3/Interceptor;)Lokhttp3/OkHttpClient$Builder; public final fun authenticator (Lokhttp3/Authenticator;)Lokhttp3/OkHttpClient$Builder; diff --git a/okhttp/src/androidMain/kotlin/okhttp3/android/AndroidNetworkPinning.kt b/okhttp/src/androidMain/kotlin/okhttp3/android/AndroidNetworkPinning.kt new file mode 100644 index 000000000000..06579d86312c --- /dev/null +++ b/okhttp/src/androidMain/kotlin/okhttp3/android/AndroidNetworkPinning.kt @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2024 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.android + +import android.net.ConnectivityManager +import android.net.Network +import android.os.Build +import androidx.annotation.RequiresApi +import java.util.Collections +import okhttp3.Call +import okhttp3.OkHttpClient +import okhttp3.android.internal.AndroidDns +import okhttp3.internal.SuppressSignatureCheck + +/** + * Decorator that supports Network Pinning on Android via Request tags. + */ +@RequiresApi(Build.VERSION_CODES.Q) +@SuppressSignatureCheck +class AndroidNetworkPinning : Call.Decorator { + private val pinnedClients = Collections.synchronizedMap(mutableMapOf()) + + /** ConnectivityManager.NetworkCallback that will clean up after networks are lost. */ + val networkCallback = + object : ConnectivityManager.NetworkCallback() { + override fun onLost(network: Network) { + pinnedClients.remove(network.toString()) + } + } + + override fun newCall(chain: Call.Chain): Call { + val request = chain.request + + val pinnedNetwork = request.tag() ?: return chain.proceed(request) + + val pinnedClient = + // API 24+ + pinnedClients.computeIfAbsent(pinnedNetwork.toString()) { + chain.client.withNetwork(network = pinnedNetwork) + } + + return pinnedClient.newCall(request) + } + + private fun OkHttpClient.withNetwork(network: Network): OkHttpClient = + newBuilder() + .dns(AndroidDns(network)) + .socketFactory(network.socketFactory) + .apply { + // Keep decorators after this one in the new client + callDecorators.subList(0, callDecorators.indexOf(this@AndroidNetworkPinning) + 1).clear() + }.build() +} diff --git a/okhttp/src/androidMain/kotlin/okhttp3/android/internal/AndroidDns.kt b/okhttp/src/androidMain/kotlin/okhttp3/android/internal/AndroidDns.kt new file mode 100644 index 000000000000..70fc149b0781 --- /dev/null +++ b/okhttp/src/androidMain/kotlin/okhttp3/android/internal/AndroidDns.kt @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2025 Block, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package okhttp3.android.internal + +import android.net.DnsResolver +import android.net.Network +import android.os.Build +import androidx.annotation.RequiresApi +import java.net.InetAddress +import java.net.UnknownHostException +import java.util.concurrent.CompletableFuture +import okhttp3.Dns +import okhttp3.internal.SuppressSignatureCheck + +@RequiresApi(Build.VERSION_CODES.Q) +@SuppressSignatureCheck +internal class AndroidDns( + val network: Network, +) : Dns { + // API 29+ + private val dnsResolver = DnsResolver.getInstance() + + override fun lookup(hostname: String): List { + // API 24+ + val result = CompletableFuture>() + + dnsResolver.query( + network, + hostname, + DnsResolver.FLAG_EMPTY, + { it.run() }, + null, + object : DnsResolver.Callback> { + override fun onAnswer( + answer: List, + rcode: Int, + ) { + result.complete(answer) + } + + override fun onError(error: DnsResolver.DnsException) { + result.completeExceptionally( + UnknownHostException(error.message).apply { + initCause(error) + }, + ) + } + }, + ) + + return result.get() + } +} diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt index fdd3d3da294e..fc7f87876769 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/Call.kt @@ -96,4 +96,34 @@ interface Call : Cloneable { fun interface Factory { fun newCall(request: Request): Call } + + /** + * The equivalent of an Interceptor for [Call.Factory], but critically supported within an [OkHttpClient]. + * + * While an [Interceptor] forms a chain as part of execution of a Call. Call.Decorator intercepts + * [Call.Factory.newCall] with similar flexibility to Application [OkHttpClient.interceptors]. + * + * That is, it may do any of + * - Modify the request such as adding Tracing Context + * - Wrap the [Call] returned + * - Return some [Call] implementation that will immediately fail avoiding network calls based on network or + * authentication state. + * - Redirect the [Call], such as using an alternative [Call.Factory]. + * - Defer execution, something not safe in an Interceptor. + * + * It should not throw an exception and instead return a Call that will fail on [Call.execute]. + * + * This flexibility means that the app developer configuring the decorators on [OkHttpClient] must be responsible + * for how these are composed in a chain. + */ + fun interface Decorator { + fun newCall(chain: Chain): Call + } + + interface Chain { + val client: OkHttpClient + val request: Request + + fun proceed(request: Request): Call + } } diff --git a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt index de3e75d5e701..170a4df81431 100644 --- a/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt +++ b/okhttp/src/commonJvmAndroid/kotlin/okhttp3/OkHttpClient.kt @@ -145,6 +145,14 @@ open class OkHttpClient internal constructor( val interceptors: List = builder.interceptors.toImmutableList() + /** + * Returns an immutable list of Call decorators that have a chance to return a different, likely + * decorating, implementation of Call. This allows functionality such as fail fast without normal Call + * execution based on network conditions, or setting Tracing context on the calling thread. + */ + val callDecorators: List = + builder.callDecorators.toImmutableList() + /** * Returns an immutable list of interceptors that observe a single network request and response. * These interceptors must call [Interceptor.Chain.proceed] exactly once: it is an error for @@ -265,6 +273,26 @@ open class OkHttpClient internal constructor( internal val routeDatabase: RouteDatabase = builder.routeDatabase ?: RouteDatabase() internal val taskRunner: TaskRunner = builder.taskRunner ?: TaskRunner.INSTANCE + private val decoratedCallFactory = + callDecorators.foldRight( + Call.Factory { request -> + RealCall(client = this, originalRequest = request, forWebSocket = false) + }, + ) { callDecorator, next -> + Call.Factory { request -> + callDecorator.newCall( + object : Call.Chain { + override val client: OkHttpClient + get() = this@OkHttpClient + override val request: Request + get() = request + + override fun proceed(request: Request): Call = next.newCall(request) + }, + ) + } + } + @get:JvmName("connectionPool") val connectionPool: ConnectionPool = builder.connectionPool ?: ConnectionPool( @@ -359,7 +387,7 @@ open class OkHttpClient internal constructor( } /** Prepares the [request] to be executed at some point in the future. */ - override fun newCall(request: Request): Call = RealCall(this, request, forWebSocket = false) + override fun newCall(request: Request): Call = decoratedCallFactory.newCall(request) /** Uses [request] to connect a new web socket. */ override fun newWebSocket( @@ -596,6 +624,7 @@ open class OkHttpClient internal constructor( internal var dispatcher: Dispatcher = Dispatcher() internal var connectionPool: ConnectionPool? = null internal val interceptors: MutableList = mutableListOf() + internal val callDecorators: MutableList = mutableListOf() internal val networkInterceptors: MutableList = mutableListOf() internal var eventListenerFactory: EventListener.Factory = EventListener.NONE.asFactory() internal var retryOnConnectionFailure = true @@ -631,6 +660,7 @@ open class OkHttpClient internal constructor( this.dispatcher = okHttpClient.dispatcher this.connectionPool = okHttpClient.connectionPool this.interceptors += okHttpClient.interceptors + this.callDecorators += okHttpClient.callDecorators this.networkInterceptors += okHttpClient.networkInterceptors this.eventListenerFactory = okHttpClient.eventListenerFactory this.retryOnConnectionFailure = okHttpClient.retryOnConnectionFailure @@ -735,6 +765,11 @@ open class OkHttpClient internal constructor( this.eventListenerFactory = eventListenerFactory } + fun addCallDecorator(decorator: Call.Decorator) = + apply { + callDecorators += decorator + } + /** * Configure this client to retry or not when a connectivity problem is encountered. By default, * this client silently recovers from the following problems: