Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/github-provider-example.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
id: github-model-demo
namespace: demo.ai
tasks:
- id: github_chat
type: io.kestra.plugin.ai.chat.ChatCompletion
provider:
type: github
token: "{{ secrets.GITHUB_MODELS_TOKEN }}"
model: "gpt-4.1-mini"
prompt: "Explain Kestra in 2 lines."
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package io.kestra.plugin.ai.provider.github;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.*;

/**
* Minimal GitHub Models client (chat + embeddings).
*
* Notes:
* - Uses the base URL from GitHubModelsConfig (defaults to https://api.github.com).
* - Supports chat with `messages` (preferred) or `input`.
* - Parses typical response shapes: choices[0].message.content, choices[0].output, data[].embedding.
*/
public class GitHubModelsClient {
private final GitHubModelsConfig config;
private final HttpClient http;
private final ObjectMapper mapper = new ObjectMapper();

public GitHubModelsClient(GitHubModelsConfig config) {
this.config = config;
int timeout = config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000;
this.http = HttpClient.newBuilder()
.connectTimeout(Duration.ofMillis(timeout))
.build();
}

private String base() {
return (config.getBaseUrl() != null && !config.getBaseUrl().isBlank())
? config.getBaseUrl()
: "https://api.github.com";
}

/**
* Call GitHub chat completion endpoint.
*
* @param model model id (e.g. "gpt-4.1-mini")
* @param messages list of message maps (role, content) OR null if using input
* @param input fallback input string (used if messages == null)
* @param params optional additional parameters (temperature, max_tokens, etc.)
* @return assistant textual response (best-effort parsing)
* @throws Exception on network / parsing errors
*/
public String chat(String model, List<Map<String, String>> messages, String input, Map<String, Object> params) throws Exception {
String url = base() + "/inference/chat/completions";

Map<String, Object> payload = new LinkedHashMap<>();
payload.put("model", model != null ? model : config.getDefaultModel());
if (messages != null) {
payload.put("messages", messages);
} else {
payload.put("input", input != null ? input : "");
}

if (params != null) {
payload.putAll(params);
}

HttpRequest req = HttpRequest.newBuilder()
.uri(URI.create(url))
.timeout(Duration.ofMillis(config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000))
.header("Authorization", "Bearer " + config.getToken())
.header("Accept", "application/vnd.github+json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(mapper.writeValueAsString(payload)))
.build();

HttpResponse<String> resp = http.send(req, HttpResponse.BodyHandlers.ofString());
int status = resp.statusCode();
String body = resp.body();

if (status < 200 || status >= 300) {
throw new RuntimeException("GitHub Models API error: " + status + " -> " + body);
}

JsonNode root = mapper.readTree(body);

// 1) Try choices[0].message.content
JsonNode choices = root.path("choices");
if (choices.isArray() && choices.size() > 0) {
JsonNode first = choices.get(0);
JsonNode message = first.path("message");
if (message.isObject() && message.has("content")) {
return message.get("content").asText();
}

// 2) fallback to output field on choice
if (first.has("output")) {
return first.get("output").asText();
}
}

// 3) fallback to top-level output
if (root.has("output")) {
return root.get("output").asText();
}

// 4) fallback: return whole body
return root.toString();
}

/**
* Call GitHub embeddings endpoint and return the first embedding vector as List<Double>.
*
* @param model model id for embeddings (e.g. "text-embedding-3-large")
* @param input text to embed
* @return embedding vector as List<Double>
* @throws Exception on network / parsing errors
*/
public List<Double> embeddings(String model, String input) throws Exception {
String url = base() + "/inference/embeddings";

Map<String, Object> payload = Map.of(
"model", model != null ? model : config.getDefaultModel(),
"input", input
);

HttpRequest req = HttpRequest.newBuilder()
.uri(URI.create(url))
.timeout(Duration.ofMillis(config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000))
.header("Authorization", "Bearer " + config.getToken())
.header("Accept", "application/vnd.github+json")
.header("Content-Type", "application/json")
.POST(HttpRequest.BodyPublishers.ofString(mapper.writeValueAsString(payload)))
.build();

HttpResponse<String> resp = http.send(req, HttpResponse.BodyHandlers.ofString());
int status = resp.statusCode();
String body = resp.body();

if (status < 200 || status >= 300) {
throw new RuntimeException("GitHub Models API embeddings error: " + status + " -> " + body);
}

JsonNode root = mapper.readTree(body);
JsonNode data = root.path("data");
if (data.isArray() && data.size() > 0) {
JsonNode first = data.get(0);
JsonNode emb = first.path("embedding");
if (emb.isArray()) {
List<Double> vec = new ArrayList<>();
for (JsonNode n : emb) {
vec.add(n.asDouble());
}
return vec;
}
}

throw new RuntimeException("No embedding found in response: " + body);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.kestra.plugin.ai.provider.github;

import java.util.*;

/**
* Adapter between Kestra plugin AI tasks and GitHubModelsClient.
*
* IMPORTANT:
* - This class is a neutral adapter. If your project requires implementation of a specific
* provider interface (e.g., ChatProvider), you'll need to adapt method signatures accordingly.
* - The methods below are simple and designed to be easy to call from task code.
*/
public class GitHubProvider {
private final GitHubModelsConfig config;
private final GitHubModelsClient client;

public GitHubProvider(GitHubModelsConfig config) {
this.config = config;
this.client = new GitHubModelsClient(config);
}

/**
* Send a chat prompt and return the assistant reply as a String.
*
* @param model model id (or null to use config.defaultModel)
* @param prompt user prompt text
* @return assistant reply text
* @throws Exception on errors
*/
public String chat(String model, String prompt) throws Exception {
// Build a minimal messages list (user role)
Map<String, String> userMsg = Map.of(
"role", "user",
"content", prompt
);
List<Map<String, String>> messages = Collections.singletonList(userMsg);

// No extra params for now; can be extended to pass temperature, max_tokens, etc.
return client.chat(model != null ? model : config.getDefaultModel(), messages, null, null);
}

/**
* Send pre-constructed messages (for more advanced use).
*
* @param model model id
* @param messages list of message maps with keys "role" and "content"
* @param params extra parameters for the model call
* @return assistant text
* @throws Exception on errors
*/
public String chatWithMessages(String model, List<Map<String, String>> messages, Map<String, Object> params) throws Exception {
return client.chat(model != null ? model : config.getDefaultModel(), messages, null, params);
}

/**
* Get embeddings for an input text.
*
* @param model model id for embeddings
* @param input text to embed
* @return embedding vector
* @throws Exception on errors
*/
public List<Double> embeddings(String model, String input) throws Exception {
return client.embeddings(model != null ? model : config.getDefaultModel(), input);
}

/**
* Expose config for tests or consumers.
*/
public GitHubModelsConfig getConfig() {
return this.config;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package io.kestra.plugin.ai.provider.github;

import org.junit.jupiter.api.Test;

import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.*;

/**
* Unit tests for GitHubProvider that mock GitHubModelsClient so tests don't hit network/Docker.
*/
public class GitHubProviderTest {
@Test
public void chatParsesMessageContent() throws Exception {
// Arrange
GitHubModelsConfig cfg = GitHubModelsConfig.builder()
.token("dummy")
.defaultModel("gpt-4.1-mini")
.build();

// Create provider instance
GitHubProvider provider = new GitHubProvider(cfg);

// Mock client and response
GitHubModelsClient mockClient = mock(GitHubModelsClient.class);
Map<String, Object> choice = Map.of(
"message", Map.of("content", "hello from mock")
);
when(mockClient.chat(anyString(), anyList(), any(), anyMap()))
.thenReturn("hello from mock");

// Inject mock client into provider (reflection because client is private)
Field clientField = GitHubProvider.class.getDeclaredField("client");
clientField.setAccessible(true);
clientField.set(provider, mockClient);

// Act
String out = provider.chat("gpt-4.1-mini", "some prompt");

// Assert
assertEquals("hello from mock", out);
verify(mockClient, times(1)).chat(eq("gpt-4.1-mini"), anyList(), eq(null), eq(null));
}

@Test
public void embeddingsReturnsList() throws Exception {
// Arrange
GitHubModelsConfig cfg = GitHubModelsConfig.builder()
.token("dummy")
.defaultModel("embed-model")
.build();
GitHubProvider provider = new GitHubProvider(cfg);

GitHubModelsClient mockClient = mock(GitHubModelsClient.class);
when(mockClient.embeddings(anyString(), anyString()))
.thenReturn(List.of(0.1, 0.2));

Field clientField = GitHubProvider.class.getDeclaredField("client");
clientField.setAccessible(true);
clientField.set(provider, mockClient);

// Act
var resp = provider.embeddings("embed-model", "text to embed");

// Assert
assertEquals(2, resp.size());
verify(mockClient, times(1)).embeddings(eq("embed-model"), eq("text to embed"));
}
}
Loading