Skip to content

Commit c2a5136

Browse files
committed
fix(cloud): use RetryUtils to retry (#204)
1 parent ed0e979 commit c2a5136

File tree

2 files changed

+28
-61
lines changed

2 files changed

+28
-61
lines changed

src/main/java/io/kestra/plugin/dbt/cloud/AbstractDbtCloud.java

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import io.kestra.core.http.client.configurations.HttpConfiguration;
1313
import io.kestra.core.models.property.Property;
1414
import io.kestra.core.models.tasks.Task;
15+
import io.kestra.core.models.tasks.retrys.Exponential;
1516
import io.kestra.core.runners.RunContext;
17+
import io.kestra.core.utils.RetryUtils;
1618
import io.swagger.v3.oas.annotations.media.Schema;
1719
import jakarta.validation.constraints.NotNull;
1820
import lombok.*;
@@ -21,8 +23,6 @@
2123
import java.io.IOException;
2224
import java.time.Duration;
2325

24-
import static org.awaitility.Awaitility.await;
25-
2626
@SuperBuilder
2727
@ToString
2828
@EqualsAndHashCode
@@ -63,15 +63,6 @@ public abstract class AbstractDbtCloud extends Task {
6363
@Builder.Default
6464
Property<Long> initialDelayMs = Property.ofValue(1000L);
6565

66-
/**
67-
* Perform an HTTP request using Kestra HttpClient with retry logic.
68-
*
69-
* @param runContext The Kestra execution context.
70-
* @param requestBuilder The prepared HTTP request builder.
71-
* @param responseType The expected response type.
72-
* @param <RES> The response class.
73-
* @return HttpResponse of type RES.
74-
*/
7566
protected <RES> HttpResponse<RES> request(
7667
RunContext runContext,
7768
HttpRequest.HttpRequestBuilder requestBuilder,
@@ -83,55 +74,33 @@ protected <RES> HttpResponse<RES> request(
8374
.addHeader("Content-Type", "application/json")
8475
.build();
8576

86-
int rMaxRetries = runContext.render(this.maxRetries).as(Integer.class).orElse(3);
87-
long rInitialDelay = runContext.render(this.initialDelayMs).as(Long.class).orElse(1000L);
88-
89-
int attempt = 0;
77+
var rMaxRetries = runContext.render(this.maxRetries).as(Integer.class).orElse(3);
78+
var rInitialDelay = runContext.render(this.initialDelayMs).as(Long.class).orElse(1000L);
9079

9180
try (var client = new HttpClient(runContext, options)) {
92-
while (true) {
93-
try {
94-
HttpResponse<String> response = client.request(request, String.class);
95-
96-
RES parsedResponse = MAPPER.readValue(response.getBody(), responseType);
81+
return new RetryUtils().<HttpResponse<RES>, HttpClientException>of(
82+
Exponential.builder()
83+
.delayFactor(2.0)
84+
.interval(Duration.ofMillis(rInitialDelay))
85+
.maxInterval(Duration.ofSeconds(30))
86+
.maxAttempts(rMaxRetries)
87+
.build()
88+
).run(
89+
(res, throwable) -> throwable instanceof HttpClientResponseException ex &&
90+
(ex.getResponse().getStatus().getCode() == 502 ||
91+
ex.getResponse().getStatus().getCode() == 503 ||
92+
ex.getResponse().getStatus().getCode() == 504),
93+
() -> {
94+
var response = client.request(request, String.class);
95+
var parsedResponse = MAPPER.readValue(response.getBody(), responseType);
9796
return HttpResponse.<RES>builder()
9897
.request(request)
9998
.body(parsedResponse)
10099
.headers(response.getHeaders())
101100
.status(response.getStatus())
102101
.build();
103-
104-
} catch (HttpClientException e) {
105-
int statusCode = extractStatusCode(e);
106-
107-
if ((statusCode == 502 || statusCode == 503 || statusCode == 504) && attempt < rMaxRetries) {
108-
long backoff = (long) (rInitialDelay * Math.pow(2, attempt));
109-
runContext.logger().warn(
110-
"Request failed with status {}. Retrying in {} ms (attempt {}/{})",
111-
statusCode, backoff, attempt + 1, rMaxRetries
112-
);
113-
114-
await()
115-
.pollDelay(Duration.ofMillis(backoff))
116-
.atMost(Duration.ofMillis(backoff + 50))
117-
.until(() -> true);
118-
119-
attempt++;
120-
continue;
121-
}
122-
123-
throw e;
124-
} catch (IOException e) {
125-
throw new RuntimeException("Error executing HTTP request", e);
126102
}
127-
}
128-
}
129-
}
130-
131-
private int extractStatusCode(HttpClientException e) {
132-
if (e instanceof HttpClientResponseException ex) {
133-
return ex.getResponse().getStatus().getCode();
103+
);
134104
}
135-
return -1;
136105
}
137106
}

src/test/java/io/kestra/plugin/dbt/cloud/CheckStatusRetryTest.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
import io.kestra.core.models.property.Property;
99
import io.kestra.core.runners.RunContextFactory;
1010
import io.kestra.core.utils.IdUtils;
11+
import io.kestra.core.utils.RetryUtils;
1112
import jakarta.inject.Inject;
1213
import org.junit.jupiter.api.Test;
1314
import org.mockito.Mockito;
1415

1516
import java.net.URI;
1617
import java.util.Map;
1718

18-
import static org.junit.jupiter.api.Assertions.assertEquals;
19-
import static org.junit.jupiter.api.Assertions.assertThrows;
19+
import static org.junit.jupiter.api.Assertions.*;
2020
import static org.mockito.ArgumentMatchers.any;
2121
import static org.mockito.ArgumentMatchers.eq;
2222
import static org.mockito.Mockito.*;
@@ -86,12 +86,6 @@ void shouldFailAfterMaxRetries() throws Exception {
8686
.status(HttpResponse.Status.builder().code(502).build())
8787
.build()
8888
))
89-
.thenThrow(new HttpClientResponseException(
90-
"Bad Gateway",
91-
HttpResponse.<String>builder()
92-
.status(HttpResponse.Status.builder().code(502).build())
93-
.build()
94-
))
9589
.thenThrow(new HttpClientResponseException(
9690
"Bad Gateway",
9791
HttpResponse.<String>builder()
@@ -109,12 +103,16 @@ void shouldFailAfterMaxRetries() throws Exception {
109103
.initialDelayMs(Property.ofValue(100L))
110104
.build();
111105

112-
assertThrows(HttpClientResponseException.class,
106+
var ex = assertThrows(RetryUtils.RetryFailed.class,
113107
() -> task.request(runContext, requestBuilder, Map.class)
114108
);
115109

110+
assertInstanceOf(HttpClientResponseException.class, ex.getCause());
111+
var cause = (HttpClientResponseException) ex.getCause();
112+
assertEquals(502, cause.getResponse().getStatus().getCode());
113+
116114
var mockClient = mocked.constructed().getFirst();
117-
verify(mockClient, times(3)).request(any(HttpRequest.class), eq(String.class));
115+
verify(mockClient, times(2)).request(any(HttpRequest.class), eq(String.class));
118116
}
119117
}
120118
}

0 commit comments

Comments
 (0)