Skip to content

Commit 62d7717

Browse files
authored
Merge pull request #3 from iNikitaGricenko/master
Implement Apache Airflow Task to Trigger Dag Runs
2 parents a8b1bf3 + eb0291b commit 62d7717

File tree

13 files changed

+532
-183
lines changed

13 files changed

+532
-183
lines changed

build.gradle

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ sourceCompatibility = 21
2626
targetCompatibility = 21
2727

2828
group "io.kestra.plugin"
29-
description 'Plugin template for Kestra'
29+
description 'Plugin Airflow for Kestra'
3030

3131
tasks.withType(JavaCompile) {
3232
options.encoding = "UTF-8"
@@ -49,10 +49,6 @@ dependencies {
4949

5050
// Logs
5151
compileOnly "org.slf4j:slf4j-api"
52-
53-
// libs included in the final jar
54-
// TODO remove it after using the GitHub template as your plugin may not need it
55-
api "com.google.code.gson:gson:2.11.0"
5652
}
5753

5854

@@ -159,8 +155,8 @@ jar {
159155
manifest {
160156
attributes(
161157
"X-Kestra-Name": project.name,
162-
"X-Kestra-Title": "Template",
163-
"X-Kestra-Group": project.group + ".templates",
158+
"X-Kestra-Title": "Airflow",
159+
"X-Kestra-Group": project.group + ".airflow",
164160
"X-Kestra-Description": project.description,
165161
"X-Kestra-Version": project.version
166162
)

settings.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
rootProject.name = 'plugin-template'
1+
rootProject.name = 'plugin-airflow'
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package io.kestra.plugin.airflow;
2+
3+
import com.fasterxml.jackson.databind.ObjectMapper;
4+
import io.kestra.core.exceptions.IllegalVariableEvaluationException;
5+
import io.kestra.core.models.annotations.PluginProperty;
6+
import io.kestra.core.models.tasks.Task;
7+
import io.kestra.core.runners.RunContext;
8+
import io.kestra.core.serializers.JacksonMapper;
9+
import io.kestra.plugin.airflow.model.DagRunResponse;
10+
import io.kestra.plugin.core.http.HttpInterface;
11+
import io.swagger.v3.oas.annotations.media.Schema;
12+
import jakarta.validation.constraints.NotNull;
13+
import lombok.EqualsAndHashCode;
14+
import lombok.Getter;
15+
import lombok.NoArgsConstructor;
16+
import lombok.ToString;
17+
import lombok.experimental.SuperBuilder;
18+
import lombok.extern.slf4j.Slf4j;
19+
20+
import java.net.URI;
21+
import java.net.http.HttpClient;
22+
import java.net.http.HttpRequest;
23+
import java.net.http.HttpResponse;
24+
import java.util.Base64;
25+
import java.util.Map;
26+
27+
@SuperBuilder
28+
@ToString
29+
@EqualsAndHashCode
30+
@Getter
31+
@NoArgsConstructor
32+
@Slf4j
33+
public abstract class AirflowConnection extends Task {
34+
35+
protected final static ObjectMapper objectMapper = JacksonMapper.ofJson();
36+
37+
public static final String DAG_RUNS_ENDPOINT_FORMAT = "%s/api/v1/dags/%s/dagRuns";
38+
39+
public static final String JSON_CONTENT_TYPE = "application/json";
40+
41+
@Schema(
42+
title = "The base URL of the Airflow instance"
43+
)
44+
@NotNull
45+
@PluginProperty(dynamic = true)
46+
private String baseUrl;
47+
48+
@Schema(
49+
title = "Adds custom headers"
50+
)
51+
@PluginProperty
52+
private Map<String, String> headers;
53+
54+
@Schema(
55+
title = "Request options"
56+
)
57+
@PluginProperty
58+
protected HttpInterface.RequestOptions options;
59+
60+
protected DagRunResponse triggerDag(RunContext runContext, String dagId, String requestBody) throws Exception {
61+
String baseUrl = runContext.render(this.baseUrl);
62+
URI triggerUri = URI.create(DAG_RUNS_ENDPOINT_FORMAT.formatted(baseUrl, dagId));
63+
64+
try (HttpClient client = getClientBuilder().build()) {
65+
HttpRequest request = getRequestBuilder(runContext, triggerUri)
66+
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
67+
.build();
68+
69+
HttpResponse<String> response = client.send(request, HttpResponse.BodyHandlers.ofString());
70+
71+
if (response.statusCode() != 200) {
72+
throw new IllegalStateException("Failed to trigger DAG: " + response.body());
73+
}
74+
75+
return objectMapper.readValue(response.body(), DagRunResponse.class);
76+
}
77+
}
78+
79+
protected DagRunResponse getDagStatus(RunContext runContext, String dagId, String dagRunId) throws Exception {
80+
URI statusUri = URI.create(DAG_RUNS_ENDPOINT_FORMAT.formatted(getBaseUrl(), dagId) + "/" + dagRunId);
81+
82+
try (HttpClient client = getClientBuilder().build()) {
83+
HttpRequest statusRequest = getRequestBuilder(runContext, statusUri)
84+
.GET()
85+
.build();
86+
87+
HttpResponse<String> response = client.send(statusRequest, HttpResponse.BodyHandlers.ofString());
88+
89+
if (response.statusCode() != 200) {
90+
throw new IllegalStateException("Failed to get DAG run status: " + response.body());
91+
}
92+
93+
return objectMapper.readValue(response.body(), DagRunResponse.class);
94+
}
95+
}
96+
97+
private HttpClient.Builder getClientBuilder() {
98+
HttpClient.Builder clientBuilder = HttpClient.newBuilder();
99+
100+
if (this.options != null && this.options.getConnectTimeout() != null) {
101+
clientBuilder.connectTimeout(options.getConnectTimeout());
102+
}
103+
104+
return clientBuilder;
105+
}
106+
107+
private HttpRequest.Builder getRequestBuilder(RunContext runContext, URI uri) throws IllegalVariableEvaluationException {
108+
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
109+
.uri(uri)
110+
.header("Content-Type", JSON_CONTENT_TYPE);
111+
112+
setupCustomHeaders(runContext, requestBuilder);
113+
114+
return requestBuilder;
115+
}
116+
117+
private void setupCustomHeaders(RunContext runContext, HttpRequest.Builder requestBuilder) throws IllegalVariableEvaluationException {
118+
if (this.options != null && this.options.getBasicAuthUser() != null && this.options.getBasicAuthPassword() != null) {
119+
String authorizationString = "%s:%s"
120+
.formatted(
121+
runContext.render(this.options.getBasicAuthUser()),
122+
runContext.render(this.options.getBasicAuthPassword())
123+
);
124+
125+
String auth = Base64
126+
.getEncoder()
127+
.encodeToString(authorizationString.getBytes());
128+
129+
requestBuilder.header("Authorization", "Basic " + auth);
130+
}
131+
132+
if (this.headers != null && !headers.isEmpty()) {
133+
this.headers.forEach(requestBuilder::header);
134+
}
135+
}
136+
137+
}
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package io.kestra.plugin.airflow.dags;
2+
3+
import com.fasterxml.jackson.core.JsonProcessingException;
4+
import io.kestra.core.models.annotations.Example;
5+
import io.kestra.core.models.annotations.Plugin;
6+
import io.kestra.core.models.annotations.PluginProperty;
7+
import io.kestra.core.models.tasks.RunnableTask;
8+
import io.kestra.core.runners.RunContext;
9+
import io.kestra.core.utils.Await;
10+
import io.kestra.plugin.airflow.AirflowConnection;
11+
import io.kestra.plugin.airflow.model.DagRunResponse;
12+
import io.swagger.v3.oas.annotations.media.Schema;
13+
import jakarta.validation.constraints.NotNull;
14+
import lombok.*;
15+
import lombok.experimental.SuperBuilder;
16+
import lombok.extern.slf4j.Slf4j;
17+
18+
import java.time.Duration;
19+
import java.time.LocalDateTime;
20+
import java.time.ZonedDateTime;
21+
import java.util.Map;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
24+
import static io.kestra.core.utils.Rethrow.throwSupplier;
25+
26+
@SuperBuilder
27+
@ToString
28+
@EqualsAndHashCode
29+
@Getter
30+
@NoArgsConstructor
31+
@Slf4j
32+
@Schema(
33+
title = "Trigger Airflow DAG",
34+
description = "Trigger an Airflow DAG run and wait for its completion."
35+
)
36+
@Plugin(
37+
examples = {
38+
@Example(
39+
title = "Basic authorization",
40+
code = {
41+
" - id: trigger_dag",
42+
" type: io.kestra.plugin.airflow.TriggerDagRun",
43+
" baseUrl: http://airflow.example.com",
44+
" dagId: example_dag",
45+
" checkFrequency: PT30S",
46+
" interval: PT30S",
47+
" maxIterations: 100",
48+
" maxDuration: PT1H",
49+
" options:",
50+
" basicAuthUser: myusername",
51+
" basicAuthPassword: mypassword"
52+
}
53+
),
54+
@Example(
55+
title = "Bearer authorization",
56+
code = {
57+
" - id: trigger_dag",
58+
" type: io.kestra.plugin.airflow.TriggerDagRun",
59+
" baseUrl: http://airflow.example.com",
60+
" dagId: example_dag",
61+
" checkFrequency: PT30S",
62+
" interval: PT30S",
63+
" headers:",
64+
" authorization: 'Bearer {{ TOKEN }}'"
65+
}
66+
),
67+
@Example(
68+
title = "Basic authorization. Custom body",
69+
code = {
70+
" - id: trigger_dag",
71+
" type: io.kestra.plugin.airflow.TriggerDagRun",
72+
" baseUrl: http://airflow.example.com",
73+
" dagId: example_dag",
74+
" checkFrequency: PT30S",
75+
" interval: PT30S",
76+
" options:",
77+
" basicAuthUser: myusername",
78+
" basicAuthPassword: mypassword",
79+
" body: |",
80+
" {",
81+
" \"conf\": {",
82+
" \"source\": \"kestra\",",
83+
" \"flow\": \"{{ flow.id }}\",",
84+
" \"namespace\": \"{{ flow.namespace }}\",",
85+
" \"task\": \"{{ task.id }}\",",
86+
" \"execution\": \"{{ execution.id }}\"",
87+
" }",
88+
" }"
89+
}
90+
),
91+
}
92+
)
93+
public class TriggerDagRun extends AirflowConnection implements RunnableTask<TriggerDagRun.Output> {
94+
95+
@Schema(
96+
title = "The ID of the DAG to trigger"
97+
)
98+
@NotNull
99+
@PluginProperty(dynamic = true)
100+
private String dagId;
101+
102+
@Schema(
103+
title = "The job ID to check status for."
104+
)
105+
@PluginProperty(dynamic = true)
106+
private String jobId;
107+
108+
@Schema(
109+
title = "The maximum total wait duration."
110+
)
111+
@PluginProperty
112+
@Builder.Default
113+
Duration maxDuration = Duration.ofMinutes(60);
114+
115+
@Schema(
116+
title = "Specify how often the task should poll for the DAG run status."
117+
)
118+
@PluginProperty
119+
@Builder.Default
120+
Duration pollFrequency = Duration.ofSeconds(1);
121+
122+
@Schema(
123+
title = "Whether task should wait for the DAG to run to completion",
124+
description = "Default value is false"
125+
)
126+
@PluginProperty
127+
@Builder.Default
128+
private Boolean wait = Boolean.FALSE;
129+
130+
@Schema(
131+
title = "Overrides the default configuration payload"
132+
)
133+
@PluginProperty
134+
private Map<String, Object> body;
135+
136+
@Override
137+
public Output run(RunContext runContext) throws Exception {
138+
String dagId = runContext.render(this.dagId);
139+
140+
DagRunResponse triggerResult = triggerDag(runContext, dagId, buildBody(runContext));
141+
String dagRunId = triggerResult.getDagRunId();
142+
143+
Output.OutputBuilder outputBuilder = Output.builder()
144+
.dagId(dagId)
145+
.dagRunId(dagRunId)
146+
.state(triggerResult.getState());
147+
148+
if (this.wait.equals(Boolean.FALSE)) {
149+
return outputBuilder.build();
150+
}
151+
152+
DagRunResponse statusResult = Await.until(
153+
throwSupplier(() -> {
154+
DagRunResponse result = getDagStatus(runContext, dagId, dagRunId);
155+
String state = result.getState();
156+
157+
if ("success".equalsIgnoreCase(state) || "failed".equalsIgnoreCase(state)) {
158+
return result;
159+
}
160+
161+
return null;
162+
}),
163+
this.pollFrequency,
164+
this.maxDuration
165+
);
166+
167+
if (statusResult == null) {
168+
throw new IllegalStateException("DAG run did not complete within the specified timeout");
169+
}
170+
171+
return outputBuilder
172+
.state(statusResult.getState())
173+
.started(statusResult.getStartDate().toLocalDateTime())
174+
.ended(statusResult.getEndDate().toLocalDateTime())
175+
.build();
176+
}
177+
178+
private String buildBody(RunContext runContext) throws JsonProcessingException {
179+
RunContext.FlowInfo flowInfo = runContext.flowInfo();
180+
181+
Map<String, Object> conf = this.body;
182+
183+
if (this.body == null) {
184+
conf = Map.of(
185+
"source", "kestra",
186+
"flow", flowInfo.id(),
187+
"namespace", flowInfo.namespace(),
188+
"task", this.id,
189+
"execution", runContext.getTriggerExecutionId()
190+
);
191+
}
192+
193+
return objectMapper.writeValueAsString(conf);
194+
}
195+
196+
@Getter
197+
@Builder
198+
public static class Output implements io.kestra.core.models.tasks.Output {
199+
@Schema(
200+
title = "DAG ID"
201+
)
202+
private String dagId;
203+
204+
@Schema(
205+
title = "DAG run ID"
206+
)
207+
private String dagRunId;
208+
209+
@Schema(
210+
title = "State"
211+
)
212+
private String state;
213+
214+
@Schema(
215+
title = "DAG run started date"
216+
)
217+
private LocalDateTime started;
218+
219+
@Schema(
220+
title = "DAG run completed date"
221+
)
222+
private LocalDateTime ended;
223+
}
224+
225+
}

0 commit comments

Comments
 (0)