Skip to content

Commit 4376077

Browse files
[PM-24482] Refresh access token preemptively and log out on 401/403 refresh errors (#2024)
1 parent 039495e commit 4376077

19 files changed

+340
-30
lines changed

BitwardenShared/Core/Platform/Services/API/AccountTokenProvider.swift

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import BitwardenKit
2+
import Foundation
13
import Networking
24

35
// MARK: - AccountTokenProvider
@@ -21,28 +23,34 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
2123
private weak var accountTokenProviderDelegate: AccountTokenProviderDelegate?
2224

2325
/// The `HTTPService` used to make the API call to refresh the access token.
24-
let httpService: HTTPService
26+
private let httpService: HTTPService
2527

2628
/// The task associated with refreshing the token, if one is in progress.
2729
private(set) var refreshTask: Task<String, Error>?
2830

31+
/// The service used to get the present time.
32+
private let timeProvider: TimeProvider
33+
2934
/// The `TokenService` used to get the current tokens from.
30-
let tokenService: TokenService
35+
private let tokenService: TokenService
3136

3237
// MARK: Initialization
3338

3439
/// Initialize an `AccountTokenProvider`.
3540
///
3641
/// - Parameters:
3742
/// - httpService: The service used to make the API call to refresh the access token.
43+
/// - timeProvider: The service used to get the present time.
3844
/// - tokenService: The service used to get the current tokens from.
3945
///
4046
init(
4147
httpService: HTTPService,
48+
timeProvider: TimeProvider = CurrentTime(),
4249
tokenService: TokenService,
4350
) {
44-
self.tokenService = tokenService
4551
self.httpService = httpService
52+
self.timeProvider = timeProvider
53+
self.tokenService = tokenService
4654
}
4755

4856
// MARK: Methods
@@ -54,15 +62,19 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
5462
return try await refreshTask.value
5563
}
5664

57-
return try await tokenService.getAccessToken()
65+
let token = try await tokenService.getAccessToken()
66+
if await shouldRefresh(accessToken: token) {
67+
return try await refreshToken()
68+
} else {
69+
return token
70+
}
5871
}
5972

60-
func refreshToken() async throws {
73+
func refreshToken() async throws -> String {
6174
if let refreshTask {
6275
// If there's a refresh in progress, wait for it to complete rather than triggering
6376
// another refresh.
64-
_ = try await refreshTask.value
65-
return
77+
return try await refreshTask.value
6678
}
6779

6880
let refreshTask = Task {
@@ -73,9 +85,12 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
7385
let response = try await httpService.send(
7486
IdentityTokenRefreshRequest(refreshToken: refreshToken),
7587
)
88+
let expirationDate = timeProvider.presentTime.addingTimeInterval(TimeInterval(response.expiresIn))
89+
7690
try await tokenService.setTokens(
7791
accessToken: response.accessToken,
7892
refreshToken: response.refreshToken,
93+
expirationDate: expirationDate,
7994
)
8095

8196
return response.accessToken
@@ -88,17 +103,35 @@ actor DefaultAccountTokenProvider: AccountTokenProvider {
88103
}
89104
self.refreshTask = refreshTask
90105

91-
_ = try await refreshTask.value
106+
return try await refreshTask.value
92107
}
93108

94109
func setDelegate(delegate: AccountTokenProviderDelegate) async {
95110
accountTokenProviderDelegate = delegate
96111
}
112+
113+
// MARK: Private
114+
115+
/// Returns whether the access token needs to be refreshed based on the last stored access token
116+
/// expiration date. This is used to preemptively refresh the token prior to its expiration.
117+
///
118+
/// - Parameter accessToken: The access token to determine whether it needs to be refreshed.
119+
/// - Returns: Whether the access token needs to be refreshed.
120+
///
121+
private func shouldRefresh(accessToken: String) async -> Bool {
122+
guard let expirationDate = try? await tokenService.getAccessTokenExpirationDate() else {
123+
// If there's no stored expiration date, don't preemptively refresh the token.
124+
return false
125+
}
126+
127+
let refreshThreshold = timeProvider.presentTime.addingTimeInterval(Constants.tokenRefreshThreshold)
128+
return expirationDate <= refreshThreshold
129+
}
97130
}
98131

99132
/// Delegate to be used by the `AccountTokenProvider`.
100133
protocol AccountTokenProviderDelegate: AnyObject {
101-
/// Callbac to be used when an error is thrown when refreshing the access token.
134+
/// Callback to be used when an error is thrown when refreshing the access token.
102135
/// - Parameter error: `Error` thrown.
103136
func onRefreshTokenError(error: Error) async throws
104137
}

BitwardenShared/Core/Platform/Services/API/AccountTokenProviderTests.swift

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import BitwardenKitMocks
12
import Networking
23
import TestHelpers
34
import XCTest
@@ -9,18 +10,25 @@ class AccountTokenProviderTests: BitwardenTestCase {
910

1011
var client: MockHTTPClient!
1112
var subject: DefaultAccountTokenProvider!
13+
var timeProvider: MockTimeProvider!
1214
var tokenService: MockTokenService!
1315

16+
let expirationDateExpired = Date(year: 2025, month: 10, day: 1, hour: 23, minute: 59, second: 0)
17+
let expirationDateExpiringSoon = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 2, second: 0)
18+
let expirationDateUnexpired = Date(year: 2025, month: 10, day: 2, hour: 0, minute: 6, second: 0)
19+
1420
// MARK: Setup & Teardown
1521

1622
override func setUp() {
1723
super.setUp()
1824

1925
client = MockHTTPClient()
26+
timeProvider = MockTimeProvider(.mockTime(Date(year: 2025, month: 10, day: 2)))
2027
tokenService = MockTokenService()
2128

2229
subject = DefaultAccountTokenProvider(
2330
httpService: HTTPService(baseURL: URL(string: "https://example.com")!, client: client),
31+
timeProvider: timeProvider,
2432
tokenService: tokenService,
2533
)
2634
}
@@ -30,13 +38,55 @@ class AccountTokenProviderTests: BitwardenTestCase {
3038

3139
client = nil
3240
subject = nil
41+
timeProvider = nil
3342
tokenService = nil
3443
}
3544

3645
// MARK: Tests
3746

38-
/// `getToken()` returns the current access token.
39-
func test_getToken() async throws {
47+
/// `getToken()` returns the current access token if fetching the expiration date returns an error.
48+
func test_getToken_tokenError() async throws {
49+
tokenService.accessToken = "ACCESS_TOKEN"
50+
tokenService.accessTokenExpirationDateResult = .failure(BitwardenTestError.example)
51+
52+
let token = try await subject.getToken()
53+
XCTAssertEqual(token, "ACCESS_TOKEN")
54+
}
55+
56+
/// `getToken()` returns a refreshed access token if the current one is expired.
57+
func test_getToken_tokenExpired() async throws {
58+
client.result = .httpSuccess(testData: .identityTokenRefresh)
59+
tokenService.accessToken = "EXPIRED"
60+
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpired)
61+
62+
let token = try await subject.getToken()
63+
XCTAssertEqual(token, "ACCESS_TOKEN")
64+
}
65+
66+
/// `getToken()` returns a refreshed access token if the current one is expiring soon.
67+
func test_getToken_tokenExpiringSoon() async throws {
68+
client.result = .httpSuccess(testData: .identityTokenRefresh)
69+
tokenService.accessToken = "EXPIRING_SOON"
70+
tokenService.accessTokenExpirationDateResult = .success(expirationDateExpiringSoon)
71+
72+
let token = try await subject.getToken()
73+
XCTAssertEqual(token, "ACCESS_TOKEN")
74+
}
75+
76+
/// `getToken()` returns the current access token if it is unexpired.
77+
func test_getToken_tokenUnexpired() async throws {
78+
tokenService.accessToken = "ACCESS_TOKEN"
79+
tokenService.accessTokenExpirationDateResult = .success(expirationDateUnexpired)
80+
81+
let token = try await subject.getToken()
82+
XCTAssertEqual(token, "ACCESS_TOKEN")
83+
}
84+
85+
/// `getToken()` returns the current access token if the expiration date doesn't yet exist.
86+
func test_getToken_tokenNil() async throws {
87+
tokenService.accessToken = "ACCESS_TOKEN"
88+
tokenService.accessTokenExpirationDateResult = .success(nil)
89+
4090
let token = try await subject.getToken()
4191
XCTAssertEqual(token, "ACCESS_TOKEN")
4292
}
@@ -58,12 +108,12 @@ class AccountTokenProviderTests: BitwardenTestCase {
58108

59109
client.result = .httpSuccess(testData: .identityTokenRefresh)
60110

61-
try await subject.refreshToken()
111+
let newAccessToken = try await subject.refreshToken()
62112

63-
let newAccessToken = try await subject.getToken()
64113
XCTAssertEqual(newAccessToken, "ACCESS_TOKEN")
65114
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
66115
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
116+
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
67117

68118
let refreshTask = await subject.refreshTask
69119
XCTAssertNil(refreshTask)
@@ -76,14 +126,15 @@ class AccountTokenProviderTests: BitwardenTestCase {
76126

77127
client.result = .httpSuccess(testData: .identityTokenRefresh)
78128

79-
async let refreshTask1: Void = subject.refreshToken()
80-
async let refreshTask2: Void = subject.refreshToken()
129+
async let refreshTask1: String = subject.refreshToken()
130+
async let refreshTask2: String = subject.refreshToken()
81131

82132
_ = try await (refreshTask1, refreshTask2)
83133

84134
XCTAssertEqual(client.requests.count, 1)
85135
XCTAssertEqual(tokenService.accessToken, "ACCESS_TOKEN")
86136
XCTAssertEqual(tokenService.refreshToken, "REFRESH_TOKEN")
137+
XCTAssertEqual(tokenService.expirationDate, Date(year: 2025, month: 10, day: 2, hour: 1, minute: 0, second: 0))
87138

88139
let refreshTask = await subject.refreshTask
89140
XCTAssertNil(refreshTask)
@@ -101,7 +152,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
101152
client.result = .failure(BitwardenTestError.example)
102153

103154
await assertAsyncThrows(error: BitwardenTestError.example) {
104-
try await subject.refreshToken()
155+
_ = try await subject.refreshToken()
105156
}
106157
XCTAssertTrue(delegate.onRefreshTokenErrorCalled)
107158
}
@@ -115,7 +166,7 @@ class AccountTokenProviderTests: BitwardenTestCase {
115166
client.result = .failure(BitwardenTestError.example)
116167

117168
await assertAsyncThrows(error: BitwardenTestError.example) {
118-
try await subject.refreshToken()
169+
_ = try await subject.refreshToken()
119170
}
120171
}
121172
}

BitwardenShared/Core/Platform/Services/API/RefreshableAPIService.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ protocol RefreshableAPIService { // sourcery: AutoMockable
99

1010
extension APIService: RefreshableAPIService {
1111
func refreshAccessToken() async throws {
12-
try await accountTokenProvider.refreshToken()
12+
_ = try await accountTokenProvider.refreshToken()
1313
}
1414
}

BitwardenShared/Core/Platform/Services/API/TestHelpers/MockAccountTokenProvider.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ class MockAccountTokenProvider: AccountTokenProvider {
99
var delegate: AccountTokenProviderDelegate?
1010
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
1111
var refreshTokenCalled = false
12-
var refreshTokenResult: Result<Void, Error> = .success(())
12+
var refreshTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
1313

1414
func getToken() async throws -> String {
1515
try getTokenResult.get()
1616
}
1717

18-
func refreshToken() async throws {
18+
func refreshToken() async throws -> String {
1919
refreshTokenCalled = true
20-
try refreshTokenResult.get()
20+
return try refreshTokenResult.get()
2121
}
2222

2323
func setDelegate(delegate: AccountTokenProviderDelegate) async {

BitwardenShared/Core/Platform/Services/StateService.swift

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ protocol StateService: AnyObject {
4444
///
4545
func doesActiveAccountHavePremium() async -> Bool
4646

47+
/// Gets the access token's expiration date for an account.
48+
///
49+
/// - Parameter userId: The user ID associated with the access token expiration date.
50+
/// - Returns: The user's access token expiration date.
51+
///
52+
func getAccessTokenExpirationDate(userId: String) async -> Date?
53+
4754
/// Gets the account for an id.
4855
///
4956
/// - Parameter userId: The id for an account. If nil, the active account will be returned.
@@ -429,6 +436,14 @@ protocol StateService: AnyObject {
429436
///
430437
func pinUnlockRequiresPasswordAfterRestart() async throws -> Bool
431438

439+
/// Sets the access token's expiration date for an account.
440+
///
441+
/// - Parameters:
442+
/// - expirationDate: The user's access token expiration date.
443+
/// - userId: The user ID associated with the access token expiration date.
444+
///
445+
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async
446+
432447
/// Sets the account encryption keys for an account.
433448
///
434449
/// - Parameters:
@@ -855,6 +870,14 @@ extension StateService {
855870
await setPendingAppIntentActions(actions: actions)
856871
}
857872

873+
/// Gets the access token's expiration date for the active account.
874+
///
875+
/// - Returns: The user's access token expiration date.
876+
///
877+
func getAccessTokenExpirationDate() async throws -> Date? {
878+
try await getAccessTokenExpirationDate(userId: getActiveAccountId())
879+
}
880+
858881
/// Gets the account encryptions keys for the active account.
859882
///
860883
/// - Returns: The account encryption keys.
@@ -1143,6 +1166,14 @@ extension StateService {
11431166
try await pinProtectedUserKeyEnvelope(userId: nil)
11441167
}
11451168

1169+
/// Sets the access token's expiration date for the active account.
1170+
///
1171+
/// - Parameter expirationDate: The user's access token expiration date.
1172+
///
1173+
func setAccessTokenExpirationDate(_ expirationDate: Date?) async throws {
1174+
try await setAccessTokenExpirationDate(expirationDate, userId: getActiveAccountId())
1175+
}
1176+
11461177
/// Sets the account encryption keys for the active account.
11471178
///
11481179
/// - Parameter encryptionKeys: The account encryption keys.
@@ -1542,6 +1573,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
15421573
}
15431574
}
15441575

1576+
func getAccessTokenExpirationDate(userId: String) -> Date? {
1577+
appSettingsStore.accessTokenExpirationDate(userId: userId)
1578+
}
1579+
15451580
func getAccount(userId: String?) throws -> Account {
15461581
guard let accounts = appSettingsStore.state?.accounts else {
15471582
throw StateServiceError.noAccounts
@@ -1844,6 +1879,7 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
18441879
state.activeUserId = state.accounts.first?.key
18451880
}
18461881

1882+
appSettingsStore.setAccessTokenExpirationDate(nil, userId: knownUserId)
18471883
appSettingsStore.setBiometricAuthenticationEnabled(nil, for: knownUserId)
18481884
appSettingsStore.setDefaultUriMatchType(nil, userId: knownUserId)
18491885
appSettingsStore.setDisableAutoTotpCopy(nil, userId: knownUserId)
@@ -1876,6 +1912,10 @@ actor DefaultStateService: StateService, ConfigStateService { // swiftlint:disab
18761912
&& appSettingsStore.pinProtectedUserKey(userId: userId) == nil
18771913
}
18781914

1915+
func setAccessTokenExpirationDate(_ expirationDate: Date?, userId: String) async {
1916+
appSettingsStore.setAccessTokenExpirationDate(expirationDate, userId: userId)
1917+
}
1918+
18791919
func setAccountKdf(_ kdfConfig: KdfConfig, userId: String) async throws {
18801920
try updateAccountProfile(userId: userId) { profile in
18811921
profile.kdfType = kdfConfig.kdfType

0 commit comments

Comments
 (0)