diff --git a/examples/github-provider-example.yaml b/examples/github-provider-example.yaml new file mode 100644 index 0000000..71f3679 --- /dev/null +++ b/examples/github-provider-example.yaml @@ -0,0 +1,10 @@ +id: github-model-demo +namespace: demo.ai +tasks: + - id: github_chat + type: io.kestra.plugin.ai.chat.ChatCompletion + provider: + type: github + token: "{{ secrets.GITHUB_MODELS_TOKEN }}" + model: "gpt-4.1-mini" + prompt: "Explain Kestra in 2 lines." diff --git a/src/main/java/io/kestra/plugin/ai/provider/github/GitHubModelsClient.java b/src/main/java/io/kestra/plugin/ai/provider/github/GitHubModelsClient.java new file mode 100644 index 0000000..d33bdda --- /dev/null +++ b/src/main/java/io/kestra/plugin/ai/provider/github/GitHubModelsClient.java @@ -0,0 +1,158 @@ +package io.kestra.plugin.ai.provider.github; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.*; + +/** + * Minimal GitHub Models client (chat + embeddings). + * + * Notes: + * - Uses the base URL from GitHubModelsConfig (defaults to https://api.github.com). + * - Supports chat with `messages` (preferred) or `input`. + * - Parses typical response shapes: choices[0].message.content, choices[0].output, data[].embedding. + */ +public class GitHubModelsClient { + private final GitHubModelsConfig config; + private final HttpClient http; + private final ObjectMapper mapper = new ObjectMapper(); + + public GitHubModelsClient(GitHubModelsConfig config) { + this.config = config; + int timeout = config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000; + this.http = HttpClient.newBuilder() + .connectTimeout(Duration.ofMillis(timeout)) + .build(); + } + + private String base() { + return (config.getBaseUrl() != null && !config.getBaseUrl().isBlank()) + ? config.getBaseUrl() + : "https://api.github.com"; + } + + /** + * Call GitHub chat completion endpoint. + * + * @param model model id (e.g. "gpt-4.1-mini") + * @param messages list of message maps (role, content) OR null if using input + * @param input fallback input string (used if messages == null) + * @param params optional additional parameters (temperature, max_tokens, etc.) + * @return assistant textual response (best-effort parsing) + * @throws Exception on network / parsing errors + */ + public String chat(String model, List> messages, String input, Map params) throws Exception { + String url = base() + "/inference/chat/completions"; + + Map payload = new LinkedHashMap<>(); + payload.put("model", model != null ? model : config.getDefaultModel()); + if (messages != null) { + payload.put("messages", messages); + } else { + payload.put("input", input != null ? input : ""); + } + + if (params != null) { + payload.putAll(params); + } + + HttpRequest req = HttpRequest.newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofMillis(config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000)) + .header("Authorization", "Bearer " + config.getToken()) + .header("Accept", "application/vnd.github+json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(mapper.writeValueAsString(payload))) + .build(); + + HttpResponse resp = http.send(req, HttpResponse.BodyHandlers.ofString()); + int status = resp.statusCode(); + String body = resp.body(); + + if (status < 200 || status >= 300) { + throw new RuntimeException("GitHub Models API error: " + status + " -> " + body); + } + + JsonNode root = mapper.readTree(body); + + // 1) Try choices[0].message.content + JsonNode choices = root.path("choices"); + if (choices.isArray() && choices.size() > 0) { + JsonNode first = choices.get(0); + JsonNode message = first.path("message"); + if (message.isObject() && message.has("content")) { + return message.get("content").asText(); + } + + // 2) fallback to output field on choice + if (first.has("output")) { + return first.get("output").asText(); + } + } + + // 3) fallback to top-level output + if (root.has("output")) { + return root.get("output").asText(); + } + + // 4) fallback: return whole body + return root.toString(); + } + + /** + * Call GitHub embeddings endpoint and return the first embedding vector as List. + * + * @param model model id for embeddings (e.g. "text-embedding-3-large") + * @param input text to embed + * @return embedding vector as List + * @throws Exception on network / parsing errors + */ + public List embeddings(String model, String input) throws Exception { + String url = base() + "/inference/embeddings"; + + Map payload = Map.of( + "model", model != null ? model : config.getDefaultModel(), + "input", input + ); + + HttpRequest req = HttpRequest.newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofMillis(config.getTimeoutMs() != null ? config.getTimeoutMs() : 60_000)) + .header("Authorization", "Bearer " + config.getToken()) + .header("Accept", "application/vnd.github+json") + .header("Content-Type", "application/json") + .POST(HttpRequest.BodyPublishers.ofString(mapper.writeValueAsString(payload))) + .build(); + + HttpResponse resp = http.send(req, HttpResponse.BodyHandlers.ofString()); + int status = resp.statusCode(); + String body = resp.body(); + + if (status < 200 || status >= 300) { + throw new RuntimeException("GitHub Models API embeddings error: " + status + " -> " + body); + } + + JsonNode root = mapper.readTree(body); + JsonNode data = root.path("data"); + if (data.isArray() && data.size() > 0) { + JsonNode first = data.get(0); + JsonNode emb = first.path("embedding"); + if (emb.isArray()) { + List vec = new ArrayList<>(); + for (JsonNode n : emb) { + vec.add(n.asDouble()); + } + return vec; + } + } + + throw new RuntimeException("No embedding found in response: " + body); + } +} diff --git a/src/main/java/io/kestra/plugin/ai/provider/github/GitHubProvider.java b/src/main/java/io/kestra/plugin/ai/provider/github/GitHubProvider.java new file mode 100644 index 0000000..0a42b68 --- /dev/null +++ b/src/main/java/io/kestra/plugin/ai/provider/github/GitHubProvider.java @@ -0,0 +1,73 @@ +package io.kestra.plugin.ai.provider.github; + +import java.util.*; + +/** + * Adapter between Kestra plugin AI tasks and GitHubModelsClient. + * + * IMPORTANT: + * - This class is a neutral adapter. If your project requires implementation of a specific + * provider interface (e.g., ChatProvider), you'll need to adapt method signatures accordingly. + * - The methods below are simple and designed to be easy to call from task code. + */ +public class GitHubProvider { + private final GitHubModelsConfig config; + private final GitHubModelsClient client; + + public GitHubProvider(GitHubModelsConfig config) { + this.config = config; + this.client = new GitHubModelsClient(config); + } + + /** + * Send a chat prompt and return the assistant reply as a String. + * + * @param model model id (or null to use config.defaultModel) + * @param prompt user prompt text + * @return assistant reply text + * @throws Exception on errors + */ + public String chat(String model, String prompt) throws Exception { + // Build a minimal messages list (user role) + Map userMsg = Map.of( + "role", "user", + "content", prompt + ); + List> messages = Collections.singletonList(userMsg); + + // No extra params for now; can be extended to pass temperature, max_tokens, etc. + return client.chat(model != null ? model : config.getDefaultModel(), messages, null, null); + } + + /** + * Send pre-constructed messages (for more advanced use). + * + * @param model model id + * @param messages list of message maps with keys "role" and "content" + * @param params extra parameters for the model call + * @return assistant text + * @throws Exception on errors + */ + public String chatWithMessages(String model, List> messages, Map params) throws Exception { + return client.chat(model != null ? model : config.getDefaultModel(), messages, null, params); + } + + /** + * Get embeddings for an input text. + * + * @param model model id for embeddings + * @param input text to embed + * @return embedding vector + * @throws Exception on errors + */ + public List embeddings(String model, String input) throws Exception { + return client.embeddings(model != null ? model : config.getDefaultModel(), input); + } + + /** + * Expose config for tests or consumers. + */ + public GitHubModelsConfig getConfig() { + return this.config; + } +} diff --git a/src/test/java/io/kestra/plugin/ai/provider/github/GitHubProviderTest.java b/src/test/java/io/kestra/plugin/ai/provider/github/GitHubProviderTest.java new file mode 100644 index 0000000..9a91f89 --- /dev/null +++ b/src/test/java/io/kestra/plugin/ai/provider/github/GitHubProviderTest.java @@ -0,0 +1,72 @@ +package io.kestra.plugin.ai.provider.github; + +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.*; + +/** + * Unit tests for GitHubProvider that mock GitHubModelsClient so tests don't hit network/Docker. + */ +public class GitHubProviderTest { + @Test + public void chatParsesMessageContent() throws Exception { + // Arrange + GitHubModelsConfig cfg = GitHubModelsConfig.builder() + .token("dummy") + .defaultModel("gpt-4.1-mini") + .build(); + + // Create provider instance + GitHubProvider provider = new GitHubProvider(cfg); + + // Mock client and response + GitHubModelsClient mockClient = mock(GitHubModelsClient.class); + Map choice = Map.of( + "message", Map.of("content", "hello from mock") + ); + when(mockClient.chat(anyString(), anyList(), any(), anyMap())) + .thenReturn("hello from mock"); + + // Inject mock client into provider (reflection because client is private) + Field clientField = GitHubProvider.class.getDeclaredField("client"); + clientField.setAccessible(true); + clientField.set(provider, mockClient); + + // Act + String out = provider.chat("gpt-4.1-mini", "some prompt"); + + // Assert + assertEquals("hello from mock", out); + verify(mockClient, times(1)).chat(eq("gpt-4.1-mini"), anyList(), eq(null), eq(null)); + } + + @Test + public void embeddingsReturnsList() throws Exception { + // Arrange + GitHubModelsConfig cfg = GitHubModelsConfig.builder() + .token("dummy") + .defaultModel("embed-model") + .build(); + GitHubProvider provider = new GitHubProvider(cfg); + + GitHubModelsClient mockClient = mock(GitHubModelsClient.class); + when(mockClient.embeddings(anyString(), anyString())) + .thenReturn(List.of(0.1, 0.2)); + + Field clientField = GitHubProvider.class.getDeclaredField("client"); + clientField.setAccessible(true); + clientField.set(provider, mockClient); + + // Act + var resp = provider.embeddings("embed-model", "text to embed"); + + // Assert + assertEquals(2, resp.size()); + verify(mockClient, times(1)).embeddings(eq("embed-model"), eq("text to embed")); + } +}