Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,58 @@ public ClaudeClient(RestClient.Builder restClientBuilder,
this.objectMapper = objectMapper;
}

/**
* Process user query with tool use support.
* Supports multi-turn tool calls for schema discovery and query execution.
*
* @param userMessage User's question
* @return AI response
*/
public String chat(List<Map<String, Object>> initialMessages) {
try {
List<Map<String, Object>> messages = new ArrayList<>(initialMessages);

for (int i = 0; i < MAX_TOOL_ITERATIONS; i++) {
Map<String, Object> request = buildRequest(messages);
String response = callClaudeApi(request);

if (response == null) {
log.error("Claude API returned null response");
throw new ClaudeException("Claude API returned null response", null);
}

JsonNode responseNode = objectMapper.readTree(response);
String stopReason = responseNode.path("stop_reason").asText();
JsonNode content = responseNode.path("content");

if ("end_turn".equals(stopReason)) {
return extractTextResponse(content);
}

if ("tool_use".equals(stopReason)) {
List<Map<String, Object>> toolResults = processToolCalls(content);

if (toolResults.isEmpty()) {
log.warn("stop_reason is tool_use but no tool_use blocks found");
return extractTextResponse(content);
}

messages.add(Map.of(
"role", "assistant",
"content", objectMapper.convertValue(content, List.class)
));
messages.add(Map.of(
"role", "user",
"content", toolResults
));
} else {
log.warn("Unexpected stop_reason: {}", stopReason);
return extractTextResponse(content);
}
}

log.warn("Maximum tool iterations reached");
throw new ClaudeException("Maximum tool iterations reached", null);

} catch (JsonProcessingException e) {
log.error("Failed to parse Claude API response", e);
throw new ClaudeException("Failed to parse Claude API response", e);
}
}

public String chat(String userMessage) {
try {
List<Map<String, Object>> messages = new ArrayList<>();
Expand All @@ -144,7 +189,6 @@ public String chat(String userMessage) {
}

JsonNode responseNode = objectMapper.readTree(response);

String stopReason = responseNode.path("stop_reason").asText();
JsonNode content = responseNode.path("content");

Expand All @@ -153,26 +197,22 @@ public String chat(String userMessage) {
}

if ("tool_use".equals(stopReason)) {
// Process tool calls and add results
List<Map<String, Object>> toolResults = processToolCalls(content);

if (toolResults.isEmpty()) {
log.warn("stop_reason is tool_use but no tool_use blocks found");
return extractTextResponse(content);
}

// Add assistant's response to messages
messages.add(Map.of(
"role", "assistant",
"content", objectMapper.convertValue(content, List.class)
));

messages.add(Map.of(
"role", "user",
"content", toolResults
));
} else {
// Unexpected stop reason
log.warn("Unexpected stop_reason: {}", stopReason);
return extractTextResponse(content);
}
Expand Down
120 changes: 105 additions & 15 deletions src/main/java/gg/agit/konect/infrastructure/slack/ai/SlackAIService.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package gg.agit.konect.infrastructure.slack.ai;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand All @@ -8,7 +12,6 @@

import gg.agit.konect.infrastructure.claude.client.ClaudeClient;
import gg.agit.konect.infrastructure.slack.client.SlackClient;
import gg.agit.konect.infrastructure.slack.config.SlackProperties;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;

Expand All @@ -19,10 +22,15 @@ public class SlackAIService {

private static final Pattern AI_PREFIX_PATTERN = Pattern.compile("^[Aa][Ii]\\)\\s*(.+)$");
private static final Pattern MENTION_PATTERN = Pattern.compile("^<@[^>]+>\\s*");
private static final String AI_RESPONSE_PREFIX = ":robot_face: *AI ์‘๋‹ต*\n";
private static final int MAX_HISTORY_MESSAGES = 10;
private static final String EMPTY_QUERY_MESSAGE =
"์งˆ๋ฌธ ๋‚ด์šฉ์ด ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ: `AI) ๊ฐ€์ž…์ž ์ˆ˜ ์•Œ๋ ค์ค˜` ๋˜๋Š” `@๋ด‡์ด๋ฆ„ ๋™์•„๋ฆฌ ์ˆ˜๋Š”?`";
private static final String ERROR_MESSAGE =
":warning: ์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์š”์ฒญ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค.";

private final ClaudeClient claudeClient;
private final SlackClient slackClient;
private final SlackProperties slackProperties;

public boolean isAIQuery(String text) {
if (text == null) {
Expand All @@ -31,6 +39,13 @@ public boolean isAIQuery(String text) {
return AI_PREFIX_PATTERN.matcher(text.trim()).matches();
}

public boolean isAppMention(String text) {
if (text == null) {
return false;
}
return MENTION_PATTERN.matcher(text.trim()).find();
}

public String extractQuery(String text) {
Matcher matcher = AI_PREFIX_PATTERN.matcher(text.trim());
if (matcher.matches()) {
Expand All @@ -46,37 +61,112 @@ public String normalizeAppMentionText(String text) {
return MENTION_PATTERN.matcher(text).replaceFirst("").trim();
}

public List<Map<String, Object>> fetchAIThreadReplies(String channelId, String threadTs) {
List<Map<String, Object>> replies = slackClient.getThreadReplies(channelId, threadTs);
if (replies.isEmpty()) {
return new ArrayList<>();
}
Map<String, Object> rootMessage = replies.get(0);
String rootText = (String)rootMessage.get("text");
if (rootText != null && isAIQuery(rootText)) {
return replies;
}
if (replies.stream().anyMatch(r -> r.get("bot_id") != null)) {
return replies;
}
return new ArrayList<>();
}

@Async
public void processAIQuery(String text) {
public void processAIQuery(String text, String channelId, String threadTs,
List<Map<String, Object>> cachedReplies) {
try {
String userQuery = extractQuery(text);

// ๋นˆ ์งˆ๋ฌธ์€ ์ฒ˜๋ฆฌํ•˜์ง€ ์•Š์Œ
if (userQuery == null || userQuery.isBlank()) {
log.debug("๋นˆ ์งˆ๋ฌธ์œผ๋กœ ์ฒ˜๋ฆฌ ์ค‘๋‹จ");
String guidanceMessage = formatSlackResponse(
"์งˆ๋ฌธ ๋‚ด์šฉ์ด ๋น„์–ด์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ: `AI) ๊ฐ€์ž…์ž ์ˆ˜ ์•Œ๋ ค์ค˜` ๋˜๋Š” `@๋ด‡์ด๋ฆ„ ๋™์•„๋ฆฌ ์ˆ˜๋Š”?`"
);
slackClient.sendMessage(guidanceMessage, slackProperties.webhooks().event());
slackClient.postThreadReply(channelId, threadTs,
formatSlackResponse(EMPTY_QUERY_MESSAGE));
return;
}

log.debug("AI ์งˆ๋ฌธ ์ฒ˜๋ฆฌ ์‹œ์ž‘: {}", userQuery);

// ClaudeClient๊ฐ€ MCP๋ฅผ ํ†ตํ•ด ์ž๋™์œผ๋กœ SQL ๊ฒฐ์ • ๋ฐ ์‹คํ–‰
String response = claudeClient.chat(userQuery);
List<Map<String, Object>> replies =
cachedReplies != null ? cachedReplies : new ArrayList<>();
List<Map<String, Object>> messages = buildConversationHistory(replies);

if (messages.isEmpty()) {
messages = new ArrayList<>();
messages.add(Map.of("role", "user", "content", userQuery));
}

String response = claudeClient.chat(messages);

log.debug("AI ์‘๋‹ต ์ƒ์„ฑ ์™„๋ฃŒ");

// Slack์— ์‘๋‹ต ์ „์†ก
String slackMessage = formatSlackResponse(response);
slackClient.sendMessage(slackMessage, slackProperties.webhooks().event());
slackClient.postThreadReply(channelId, threadTs, formatSlackResponse(response));

} catch (Exception e) {
log.error("AI ์งˆ๋ฌธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ", e);
String errorMessage = ":warning: ์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์š”์ฒญ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค.";
slackClient.sendMessage(errorMessage, slackProperties.webhooks().event());
slackClient.postThreadReply(channelId, threadTs, ERROR_MESSAGE);
}
}

private List<Map<String, Object>> buildConversationHistory(List<Map<String, Object>> replies) {
if (replies.isEmpty()) {
return new ArrayList<>();
}

List<Map<String, Object>> messages = new ArrayList<>();
for (Map<String, Object> reply : replies) {
String replyText = (String)reply.get("text");

if (replyText == null) {
continue;
}

if (reply.get("bot_id") != null) {
String content = replyText.startsWith(AI_RESPONSE_PREFIX)
? replyText.substring(AI_RESPONSE_PREFIX.length())
: replyText;
messages.add(Map.of("role", "assistant", "content", content));
} else {
String normalizedText = normalizeAppMentionText(replyText);
String userText = isAIQuery(normalizedText)
? extractQuery(normalizedText)
: normalizedText;
messages.add(Map.of("role", "user", "content", userText));
}
}

List<Map<String, Object>> merged = mergeConsecutiveRoles(messages);

if (!merged.isEmpty() && "assistant".equals(merged.get(0).get("role"))) {
merged = new ArrayList<>(merged.subList(1, merged.size()));
}

if (merged.size() > MAX_HISTORY_MESSAGES) {
merged = new ArrayList<>(
merged.subList(merged.size() - MAX_HISTORY_MESSAGES, merged.size())
);
}
return merged;
}

private List<Map<String, Object>> mergeConsecutiveRoles(List<Map<String, Object>> messages) {
List<Map<String, Object>> merged = new ArrayList<>();
for (Map<String, Object> msg : messages) {
if (!merged.isEmpty()
&& merged.get(merged.size() - 1).get("role").equals(msg.get("role"))) {
Map<String, Object> last = new HashMap<>(merged.get(merged.size() - 1));
last.put("content", last.get("content") + "\n" + msg.get("content"));
merged.set(merged.size() - 1, last);
} else {
merged.add(msg);
}
}
return merged;
}

private String formatSlackResponse(String response) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package gg.agit.konect.infrastructure.slack.ai;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -28,6 +31,9 @@ public class SlackEventController {

private static final String SLACK_TIMESTAMP_HEADER = "X-Slack-Request-Timestamp";
private static final String SLACK_SIGNATURE_HEADER = "X-Slack-Signature";
private static final int EVENT_CACHE_MAX_SIZE = 500;

private final Set<String> processedEventIds = ConcurrentHashMap.newKeySet();

private final SlackAIService slackAIService;
private final SlackSignatureVerifier signatureVerifier;
Expand All @@ -48,30 +54,34 @@ public ResponseEntity<Object> handleSlackEvent(

String type = (String)payload.get("type");

// URL ๊ฒ€์ฆ์€ ์„œ๋ช… ๊ฒ€์ฆ ์—†์ด ์ฒ˜๋ฆฌ (์ตœ์ดˆ ์„ค์ • ์‹œ)
if ("url_verification".equals(type)) {
String challenge = (String)payload.get("challenge");
log.info("Slack URL ๊ฒ€์ฆ ์š”์ฒญ ์ฒ˜๋ฆฌ");
return ResponseEntity.ok(Map.of("challenge", challenge));
}

// ์„œ๋ช… ๊ฒ€์ฆ - ์›๋ณธ ์š”์ฒญ ๋ณธ๋ฌธ ์‚ฌ์šฉ
if (!signatureVerifier.isValidRequest(timestamp, signature, rawBody)) {
log.warn("Slack ์„œ๋ช… ๊ฒ€์ฆ ์‹คํŒจ");
return ResponseEntity.status(HttpStatus.UNAUTHORIZED).build();
}

log.debug("Slack ์ด๋ฒคํŠธ ์ˆ˜์‹ : type={}", type);

// ์ด๋ฒคํŠธ ์ฝœ๋ฐฑ ์ฒ˜๋ฆฌ
if ("event_callback".equals(type)) {
String eventId = (String)payload.get("event_id");
if (eventId != null && !processedEventIds.add(eventId)) {
log.debug("์ค‘๋ณต ์ด๋ฒคํŠธ ๋ฌด์‹œ: event_id={}", eventId);
return ResponseEntity.ok().build();
}
if (processedEventIds.size() > EVENT_CACHE_MAX_SIZE) {
processedEventIds.remove(processedEventIds.iterator().next());
}
Map<String, Object> event = (Map<String, Object>)payload.get("event");
if (event != null) {
handleEvent(event);
}
}

// Slack์€ 3์ดˆ ๋‚ด ์‘๋‹ต์„ ๊ธฐ๋Œ€ํ•˜๋ฏ€๋กœ ๋น ๋ฅด๊ฒŒ 200 ๋ฐ˜ํ™˜
return ResponseEntity.ok().build();
}

Expand All @@ -89,27 +99,37 @@ private void handleEvent(Map<String, Object> event) {
String eventType = (String)event.get("type");
String text = (String)event.get("text");
String subtype = (String)event.get("subtype");
String channelId = (String)event.get("channel");
String ts = (String)event.get("ts");
String threadTs = (String)event.get("thread_ts");

log.debug("์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ: eventType={}", eventType);

// bot ๋ฉ”์‹œ์ง€๋‚˜ ๋ณ€๊ฒฝ ์ด๋ฒคํŠธ๋Š” ๋ฌด์‹œ
if (subtype != null) {
return;
}

// ๋ฉ”์‹œ์ง€ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
String effectiveThreadTs = threadTs != null ? threadTs : ts;

if ("message".equals(eventType) && text != null) {
if (slackAIService.isAIQuery(text)) {
log.debug("AI ์งˆ๋ฌธ ๊ฐ์ง€");
slackAIService.processAIQuery(text);
slackAIService.processAIQuery(text, channelId, effectiveThreadTs, null);
} else if (threadTs != null && slackAIService.isAppMention(text)) {
List<Map<String, Object>> aiReplies =
slackAIService.fetchAIThreadReplies(channelId, threadTs);
if (!aiReplies.isEmpty()) {
log.debug("AI ์Šค๋ ˆ๋“œ ๋‚ด ํ›„์† ์งˆ๋ฌธ ๊ฐ์ง€");
slackAIService.processAIQuery(
text, channelId, effectiveThreadTs, aiReplies);
}
}
}

// ์•ฑ ๋ฉ˜์…˜ ์ด๋ฒคํŠธ ์ฒ˜๋ฆฌ
if ("app_mention".equals(eventType) && text != null) {
String normalizedText = slackAIService.normalizeAppMentionText(text);
log.debug("์•ฑ ๋ฉ˜์…˜ ๊ฐ์ง€");
slackAIService.processAIQuery(normalizedText);
slackAIService.processAIQuery(normalizedText, channelId, effectiveThreadTs, null);
}
}
}
Loading
Loading