diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/pom.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/pom.xml new file mode 100644 index 00000000000..96947ae83bd --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/pom.xml @@ -0,0 +1,79 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../../../../pom.xml + + spring-ai-autoconfigure-model-chat-memory-redis + jar + Spring AI Redis Chat Memory Auto Configuration + Spring AI Redis Chat Memory Auto Configuration + + + + + org.springframework.boot + spring-boot-autoconfigure + + + + org.springframework.ai + spring-ai-model-chat-memory-repository-redis + ${project.version} + + + + redis.clients + jedis + + + + org.springframework.boot + spring-boot-starter-data-redis + true + + + + org.springframework.boot + spring-boot-data-redis + true + + + + org.springframework.boot + spring-boot-configuration-processor + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.springframework.boot + spring-boot-testcontainers + test + + + + org.testcontainers + testcontainers-junit-jupiter + test + + + + com.redis + testcontainers-redis + 2.2.0 + test + + + + \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfiguration.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfiguration.java new file mode 100644 index 00000000000..873fe83ef3f --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfiguration.java @@ -0,0 +1,84 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.chat.memory.redis.autoconfigure; + +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.memory.repository.redis.RedisChatMemoryRepository; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.data.redis.autoconfigure.DataRedisAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.util.StringUtils; + +import redis.clients.jedis.JedisPooled; + +/** + * Auto-configuration for Redis-based chat memory implementation. + * + * @author Brian Sam-Bodden + */ +@AutoConfiguration(after = DataRedisAutoConfiguration.class) +@ConditionalOnClass({ RedisChatMemoryRepository.class, JedisPooled.class }) +@EnableConfigurationProperties(RedisChatMemoryProperties.class) +public class RedisChatMemoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public JedisPooled jedisClient(RedisChatMemoryProperties properties) { + return new JedisPooled(properties.getHost(), properties.getPort()); + } + + @Bean + @ConditionalOnMissingBean({ RedisChatMemoryRepository.class, ChatMemory.class, ChatMemoryRepository.class }) + public RedisChatMemoryRepository redisChatMemory(JedisPooled jedisClient, RedisChatMemoryProperties properties) { + RedisChatMemoryRepository.Builder builder = RedisChatMemoryRepository.builder().jedisClient(jedisClient); + + // Apply configuration if provided + if (StringUtils.hasText(properties.getIndexName())) { + builder.indexName(properties.getIndexName()); + } + + if (StringUtils.hasText(properties.getKeyPrefix())) { + builder.keyPrefix(properties.getKeyPrefix()); + } + + if (properties.getTimeToLive() != null && properties.getTimeToLive().toSeconds() > 0) { + builder.timeToLive(properties.getTimeToLive()); + } + + if (properties.getInitializeSchema() != null) { + builder.initializeSchema(properties.getInitializeSchema()); + } + + if (properties.getMaxConversationIds() != null) { + builder.maxConversationIds(properties.getMaxConversationIds()); + } + + if (properties.getMaxMessagesPerConversation() != null) { + builder.maxMessagesPerConversation(properties.getMaxMessagesPerConversation()); + } + + if (properties.getMetadataFields() != null && !properties.getMetadataFields().isEmpty()) { + builder.metadataFields(properties.getMetadataFields()); + } + + return builder.build(); + } + +} \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryProperties.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryProperties.java new file mode 100644 index 00000000000..79af25c5167 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryProperties.java @@ -0,0 +1,156 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.chat.memory.redis.autoconfigure; + +import java.time.Duration; +import java.util.List; +import java.util.Map; + +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.ai.chat.memory.repository.redis.RedisChatMemoryConfig; + +/** + * Configuration properties for Redis-based chat memory. + * + * @author Brian Sam-Bodden + */ +@ConfigurationProperties(prefix = "spring.ai.chat.memory.redis") +public class RedisChatMemoryProperties { + + /** + * Redis server host. + */ + private String host = "localhost"; + + /** + * Redis server port. + */ + private int port = 6379; + + /** + * Name of the Redis search index. + */ + private String indexName = RedisChatMemoryConfig.DEFAULT_INDEX_NAME; + + /** + * Key prefix for Redis chat memory entries. + */ + private String keyPrefix = RedisChatMemoryConfig.DEFAULT_KEY_PREFIX; + + /** + * Time to live for chat memory entries. Default is no expiration. + */ + private Duration timeToLive; + + /** + * Whether to initialize the Redis schema. Default is true. + */ + private Boolean initializeSchema = true; + + /** + * Maximum number of conversation IDs to return (defaults to 1000). + */ + private Integer maxConversationIds = RedisChatMemoryConfig.DEFAULT_MAX_RESULTS; + + /** + * Maximum number of messages to return per conversation (defaults to 1000). + */ + private Integer maxMessagesPerConversation = RedisChatMemoryConfig.DEFAULT_MAX_RESULTS; + + /** + * Metadata field definitions for proper indexing. Compatible with RedisVL schema + * format. Example:
+	 * spring.ai.chat.memory.redis.metadata-fields[0].name=priority
+	 * spring.ai.chat.memory.redis.metadata-fields[0].type=tag
+	 * spring.ai.chat.memory.redis.metadata-fields[1].name=score
+	 * spring.ai.chat.memory.redis.metadata-fields[1].type=numeric
+	 * 
+ */ + private List> metadataFields; + + public String getHost() { + return host; + } + + public void setHost(String host) { + this.host = host; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + this.port = port; + } + + public String getIndexName() { + return indexName; + } + + public void setIndexName(String indexName) { + this.indexName = indexName; + } + + public String getKeyPrefix() { + return keyPrefix; + } + + public void setKeyPrefix(String keyPrefix) { + this.keyPrefix = keyPrefix; + } + + public Duration getTimeToLive() { + return timeToLive; + } + + public void setTimeToLive(Duration timeToLive) { + this.timeToLive = timeToLive; + } + + public Boolean getInitializeSchema() { + return initializeSchema; + } + + public void setInitializeSchema(Boolean initializeSchema) { + this.initializeSchema = initializeSchema; + } + + public Integer getMaxConversationIds() { + return maxConversationIds; + } + + public void setMaxConversationIds(Integer maxConversationIds) { + this.maxConversationIds = maxConversationIds; + } + + public Integer getMaxMessagesPerConversation() { + return maxMessagesPerConversation; + } + + public void setMaxMessagesPerConversation(Integer maxMessagesPerConversation) { + this.maxMessagesPerConversation = maxMessagesPerConversation; + } + + public List> getMetadataFields() { + return metadataFields; + } + + public void setMetadataFields(List> metadataFields) { + this.metadataFields = metadataFields; + } + +} \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports new file mode 100644 index 00000000000..d68fc574ca0 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -0,0 +1 @@ +org.springframework.ai.model.chat.memory.redis.autoconfigure.RedisChatMemoryAutoConfiguration \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfigurationIT.java new file mode 100644 index 00000000000..ca26acce9df --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/java/org/springframework/ai/model/chat/memory/redis/autoconfigure/RedisChatMemoryAutoConfigurationIT.java @@ -0,0 +1,93 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.model.chat.memory.redis.autoconfigure; + +import com.redis.testcontainers.RedisStackContainer; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.memory.repository.redis.RedisChatMemoryRepository; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.data.redis.autoconfigure.DataRedisAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import static org.assertj.core.api.Assertions.assertThat; + +@Testcontainers +class RedisChatMemoryAutoConfigurationIT { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemoryAutoConfigurationIT.class); + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)) + .withExposedPorts(6379); + + @BeforeAll + static void setup() { + logger.info("Redis container running on host: {} and port: {}", redisContainer.getHost(), + redisContainer.getFirstMappedPort()); + } + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration( + AutoConfigurations.of(RedisChatMemoryAutoConfiguration.class, DataRedisAutoConfiguration.class)) + .withPropertyValues("spring.data.redis.host=" + redisContainer.getHost(), + "spring.data.redis.port=" + redisContainer.getFirstMappedPort(), + // Pass the same Redis connection properties to our chat memory properties + "spring.ai.chat.memory.redis.host=" + redisContainer.getHost(), + "spring.ai.chat.memory.redis.port=" + redisContainer.getFirstMappedPort()); + + @Test + void autoConfigurationRegistersExpectedBeans() { + this.contextRunner.run(context -> { + assertThat(context).hasSingleBean(RedisChatMemoryRepository.class); + assertThat(context).hasSingleBean(ChatMemory.class); + assertThat(context).hasSingleBean(ChatMemoryRepository.class); + }); + } + + @Test + void customPropertiesAreApplied() { + this.contextRunner + .withPropertyValues("spring.ai.chat.memory.redis.index-name=custom-index", + "spring.ai.chat.memory.redis.key-prefix=custom-prefix:", + "spring.ai.chat.memory.redis.time-to-live=300s") + .run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + assertThat(chatMemory).isNotNull(); + }); + } + + @Test + void chatMemoryRepositoryIsProvidedByRedisChatMemory() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository redisChatMemory = context.getBean(RedisChatMemoryRepository.class); + ChatMemory chatMemory = context.getBean(ChatMemory.class); + ChatMemoryRepository repository = context.getBean(ChatMemoryRepository.class); + + assertThat(chatMemory).isSameAs(redisChatMemory); + assertThat(repository).isSameAs(redisChatMemory); + }); + } + +} \ No newline at end of file diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/resources/logback-test.xml b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/resources/logback-test.xml new file mode 100644 index 00000000000..01da2302942 --- /dev/null +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis/src/test/resources/logback-test.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/pom.xml b/memory/repository/spring-ai-model-chat-memory-repository-redis/pom.xml new file mode 100644 index 00000000000..6375d0370b0 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/pom.xml @@ -0,0 +1,93 @@ + + + 4.0.0 + + org.springframework.ai + spring-ai-parent + 2.0.0-SNAPSHOT + ../../../pom.xml + + spring-ai-model-chat-memory-repository-redis + jar + Spring AI Chat Memory Repository - Redis + Redis-based persistent implementation of the Spring AI ChatMemoryRepository interface + + + + org.springframework.ai + spring-ai-model + ${project.version} + + + + redis.clients + jedis + + + + com.google.code.gson + gson + + + + org.slf4j + slf4j-api + + + + + org.springframework.boot + spring-boot-starter-test + test + + + com.vaadin.external.google + android-json + + + + + + org.springframework.boot + spring-boot-testcontainers + test + + + + org.testcontainers + testcontainers-junit-jupiter + test + + + + com.redis + testcontainers-redis + 2.2.0 + test + + + + ch.qos.logback + logback-classic + test + + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + checkstyle-validation + none + + + + + + + diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/AdvancedChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/AdvancedChatMemoryRepository.java new file mode 100644 index 00000000000..de00d6ee156 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/AdvancedChatMemoryRepository.java @@ -0,0 +1,101 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.repository.redis; + +import java.time.Instant; +import java.util.List; + +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; + +/** + * Redis-specific extended interface for ChatMemoryRepository with advanced query + * capabilities. + * + *

+ * This interface provides Redis Search-specific functionality and serves as inspiration + * for potential future evolution of the core ChatMemoryRepository interface. Other + * database implementations may provide similar capabilities through vendor-specific + * extensions. + *

+ * + *

+ * Note that the {@code executeQuery} method uses Redis Search syntax, which is specific + * to Redis implementations and not portable across different storage backends. + *

+ * + * @author Brian Sam-Bodden + * @since 2.0.0 + */ +public interface AdvancedChatMemoryRepository extends ChatMemoryRepository { + + /** + * Find messages by content across all conversations. + * @param contentPattern The text pattern to search for in message content + * @param limit Maximum number of results to return + * @return List of messages matching the pattern + */ + List findByContent(String contentPattern, int limit); + + /** + * Find messages by type across all conversations. + * @param messageType The message type to filter by + * @param limit Maximum number of results to return + * @return List of messages of the specified type + */ + List findByType(MessageType messageType, int limit); + + /** + * Find messages by timestamp range. + * @param conversationId Optional conversation ID to filter by (null for all + * conversations) + * @param fromTime Start of time range (inclusive) + * @param toTime End of time range (inclusive) + * @param limit Maximum number of results to return + * @return List of messages within the time range + */ + List findByTimeRange(String conversationId, Instant fromTime, Instant toTime, int limit); + + /** + * Find messages with a specific metadata key-value pair. + * @param metadataKey The metadata key to search for + * @param metadataValue The metadata value to match + * @param limit Maximum number of results to return + * @return List of messages with matching metadata + */ + List findByMetadata(String metadataKey, Object metadataValue, int limit); + + /** + * Execute a custom query using Redis Search syntax. + * @param query The Redis Search query string + * @param limit Maximum number of results to return + * @return List of messages matching the query + */ + List executeQuery(String query, int limit); + + /** + * A wrapper class to return messages with their conversation context. + * + * @param conversationId the conversation identifier + * @param message the message content + * @param timestamp the message timestamp + */ + record MessageWithConversation(String conversationId, Message message, long timestamp) { + } + +} diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryConfig.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryConfig.java new file mode 100644 index 00000000000..1b9884e9041 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryConfig.java @@ -0,0 +1,271 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import redis.clients.jedis.JedisPooled; + +import org.springframework.util.Assert; + +/** + * Configuration class for RedisChatMemoryRepository. + * + * @author Brian Sam-Bodden + */ +public class RedisChatMemoryConfig { + + public static final String DEFAULT_INDEX_NAME = "chat-memory-idx"; + + public static final String DEFAULT_KEY_PREFIX = "chat-memory:"; + + /** + * Default maximum number of results to return (1000 is Redis's default cursor read + * size). + */ + public static final int DEFAULT_MAX_RESULTS = 1000; + + /** The Redis client */ + private final JedisPooled jedisClient; + + /** The index name for Redis Search */ + private final String indexName; + + /** The key prefix for stored messages */ + private final String keyPrefix; + + /** The time-to-live in seconds for stored messages */ + private final Integer timeToLiveSeconds; + + /** Whether to automatically initialize the schema */ + private final boolean initializeSchema; + + /** + * Maximum number of conversation IDs to return. + */ + private final int maxConversationIds; + + /** + * Maximum number of messages to return per conversation. + */ + private final int maxMessagesPerConversation; + + /** + * Optional metadata field definitions for proper indexing. Format compatible with + * RedisVL schema format. + */ + private final List> metadataFields; + + private RedisChatMemoryConfig(final Builder builder) { + Assert.notNull(builder.jedisClient, "JedisPooled client must not be null"); + Assert.hasText(builder.indexName, "Index name must not be empty"); + Assert.hasText(builder.keyPrefix, "Key prefix must not be empty"); + + this.jedisClient = builder.jedisClient; + this.indexName = builder.indexName; + this.keyPrefix = builder.keyPrefix; + this.timeToLiveSeconds = builder.timeToLiveSeconds; + this.initializeSchema = builder.initializeSchema; + this.maxConversationIds = builder.maxConversationIds; + this.maxMessagesPerConversation = builder.maxMessagesPerConversation; + this.metadataFields = builder.metadataFields != null ? Collections.unmodifiableList(builder.metadataFields) + : Collections.emptyList(); + } + + public static Builder builder() { + return new Builder(); + } + + public JedisPooled getJedisClient() { + return jedisClient; + } + + public String getIndexName() { + return indexName; + } + + public String getKeyPrefix() { + return keyPrefix; + } + + public Integer getTimeToLiveSeconds() { + return timeToLiveSeconds; + } + + public boolean isInitializeSchema() { + return initializeSchema; + } + + /** + * Gets the maximum number of conversation IDs to return. + * @return maximum number of conversation IDs + */ + public int getMaxConversationIds() { + return maxConversationIds; + } + + /** + * Gets the maximum number of messages to return per conversation. + * @return maximum number of messages per conversation + */ + public int getMaxMessagesPerConversation() { + return maxMessagesPerConversation; + } + + /** + * Gets the metadata field definitions. + * @return list of metadata field definitions in RedisVL-compatible format + */ + public List> getMetadataFields() { + return metadataFields; + } + + /** + * Builder for RedisChatMemoryConfig. + */ + public static class Builder { + + /** The Redis client */ + private JedisPooled jedisClient; + + /** The index name */ + private String indexName = DEFAULT_INDEX_NAME; + + /** The key prefix */ + private String keyPrefix = DEFAULT_KEY_PREFIX; + + /** The time-to-live in seconds */ + private Integer timeToLiveSeconds = -1; + + /** Whether to initialize the schema */ + private boolean initializeSchema = true; + + /** Maximum number of conversation IDs to return */ + private int maxConversationIds = DEFAULT_MAX_RESULTS; + + /** Maximum number of messages per conversation */ + private int maxMessagesPerConversation = DEFAULT_MAX_RESULTS; + + /** Optional metadata field definitions for indexing */ + private List> metadataFields; + + /** + * Sets the Redis client. + * @param jedisClient the Redis client to use + * @return the builder instance + */ + public Builder jedisClient(final JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + /** + * Sets the index name. + * @param indexName the index name to use + * @return the builder instance + */ + public Builder indexName(final String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets the key prefix. + * @param keyPrefix the key prefix to use + * @return the builder instance + */ + public Builder keyPrefix(final String keyPrefix) { + this.keyPrefix = keyPrefix; + return this; + } + + /** + * Sets the time-to-live duration. + * @param ttl the time-to-live duration + * @return the builder instance + */ + public Builder timeToLive(final Duration ttl) { + if (ttl != null) { + this.timeToLiveSeconds = (int) ttl.toSeconds(); + } + return this; + } + + /** + * Sets whether to initialize the schema. + * @param initialize true to initialize schema, false otherwise + * @return the builder instance + */ + public Builder initializeSchema(final boolean initialize) { + this.initializeSchema = initialize; + return this; + } + + /** + * Sets the maximum number of conversation IDs to return. Default is 1000, which + * is Redis's default cursor read size. + * @param maxConversationIds maximum number of conversation IDs + * @return the builder instance + */ + public Builder maxConversationIds(final int maxConversationIds) { + this.maxConversationIds = maxConversationIds; + return this; + } + + /** + * Sets the maximum number of messages to return per conversation. Default is + * 1000, which is Redis's default cursor read size. + * @param maxMessagesPerConversation maximum number of messages + * @return the builder instance + */ + public Builder maxMessagesPerConversation(final int maxMessagesPerConversation) { + this.maxMessagesPerConversation = maxMessagesPerConversation; + return this; + } + + /** + * Sets the metadata field definitions for proper indexing. Format is compatible + * with RedisVL schema format. Each map should contain "name" and "type" keys. + * + * Example:
+		 * List.of(
+		 *     Map.of("name", "priority", "type", "tag"),
+		 *     Map.of("name", "score", "type", "numeric"),
+		 *     Map.of("name", "category", "type", "tag")
+		 * )
+		 * 
+ * @param metadataFields list of field definitions + * @return the builder instance + */ + public Builder metadataFields(List> metadataFields) { + this.metadataFields = metadataFields; + return this; + } + + /** + * Builds a new RedisChatMemoryConfig instance. + * @return the new configuration instance + */ + public RedisChatMemoryConfig build() { + return new RedisChatMemoryConfig(this); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java new file mode 100644 index 00000000000..920c01a4bc4 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/main/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepository.java @@ -0,0 +1,1315 @@ +package org.springframework.ai.chat.memory.repository.redis; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.content.Media; +import org.springframework.ai.content.MediaContent; +import org.springframework.util.Assert; +import org.springframework.util.MimeType; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.json.Path2; +import redis.clients.jedis.search.*; +import redis.clients.jedis.search.RediSearchUtil; +import redis.clients.jedis.search.aggr.AggregationBuilder; +import redis.clients.jedis.search.aggr.AggregationResult; +import redis.clients.jedis.search.aggr.Reducers; +import redis.clients.jedis.search.querybuilder.QueryBuilders; +import redis.clients.jedis.search.querybuilder.QueryNode; +import redis.clients.jedis.search.querybuilder.Values; +import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Redis implementation of {@link ChatMemoryRepository} using Redis (JSON + Query Engine). + * Stores chat messages as JSON documents and uses the Redis Query Engine for querying. + * + * @author Brian Sam-Bodden + */ +public final class RedisChatMemoryRepository implements ChatMemoryRepository, AdvancedChatMemoryRepository { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemoryRepository.class); + + private static final Gson gson = new Gson(); + + private static final Path2 ROOT_PATH = Path2.of("$"); + + private final RedisChatMemoryConfig config; + + private final JedisPooled jedis; + + public RedisChatMemoryRepository(RedisChatMemoryConfig config) { + Assert.notNull(config, "Config must not be null"); + this.config = config; + this.jedis = config.getJedisClient(); + + if (config.isInitializeSchema()) { + initializeSchema(); + } + } + + public static Builder builder() { + return new Builder(); + } + + public void add(String conversationId, List messages) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(messages, "Messages must not be null"); + + if (messages.isEmpty()) { + return; + } + + if (logger.isDebugEnabled()) { + logger.debug("Adding {} messages to conversation: {}", messages.size(), conversationId); + } + + // Get the next available timestamp for the first message + long nextTimestamp = getNextTimestampForConversation(conversationId); + final AtomicLong timestampSequence = new AtomicLong(nextTimestamp); + + try (Pipeline pipeline = jedis.pipelined()) { + for (Message message : messages) { + long timestamp = timestampSequence.getAndIncrement(); + String key = createKey(conversationId, timestamp); + + Map documentMap = createMessageDocument(conversationId, message); + // Ensure the timestamp in the document matches the key timestamp for + // consistency + documentMap.put("timestamp", timestamp); + + String json = gson.toJson(documentMap); + + if (logger.isDebugEnabled()) { + logger.debug("Storing batch message with key: {}, type: {}, content: {}", key, + message.getMessageType(), message.getText()); + } + + pipeline.jsonSet(key, ROOT_PATH, json); + + if (config.getTimeToLiveSeconds() != -1) { + pipeline.expire(key, config.getTimeToLiveSeconds()); + } + } + pipeline.sync(); + } + } + + public void add(String conversationId, Message message) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(message, "Message must not be null"); + + if (logger.isDebugEnabled()) { + logger.debug("Adding message type: {}, content: {} to conversation: {}", message.getMessageType(), + message.getText(), conversationId); + } + + // Get the current highest timestamp for this conversation + long timestamp = getNextTimestampForConversation(conversationId); + + String key = createKey(conversationId, timestamp); + Map documentMap = createMessageDocument(conversationId, message); + + // Ensure the timestamp in the document matches the key timestamp for consistency + documentMap.put("timestamp", timestamp); + + String json = gson.toJson(documentMap); + + if (logger.isDebugEnabled()) { + logger.debug("Storing message with key: {}, JSON: {}", key, json); + } + + jedis.jsonSet(key, ROOT_PATH, json); + + if (config.getTimeToLiveSeconds() != -1) { + jedis.expire(key, config.getTimeToLiveSeconds()); + } + } + + /** + * Gets the next available timestamp for a conversation to ensure proper ordering. + * Uses Redis Lua script for atomic operations to ensure thread safety when multiple + * threads access the same conversation. + * @param conversationId the conversation ID + * @return the next timestamp to use + */ + private long getNextTimestampForConversation(String conversationId) { + // Create a Redis key specifically for tracking the sequence + String sequenceKey = String.format("%scounter:%s", config.getKeyPrefix(), escapeKey(conversationId)); + + try { + // Get the current time as base timestamp + long baseTimestamp = Instant.now().toEpochMilli(); + // Using a Lua script for atomic operation ensures that multiple threads + // will always get unique and increasing timestamps + String script = "local exists = redis.call('EXISTS', KEYS[1]) " + "if exists == 0 then " + + " redis.call('SET', KEYS[1], ARGV[1]) " + " return ARGV[1] " + "end " + + "return redis.call('INCR', KEYS[1])"; + + // Execute the script atomically + Object result = jedis.eval(script, java.util.Collections.singletonList(sequenceKey), + java.util.Collections.singletonList(String.valueOf(baseTimestamp))); + + long nextTimestamp = Long.parseLong(result.toString()); + + // Set expiration on the counter key (same as the messages) + if (config.getTimeToLiveSeconds() != -1) { + jedis.expire(sequenceKey, config.getTimeToLiveSeconds()); + } + + if (logger.isDebugEnabled()) { + logger.debug("Generated atomic timestamp {} for conversation {}", nextTimestamp, conversationId); + } + + return nextTimestamp; + } + + catch (Exception e) { + // Log error and fall back to current timestamp with nanoTime for uniqueness + logger.warn("Error getting atomic timestamp for conversation {}, using fallback: {}", conversationId, + e.getMessage()); + // Add nanoseconds to ensure uniqueness even in fallback scenario + return Instant.now().toEpochMilli() * 1000 + (System.nanoTime() % 1000); + } + } + + public List get(String conversationId) { + return get(conversationId, config.getMaxMessagesPerConversation()); + } + + public List get(String conversationId, int lastN) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.isTrue(lastN > 0, "LastN must be greater than 0"); + + // Use QueryBuilders to create a tag field query for conversation_id + QueryNode queryNode = QueryBuilders.intersect("conversation_id", + Values.tags(RediSearchUtil.escape(conversationId))); + Query query = new Query(queryNode.toString()).setSortBy("timestamp", true).limit(0, lastN); + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + if (logger.isDebugEnabled()) { + logger.debug("Redis search for conversation {} returned {} results", conversationId, + result.getDocuments().size()); + result.getDocuments().forEach(doc -> { + if (doc.get("$") != null) { + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); + logger.debug("Document: {}", json); + } + }); + } + + List messages = new ArrayList<>(); + result.getDocuments().forEach(doc -> { + if (doc.get("$") != null) { + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); + if (logger.isDebugEnabled()) { + logger.debug("Processing JSON document: {}", json); + } + + String type = json.get("type").getAsString(); + String content = json.get("content").getAsString(); + + // Convert metadata from JSON to Map if present + Map metadata = new HashMap<>(); + if (json.has("metadata") && json.get("metadata").isJsonObject()) { + JsonObject metadataJson = json.getAsJsonObject("metadata"); + metadataJson.entrySet().forEach(entry -> { + metadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); + }); + } + + if (MessageType.ASSISTANT.toString().equals(type)) { + if (logger.isDebugEnabled()) { + logger.debug("Creating AssistantMessage with content: {}", content); + } + + // Handle tool calls if present + List toolCalls = new ArrayList<>(); + if (json.has("toolCalls") && json.get("toolCalls").isJsonArray()) { + json.getAsJsonArray("toolCalls").forEach(element -> { + JsonObject toolCallJson = element.getAsJsonObject(); + toolCalls.add(new AssistantMessage.ToolCall( + toolCallJson.has("id") ? toolCallJson.get("id").getAsString() : "", + toolCallJson.has("type") ? toolCallJson.get("type").getAsString() : "", + toolCallJson.has("name") ? toolCallJson.get("name").getAsString() : "", + toolCallJson.has("arguments") ? toolCallJson.get("arguments").getAsString() : "")); + }); + } + + // Handle media if present + List media = new ArrayList<>(); + if (json.has("media") && json.get("media").isJsonArray()) { + JsonArray mediaArray = json.getAsJsonArray("media"); + for (JsonElement mediaElement : mediaArray) { + JsonObject mediaJson = mediaElement.getAsJsonObject(); + + // Extract required media properties + String mediaId = mediaJson.has("id") ? mediaJson.get("id").getAsString() : null; + String mediaName = mediaJson.has("name") ? mediaJson.get("name").getAsString() : null; + String mimeTypeString = mediaJson.has("mimeType") ? mediaJson.get("mimeType").getAsString() + : null; + + if (mimeTypeString != null) { + MimeType mimeType = MimeType.valueOf(mimeTypeString); + Media.Builder mediaBuilder = Media.builder().mimeType(mimeType); + + // Set optional properties if present + if (mediaId != null) { + mediaBuilder.id(mediaId); + } + + if (mediaName != null) { + mediaBuilder.name(mediaName); + } + + // Handle data based on its type + if (mediaJson.has("data")) { + JsonElement dataElement = mediaJson.get("data"); + if (dataElement.isJsonPrimitive() && dataElement.getAsJsonPrimitive().isString()) { + String dataString = dataElement.getAsString(); + + // Check if data is Base64-encoded + if (mediaJson.has("dataType") + && "base64".equals(mediaJson.get("dataType").getAsString())) { + // Decode Base64 string to byte array + try { + byte[] decodedBytes = Base64.getDecoder().decode(dataString); + mediaBuilder.data(decodedBytes); + } + + catch (IllegalArgumentException e) { + logger.warn("Failed to decode Base64 data, storing as string", e); + mediaBuilder.data(dataString); + } + } + + else { + // Handle URL/URI data + try { + mediaBuilder.data(URI.create(dataString)); + } + + catch (IllegalArgumentException e) { + // Not a valid URI, store as string + mediaBuilder.data(dataString); + } + } + } + + else if (dataElement.isJsonArray()) { + // For backward compatibility - handle byte array + // data stored as JSON array + JsonArray dataArray = dataElement.getAsJsonArray(); + byte[] byteArray = new byte[dataArray.size()]; + for (int i = 0; i < dataArray.size(); i++) { + byteArray[i] = dataArray.get(i).getAsByte(); + } + mediaBuilder.data(byteArray); + } + } + + media.add(mediaBuilder.build()); + } + } + } + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content(content) + .properties(metadata) + .toolCalls(toolCalls) + .media(media) + .build(); + messages.add(assistantMessage); + } + + else if (MessageType.USER.toString().equals(type)) { + if (logger.isDebugEnabled()) { + logger.debug("Creating UserMessage with content: {}", content); + } + + // Create a UserMessage with the builder to properly set metadata + List userMedia = new ArrayList<>(); + if (json.has("media") && json.get("media").isJsonArray()) { + JsonArray mediaArray = json.getAsJsonArray("media"); + for (JsonElement mediaElement : mediaArray) { + JsonObject mediaJson = mediaElement.getAsJsonObject(); + + // Extract required media properties + String mediaId = mediaJson.has("id") ? mediaJson.get("id").getAsString() : null; + String mediaName = mediaJson.has("name") ? mediaJson.get("name").getAsString() : null; + String mimeTypeString = mediaJson.has("mimeType") ? mediaJson.get("mimeType").getAsString() + : null; + + if (mimeTypeString != null) { + MimeType mimeType = MimeType.valueOf(mimeTypeString); + Media.Builder mediaBuilder = Media.builder().mimeType(mimeType); + + // Set optional properties if present + if (mediaId != null) { + mediaBuilder.id(mediaId); + } + + if (mediaName != null) { + mediaBuilder.name(mediaName); + } + + // Handle data based on its type and markers + if (mediaJson.has("data")) { + JsonElement dataElement = mediaJson.get("data"); + if (dataElement.isJsonPrimitive() && dataElement.getAsJsonPrimitive().isString()) { + String dataString = dataElement.getAsString(); + + // Check if data is Base64-encoded + if (mediaJson.has("dataType") + && "base64".equals(mediaJson.get("dataType").getAsString())) { + // Decode Base64 string to byte array + try { + byte[] decodedBytes = Base64.getDecoder().decode(dataString); + mediaBuilder.data(decodedBytes); + } + + catch (IllegalArgumentException e) { + logger.warn("Failed to decode Base64 data, storing as string", e); + mediaBuilder.data(dataString); + } + } + + else { + // Handle URL/URI data + try { + mediaBuilder.data(URI.create(dataString)); + } + + catch (IllegalArgumentException e) { + // Not a valid URI, store as string + mediaBuilder.data(dataString); + } + } + } + + else if (dataElement.isJsonArray()) { + // For backward compatibility - handle byte array + // data stored as JSON array + JsonArray dataArray = dataElement.getAsJsonArray(); + byte[] byteArray = new byte[dataArray.size()]; + for (int i = 0; i < dataArray.size(); i++) { + byteArray[i] = dataArray.get(i).getAsByte(); + } + mediaBuilder.data(byteArray); + } + } + + userMedia.add(mediaBuilder.build()); + } + } + } + messages.add(UserMessage.builder().text(content).metadata(metadata).media(userMedia).build()); + } + + else if (MessageType.SYSTEM.toString().equals(type)) { + if (logger.isDebugEnabled()) { + logger.debug("Creating SystemMessage with content: {}", content); + } + + messages.add(SystemMessage.builder().text(content).metadata(metadata).build()); + } + + else if (MessageType.TOOL.toString().equals(type)) { + if (logger.isDebugEnabled()) { + logger.debug("Creating ToolResponseMessage with content: {}", content); + } + + // Extract tool responses + List toolResponses = new ArrayList<>(); + if (json.has("toolResponses") && json.get("toolResponses").isJsonArray()) { + JsonArray responseArray = json.getAsJsonArray("toolResponses"); + for (JsonElement responseElement : responseArray) { + JsonObject responseJson = responseElement.getAsJsonObject(); + + String id = responseJson.has("id") ? responseJson.get("id").getAsString() : ""; + String name = responseJson.has("name") ? responseJson.get("name").getAsString() : ""; + String responseData = responseJson.has("responseData") + ? responseJson.get("responseData").getAsString() : ""; + + toolResponses.add(new ToolResponseMessage.ToolResponse(id, name, responseData)); + } + } + + messages.add(ToolResponseMessage.builder().responses(toolResponses).metadata(metadata).build()); + } + // Add handling for other message types if needed + else { + logger.warn("Unknown message type: {}", type); + } + } + }); + + if (logger.isDebugEnabled()) { + logger.debug("Returning {} messages for conversation {}", messages.size(), conversationId); + messages.forEach(message -> logger.debug("Message type: {}, content: {}, class: {}", + message.getMessageType(), message.getText(), message.getClass().getSimpleName())); + } + + return messages; + } + + public void clear(String conversationId) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + + // Use QueryBuilders to create a tag field query + QueryNode queryNode = QueryBuilders.intersect("conversation_id", + Values.tags(RediSearchUtil.escape(conversationId))); + Query query = new Query(queryNode.toString()); + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + try (Pipeline pipeline = jedis.pipelined()) { + result.getDocuments().forEach(doc -> pipeline.del(doc.getId())); + pipeline.sync(); + } + } + + private void initializeSchema() { + try { + if (!jedis.ftList().contains(config.getIndexName())) { + List schemaFields = new ArrayList<>(); + + // Basic fields for all messages - using schema field objects + schemaFields.add(new TextField("$.content").as("content")); + schemaFields.add(new TextField("$.type").as("type")); + schemaFields.add(new TagField("$.conversation_id").as("conversation_id")); + schemaFields.add(new NumericField("$.timestamp").as("timestamp")); + + // Add metadata fields based on user-provided schema or default to text + if (config.getMetadataFields() != null && !config.getMetadataFields().isEmpty()) { + // User has provided a metadata schema - use it + for (Map fieldDef : config.getMetadataFields()) { + String fieldName = fieldDef.get("name"); + String fieldType = fieldDef.getOrDefault("type", "text"); + String jsonPath = "$.metadata." + fieldName; + String indexedName = "metadata_" + fieldName; + + switch (fieldType.toLowerCase()) { + case "numeric": + schemaFields.add(new NumericField(jsonPath).as(indexedName)); + break; + case "tag": + schemaFields.add(new TagField(jsonPath).as(indexedName)); + break; + case "text": + default: + schemaFields.add(new TextField(jsonPath).as(indexedName)); + break; + } + } + // When specific metadata fields are defined, we don't add a wildcard + // metadata field to avoid indexing errors with non-string values + } + + else { + // No schema provided - fallback to indexing all metadata as text + schemaFields.add(new TextField("$.metadata.*").as("metadata")); + } + + // Create the index with the defined schema + FTCreateParams indexParams = FTCreateParams.createParams() + .on(IndexDataType.JSON) + .prefix(config.getKeyPrefix()); + + String response = jedis.ftCreate(config.getIndexName(), indexParams, + schemaFields.toArray(new SchemaField[0])); + + if (!response.equals("OK")) { + throw new IllegalStateException("Failed to create index: " + response); + } + + if (logger.isDebugEnabled()) { + logger.debug("Created Redis search index '{}' with {} schema fields", config.getIndexName(), + schemaFields.size()); + } + } + + else if (logger.isDebugEnabled()) { + logger.debug("Redis search index '{}' already exists", config.getIndexName()); + } + } + + catch (Exception e) { + logger.error("Failed to initialize Redis schema: {}", e.getMessage()); + if (logger.isDebugEnabled()) { + logger.debug("Error details", e); + } + throw new IllegalStateException("Could not initialize Redis schema", e); + } + } + + private String createKey(String conversationId, long timestamp) { + return String.format("%s%s:%d", config.getKeyPrefix(), escapeKey(conversationId), timestamp); + } + + private Map createMessageDocument(String conversationId, Message message) { + Map documentMap = new HashMap<>(); + documentMap.put("type", message.getMessageType().toString()); + documentMap.put("content", message.getText()); + documentMap.put("conversation_id", conversationId); + documentMap.put("timestamp", Instant.now().toEpochMilli()); + + // Store metadata/properties + if (message.getMetadata() != null && !message.getMetadata().isEmpty()) { + documentMap.put("metadata", message.getMetadata()); + } + + // Handle tool calls for AssistantMessage + if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) { + documentMap.put("toolCalls", assistantMessage.getToolCalls()); + } + + // Handle tool responses for ToolResponseMessage + if (message instanceof ToolResponseMessage toolResponseMessage) { + documentMap.put("toolResponses", toolResponseMessage.getResponses()); + } + + // Handle media content + if (message instanceof MediaContent mediaContent && !mediaContent.getMedia().isEmpty()) { + List> mediaList = new ArrayList<>(); + + for (Media media : mediaContent.getMedia()) { + Map mediaMap = new HashMap<>(); + + // Store ID and name if present + if (media.getId() != null) { + mediaMap.put("id", media.getId()); + } + + if (media.getName() != null) { + mediaMap.put("name", media.getName()); + } + + // Store MimeType as string + if (media.getMimeType() != null) { + mediaMap.put("mimeType", media.getMimeType().toString()); + } + + // Handle data based on its type + Object data = media.getData(); + if (data != null) { + if (data instanceof URI || data instanceof String) { + // Store URI/URL as string + mediaMap.put("data", data.toString()); + } + + else if (data instanceof byte[]) { + // Encode byte array as Base64 string + mediaMap.put("data", Base64.getEncoder().encodeToString((byte[]) data)); + // Add a marker to indicate this is Base64-encoded + mediaMap.put("dataType", "base64"); + } + + else { + // For other types, store as string + mediaMap.put("data", data.toString()); + } + } + + mediaList.add(mediaMap); + } + + documentMap.put("media", mediaList); + } + + return documentMap; + } + + private String escapeKey(String key) { + return key.replace(":", "\\:"); + } + + // ChatMemoryRepository implementation + + /** + * Finds all unique conversation IDs using Redis aggregation. This method is optimized + * to perform the deduplication on the Redis server side. + * @return a list of unique conversation IDs + */ + @Override + public List findConversationIds() { + // Use Redis aggregation to get distinct conversation_ids + AggregationBuilder aggregation = new AggregationBuilder("*") + .groupBy("@conversation_id", Reducers.count().as("count")) + .limit(0, config.getMaxConversationIds()); // Use configured limit + + AggregationResult result = jedis.ftAggregate(config.getIndexName(), aggregation); + + List conversationIds = new ArrayList<>(); + result.getResults().forEach(row -> { + String conversationId = (String) row.get("conversation_id"); + if (conversationId != null) { + conversationIds.add(conversationId); + } + }); + + if (logger.isDebugEnabled()) { + logger.debug("Found {} unique conversation IDs using Redis aggregation", conversationIds.size()); + conversationIds.forEach(id -> logger.debug("Conversation ID: {}", id)); + } + + return conversationIds; + } + + /** + * Finds all messages for a given conversation ID. Uses the configured maximum + * messages per conversation limit to avoid exceeding Redis limits. + * @param conversationId the conversation ID to find messages for + * @return a list of messages for the conversation + */ + @Override + public List findByConversationId(String conversationId) { + // Reuse existing get method with the configured limit + return get(conversationId, config.getMaxMessagesPerConversation()); + } + + @Override + public void saveAll(String conversationId, List messages) { + // First clear any existing messages for this conversation + clear(conversationId); + + // Then add all the new messages + add(conversationId, messages); + } + + @Override + public void deleteByConversationId(String conversationId) { + // Reuse existing clear method + clear(conversationId); + } + + // AdvancedChatMemoryRepository implementation + + /** + * Gets the index name used by this RedisChatMemory instance. + * @return the index name + */ + public String getIndexName() { + return config.getIndexName(); + } + + @Override + public List findByContent(String contentPattern, int limit) { + Assert.notNull(contentPattern, "Content pattern must not be null"); + Assert.isTrue(limit > 0, "Limit must be greater than 0"); + + // Use QueryBuilders to create a text field query + // Note: We don't escape the contentPattern here because Redis full-text search + // should handle the special characters appropriately in text fields + QueryNode queryNode = QueryBuilders.intersect("content", Values.value(contentPattern)); + Query query = new Query(queryNode.toString()).setSortBy("timestamp", true).limit(0, limit); + + if (logger.isDebugEnabled()) { + logger.debug("Searching for messages with content pattern '{}' with limit {}", contentPattern, limit); + } + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + return processSearchResult(result); + } + + @Override + public List findByType(MessageType messageType, int limit) { + Assert.notNull(messageType, "Message type must not be null"); + Assert.isTrue(limit > 0, "Limit must be greater than 0"); + + // Use QueryBuilders to create a text field query + QueryNode queryNode = QueryBuilders.intersect("type", Values.value(messageType.toString())); + Query query = new Query(queryNode.toString()).setSortBy("timestamp", true).limit(0, limit); + + if (logger.isDebugEnabled()) { + logger.debug("Searching for messages of type {} with limit {}", messageType, limit); + } + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + return processSearchResult(result); + } + + @Override + public List findByTimeRange(String conversationId, Instant fromTime, Instant toTime, + int limit) { + Assert.notNull(fromTime, "From time must not be null"); + Assert.notNull(toTime, "To time must not be null"); + Assert.isTrue(limit > 0, "Limit must be greater than 0"); + Assert.isTrue(!toTime.isBefore(fromTime), "To time must not be before from time"); + + // Build query with numeric range for timestamp using the QueryBuilder + long fromTimeMs = fromTime.toEpochMilli(); + long toTimeMs = toTime.toEpochMilli(); + + // Create the numeric range query for timestamp + QueryNode rangeNode = QueryBuilders.intersect("timestamp", Values.between(fromTimeMs, toTimeMs)); + + // If conversationId is provided, add it to the query as a tag filter + QueryNode finalQuery; + if (conversationId != null && !conversationId.isEmpty()) { + QueryNode conversationNode = QueryBuilders.intersect("conversation_id", + Values.tags(RediSearchUtil.escape(conversationId))); + finalQuery = QueryBuilders.intersect(rangeNode, conversationNode); + } + + else { + finalQuery = rangeNode; + } + + // Create the query with sorting by timestamp + Query query = new Query(finalQuery.toString()).setSortBy("timestamp", true).limit(0, limit); + + if (logger.isDebugEnabled()) { + logger.debug("Searching for messages in time range from {} to {} with limit {}, query: '{}'", fromTime, + toTime, limit, finalQuery); + } + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + return processSearchResult(result); + } + + @Override + public List findByMetadata(String metadataKey, Object metadataValue, int limit) { + Assert.notNull(metadataKey, "Metadata key must not be null"); + Assert.notNull(metadataValue, "Metadata value must not be null"); + Assert.isTrue(limit > 0, "Limit must be greater than 0"); + + // Check if this metadata field was explicitly defined in the schema + String indexedFieldName = "metadata_" + metadataKey; + boolean isFieldIndexed = false; + String fieldType = "text"; + + if (config.getMetadataFields() != null) { + for (Map fieldDef : config.getMetadataFields()) { + if (metadataKey.equals(fieldDef.get("name"))) { + isFieldIndexed = true; + fieldType = fieldDef.getOrDefault("type", "text"); + break; + } + } + } + + QueryNode queryNode; + if (isFieldIndexed) { + // Field is explicitly indexed - use proper query based on type + switch (fieldType.toLowerCase()) { + case "numeric": + if (metadataValue instanceof Number) { + queryNode = QueryBuilders.intersect(indexedFieldName, + Values.eq(((Number) metadataValue).doubleValue())); + } + + else { + // Try to parse as number + try { + double numValue = Double.parseDouble(metadataValue.toString()); + queryNode = QueryBuilders.intersect(indexedFieldName, Values.eq(numValue)); + } + + catch (NumberFormatException e) { + // Fall back to text search in general metadata + String searchPattern = metadataKey + " " + metadataValue; + queryNode = QueryBuilders.intersect("metadata", Values.value(searchPattern)); + } + } + break; + case "tag": + // For tag fields, we don't need to escape the value + queryNode = QueryBuilders.intersect(indexedFieldName, Values.tags(metadataValue.toString())); + break; + case "text": + default: + queryNode = QueryBuilders.intersect(indexedFieldName, + Values.value(RediSearchUtil.escape(metadataValue.toString()))); + break; + } + } + + else { + // Field not explicitly indexed - search in general metadata field + String searchPattern = metadataKey + " " + metadataValue; + queryNode = QueryBuilders.intersect("metadata", Values.value(searchPattern)); + } + + Query query = new Query(queryNode.toString()).setSortBy("timestamp", true).limit(0, limit); + + if (logger.isDebugEnabled()) { + logger.debug("Searching for messages with metadata {}={}, query: '{}', limit: {}", metadataKey, + metadataValue, queryNode, limit); + } + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + if (logger.isDebugEnabled()) { + logger.debug("Search returned {} results", result.getTotalResults()); + } + return processSearchResult(result); + } + + @Override + public List executeQuery(String query, int limit) { + Assert.notNull(query, "Query must not be null"); + Assert.isTrue(limit > 0, "Limit must be greater than 0"); + + // Create a Query object from the query string + // The client provides the full Redis Search query syntax + Query redisQuery = new Query(query).limit(0, limit).setSortBy("timestamp", true); // Default + // sorting + // by + // timestamp + // ascending + + if (logger.isDebugEnabled()) { + logger.debug("Executing custom query '{}' with limit {}", query, limit); + } + + return executeSearchQuery(redisQuery); + } + + /** + * Processes a search result and converts it to a list of MessageWithConversation + * objects. + * @param result the search result to process + * @return a list of MessageWithConversation objects + */ + private List processSearchResult(SearchResult result) { + List messages = new ArrayList<>(); + + for (Document doc : result.getDocuments()) { + if (doc.get("$") != null) { + // Parse the JSON document + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); + + // Extract conversation ID and timestamp + String conversationId = json.get("conversation_id").getAsString(); + long timestamp = json.get("timestamp").getAsLong(); + + // Convert JSON to message + Message message = convertJsonToMessage(json); + + // Add to result list + messages.add(new MessageWithConversation(conversationId, message, timestamp)); + } + } + + if (logger.isDebugEnabled()) { + logger.debug("Search returned {} messages", messages.size()); + } + + return messages; + } + + /** + * Executes a search query and converts the results to a list of + * MessageWithConversation objects. Centralizes the common search execution logic used + * by multiple finder methods. + * @param query The query to execute + * @return A list of MessageWithConversation objects + */ + private List executeSearchQuery(Query query) { + try { + // Execute the search + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + return processSearchResult(result); + } + + catch (Exception e) { + logger.error("Error executing query '{}': {}", query, e.getMessage()); + if (logger.isTraceEnabled()) { + logger.debug("Error details", e); + } + return Collections.emptyList(); + } + } + + /** + * Converts a JSON object to a Message instance. This is a helper method for the + * advanced query operations to convert Redis JSON documents back to Message objects. + * @param json The JSON object representing a message + * @return A Message object of the appropriate type + */ + private Message convertJsonToMessage(JsonObject json) { + String type = json.get("type").getAsString(); + String content = json.get("content").getAsString(); + + // Convert metadata from JSON to Map if present + Map metadata = new HashMap<>(); + if (json.has("metadata") && json.get("metadata").isJsonObject()) { + JsonObject metadataJson = json.getAsJsonObject("metadata"); + metadataJson.entrySet().forEach(entry -> { + metadata.put(entry.getKey(), gson.fromJson(entry.getValue(), Object.class)); + }); + } + + if (MessageType.ASSISTANT.toString().equals(type)) { + // Handle tool calls if present + List toolCalls = new ArrayList<>(); + if (json.has("toolCalls") && json.get("toolCalls").isJsonArray()) { + json.getAsJsonArray("toolCalls").forEach(element -> { + JsonObject toolCallJson = element.getAsJsonObject(); + toolCalls.add(new AssistantMessage.ToolCall( + toolCallJson.has("id") ? toolCallJson.get("id").getAsString() : "", + toolCallJson.has("type") ? toolCallJson.get("type").getAsString() : "", + toolCallJson.has("name") ? toolCallJson.get("name").getAsString() : "", + toolCallJson.has("arguments") ? toolCallJson.get("arguments").getAsString() : "")); + }); + } + + // Handle media if present + List media = new ArrayList<>(); + if (json.has("media") && json.get("media").isJsonArray()) { + JsonArray mediaArray = json.getAsJsonArray("media"); + for (JsonElement mediaElement : mediaArray) { + JsonObject mediaJson = mediaElement.getAsJsonObject(); + + // Extract required media properties + String mediaId = mediaJson.has("id") ? mediaJson.get("id").getAsString() : null; + String mediaName = mediaJson.has("name") ? mediaJson.get("name").getAsString() : null; + String mimeTypeString = mediaJson.has("mimeType") ? mediaJson.get("mimeType").getAsString() : null; + + if (mimeTypeString != null) { + MimeType mimeType = MimeType.valueOf(mimeTypeString); + Media.Builder mediaBuilder = Media.builder().mimeType(mimeType); + + // Set optional properties if present + if (mediaId != null) { + mediaBuilder.id(mediaId); + } + + if (mediaName != null) { + mediaBuilder.name(mediaName); + } + + // Handle data based on its type + if (mediaJson.has("data")) { + JsonElement dataElement = mediaJson.get("data"); + if (dataElement.isJsonPrimitive() && dataElement.getAsJsonPrimitive().isString()) { + String dataString = dataElement.getAsString(); + + // Check if data is Base64-encoded + if (mediaJson.has("dataType") + && "base64".equals(mediaJson.get("dataType").getAsString())) { + // Decode Base64 string to byte array + try { + byte[] decodedBytes = Base64.getDecoder().decode(dataString); + mediaBuilder.data(decodedBytes); + } + + catch (IllegalArgumentException e) { + logger.warn("Failed to decode Base64 data, storing as string", e); + mediaBuilder.data(dataString); + } + } + + else { + // Handle URL/URI data + try { + mediaBuilder.data(URI.create(dataString)); + } + + catch (IllegalArgumentException e) { + // Not a valid URI, store as string + mediaBuilder.data(dataString); + } + } + } + + else if (dataElement.isJsonArray()) { + // For backward compatibility - handle byte array data + // stored as JSON array + JsonArray dataArray = dataElement.getAsJsonArray(); + byte[] byteArray = new byte[dataArray.size()]; + for (int i = 0; i < dataArray.size(); i++) { + byteArray[i] = dataArray.get(i).getAsByte(); + } + mediaBuilder.data(byteArray); + } + } + + media.add(mediaBuilder.build()); + } + } + } + + return AssistantMessage.builder() + .content(content) + .properties(metadata) + .toolCalls(toolCalls) + .media(media) + .build(); + } + + else if (MessageType.USER.toString().equals(type)) { + // Create a UserMessage with the builder to properly set metadata + List userMedia = new ArrayList<>(); + if (json.has("media") && json.get("media").isJsonArray()) { + JsonArray mediaArray = json.getAsJsonArray("media"); + for (JsonElement mediaElement : mediaArray) { + JsonObject mediaJson = mediaElement.getAsJsonObject(); + + // Extract required media properties + String mediaId = mediaJson.has("id") ? mediaJson.get("id").getAsString() : null; + String mediaName = mediaJson.has("name") ? mediaJson.get("name").getAsString() : null; + String mimeTypeString = mediaJson.has("mimeType") ? mediaJson.get("mimeType").getAsString() : null; + + if (mimeTypeString != null) { + MimeType mimeType = MimeType.valueOf(mimeTypeString); + Media.Builder mediaBuilder = Media.builder().mimeType(mimeType); + + // Set optional properties if present + if (mediaId != null) { + mediaBuilder.id(mediaId); + } + + if (mediaName != null) { + mediaBuilder.name(mediaName); + } + + // Handle data based on its type and markers + if (mediaJson.has("data")) { + JsonElement dataElement = mediaJson.get("data"); + if (dataElement.isJsonPrimitive() && dataElement.getAsJsonPrimitive().isString()) { + String dataString = dataElement.getAsString(); + + // Check if data is Base64-encoded + if (mediaJson.has("dataType") + && "base64".equals(mediaJson.get("dataType").getAsString())) { + // Decode Base64 string to byte array + try { + byte[] decodedBytes = Base64.getDecoder().decode(dataString); + mediaBuilder.data(decodedBytes); + } + + catch (IllegalArgumentException e) { + logger.warn("Failed to decode Base64 data, storing as string", e); + mediaBuilder.data(dataString); + } + } + + else { + // Handle URL/URI data + try { + mediaBuilder.data(URI.create(dataString)); + } + + catch (IllegalArgumentException e) { + // Not a valid URI, store as string + mediaBuilder.data(dataString); + } + } + } + + else if (dataElement.isJsonArray()) { + // For backward compatibility - handle byte array data + // stored as JSON array + JsonArray dataArray = dataElement.getAsJsonArray(); + byte[] byteArray = new byte[dataArray.size()]; + for (int i = 0; i < dataArray.size(); i++) { + byteArray[i] = dataArray.get(i).getAsByte(); + } + mediaBuilder.data(byteArray); + } + } + + userMedia.add(mediaBuilder.build()); + } + } + } + return UserMessage.builder().text(content).metadata(metadata).media(userMedia).build(); + } + + else if (MessageType.SYSTEM.toString().equals(type)) { + return SystemMessage.builder().text(content).metadata(metadata).build(); + } + + else if (MessageType.TOOL.toString().equals(type)) { + // Extract tool responses + List toolResponses = new ArrayList<>(); + if (json.has("toolResponses") && json.get("toolResponses").isJsonArray()) { + JsonArray responseArray = json.getAsJsonArray("toolResponses"); + for (JsonElement responseElement : responseArray) { + JsonObject responseJson = responseElement.getAsJsonObject(); + + String id = responseJson.has("id") ? responseJson.get("id").getAsString() : ""; + String name = responseJson.has("name") ? responseJson.get("name").getAsString() : ""; + String responseData = responseJson.has("responseData") + ? responseJson.get("responseData").getAsString() : ""; + + toolResponses.add(new ToolResponseMessage.ToolResponse(id, name, responseData)); + } + } + + return ToolResponseMessage.builder().responses(toolResponses).metadata(metadata).build(); + } + + // For unknown message types, return a generic UserMessage + logger.warn("Unknown message type: {}, returning generic UserMessage", type); + return UserMessage.builder().text(content).metadata(metadata).build(); + } + + /** + * Inner static builder class for constructing instances of {@link RedisChatMemory}. + */ + public static class Builder { + + private JedisPooled jedisClient; + + private String indexName = RedisChatMemoryConfig.DEFAULT_INDEX_NAME; + + private String keyPrefix = RedisChatMemoryConfig.DEFAULT_KEY_PREFIX; + + private boolean initializeSchema = true; + + private long timeToLiveSeconds = -1; + + private int maxConversationIds = 10; + + private int maxMessagesPerConversation = 100; + + private List> metadataFields; + + /** + * Sets the JedisPooled client. + * @param jedisClient the JedisPooled client to use + * @return this builder + */ + public Builder jedisClient(final JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + /** + * Sets the index name. + * @param indexName the index name to use + * @return this builder + */ + public Builder indexName(final String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets the key prefix. + * @param keyPrefix the key prefix to use + * @return this builder + */ + public Builder keyPrefix(final String keyPrefix) { + this.keyPrefix = keyPrefix; + return this; + } + + /** + * Sets whether to initialize the schema. + * @param initializeSchema whether to initialize the schema + * @return this builder + */ + public Builder initializeSchema(final boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + /** + * Sets the time to live in seconds for messages stored in Redis. + * @param timeToLiveSeconds the time to live in seconds (use -1 for no expiration) + * @return this builder + */ + public Builder ttlSeconds(final long timeToLiveSeconds) { + this.timeToLiveSeconds = timeToLiveSeconds; + return this; + } + + /** + * Sets the time to live duration for messages stored in Redis. + * @param timeToLive the time to live duration (null for no expiration) + * @return this builder + */ + public Builder timeToLive(final Duration timeToLive) { + if (timeToLive != null) { + this.timeToLiveSeconds = timeToLive.getSeconds(); + } + + else { + this.timeToLiveSeconds = -1; + } + return this; + } + + /** + * Sets the maximum number of conversation IDs to return. + * @param maxConversationIds the maximum number of conversation IDs + * @return this builder + */ + public Builder maxConversationIds(final int maxConversationIds) { + this.maxConversationIds = maxConversationIds; + return this; + } + + /** + * Sets the maximum number of messages per conversation to return. + * @param maxMessagesPerConversation the maximum number of messages per + * conversation + * @return this builder + */ + public Builder maxMessagesPerConversation(final int maxMessagesPerConversation) { + this.maxMessagesPerConversation = maxMessagesPerConversation; + return this; + } + + /** + * Sets the metadata field definitions for proper indexing. Format is compatible + * with RedisVL schema format. + * @param metadataFields list of field definitions + * @return this builder + */ + public Builder metadataFields(List> metadataFields) { + this.metadataFields = metadataFields; + return this; + } + + /** + * Builds and returns an instance of {@link RedisChatMemoryRepository}. + * @return a new {@link RedisChatMemoryRepository} instance + */ + public RedisChatMemoryRepository build() { + Assert.notNull(this.jedisClient, "JedisClient must not be null"); + + RedisChatMemoryConfig config = new RedisChatMemoryConfig.Builder().jedisClient(this.jedisClient) + .indexName(this.indexName) + .keyPrefix(this.keyPrefix) + .initializeSchema(this.initializeSchema) + .timeToLive(Duration.ofSeconds(this.timeToLiveSeconds)) + .maxConversationIds(this.maxConversationIds) + .maxMessagesPerConversation(this.maxMessagesPerConversation) + .metadataFields(this.metadataFields) + .build(); + + return new RedisChatMemoryRepository(config); + } + + } + +} diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryAdvancedQueryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryAdvancedQueryIT.java new file mode 100644 index 00000000000..3bf39b486c8 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryAdvancedQueryIT.java @@ -0,0 +1,549 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.repository.redis.AdvancedChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository advanced query capabilities. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryAdvancedQueryIT { + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + @Test + void shouldFindMessagesByType_singleConversation() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + + // Clear any existing test data + chatMemory.findConversationIds().forEach(chatMemory::clear); + + String conversationId = "test-find-by-type"; + + // Add various message types to a single conversation + chatMemory.add(conversationId, new SystemMessage("System message 1")); + chatMemory.add(conversationId, new UserMessage("User message 1")); + chatMemory.add(conversationId, new AssistantMessage("Assistant message 1")); + chatMemory.add(conversationId, new UserMessage("User message 2")); + chatMemory.add(conversationId, new AssistantMessage("Assistant message 2")); + chatMemory.add(conversationId, new SystemMessage("System message 2")); + + // Test finding by USER type + List userMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.USER, 10); + + assertThat(userMessages).hasSize(2); + assertThat(userMessages.get(0).message().getText()).isEqualTo("User message 1"); + assertThat(userMessages.get(1).message().getText()).isEqualTo("User message 2"); + assertThat(userMessages.get(0).conversationId()).isEqualTo(conversationId); + assertThat(userMessages.get(1).conversationId()).isEqualTo(conversationId); + + // Test finding by SYSTEM type + List systemMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.SYSTEM, 10); + + assertThat(systemMessages).hasSize(2); + assertThat(systemMessages.get(0).message().getText()).isEqualTo("System message 1"); + assertThat(systemMessages.get(1).message().getText()).isEqualTo("System message 2"); + + // Test finding by ASSISTANT type + List assistantMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.ASSISTANT, 10); + + assertThat(assistantMessages).hasSize(2); + assertThat(assistantMessages.get(0).message().getText()).isEqualTo("Assistant message 1"); + assertThat(assistantMessages.get(1).message().getText()).isEqualTo("Assistant message 2"); + + // Test finding by TOOL type (should be empty) + List toolMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.TOOL, 10); + + assertThat(toolMessages).isEmpty(); + }); + } + + @Test + void shouldFindMessagesByType_multipleConversations() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId1 = "conv-1-" + UUID.randomUUID(); + String conversationId2 = "conv-2-" + UUID.randomUUID(); + + // Add messages to first conversation + chatMemory.add(conversationId1, new UserMessage("User in conv 1")); + chatMemory.add(conversationId1, new AssistantMessage("Assistant in conv 1")); + chatMemory.add(conversationId1, new SystemMessage("System in conv 1")); + + // Add messages to second conversation + chatMemory.add(conversationId2, new UserMessage("User in conv 2")); + chatMemory.add(conversationId2, new AssistantMessage("Assistant in conv 2")); + chatMemory.add(conversationId2, new SystemMessage("System in conv 2")); + chatMemory.add(conversationId2, new UserMessage("Second user in conv 2")); + + // Find all USER messages across conversations + List userMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.USER, 10); + + assertThat(userMessages).hasSize(3); + + // Verify messages from both conversations are included + List conversationIds = userMessages.stream().map(msg -> msg.conversationId()).distinct().toList(); + + assertThat(conversationIds).containsExactlyInAnyOrder(conversationId1, conversationId2); + + // Count messages from each conversation + long conv1Count = userMessages.stream().filter(msg -> msg.conversationId().equals(conversationId1)).count(); + long conv2Count = userMessages.stream().filter(msg -> msg.conversationId().equals(conversationId2)).count(); + + assertThat(conv1Count).isEqualTo(1); + assertThat(conv2Count).isEqualTo(2); + }); + } + + @Test + void shouldRespectLimitParameter() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId = "test-limit-parameter"; + + // Add multiple messages of the same type + chatMemory.add(conversationId, new UserMessage("User message 1")); + chatMemory.add(conversationId, new UserMessage("User message 2")); + chatMemory.add(conversationId, new UserMessage("User message 3")); + chatMemory.add(conversationId, new UserMessage("User message 4")); + chatMemory.add(conversationId, new UserMessage("User message 5")); + + // Retrieve with a limit of 3 + List messages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.USER, 3); + + // Verify only 3 messages are returned + assertThat(messages).hasSize(3); + }); + } + + @Test + void shouldHandleToolMessages() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId = "test-tool-messages"; + + // Create a ToolResponseMessage + ToolResponseMessage.ToolResponse toolResponse = new ToolResponseMessage.ToolResponse("tool-1", "weather", + "{\"temperature\":\"22°C\"}"); + ToolResponseMessage toolMessage = ToolResponseMessage.builder().responses(List.of(toolResponse)).build(); + + // Add various message types + chatMemory.add(conversationId, new UserMessage("Weather query")); + chatMemory.add(conversationId, toolMessage); + chatMemory.add(conversationId, new AssistantMessage("It's 22°C")); + + // Find TOOL type messages + List toolMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.TOOL, 10); + + assertThat(toolMessages).hasSize(1); + assertThat(toolMessages.get(0).message()).isInstanceOf(ToolResponseMessage.class); + + ToolResponseMessage retrievedToolMessage = (ToolResponseMessage) toolMessages.get(0).message(); + assertThat(retrievedToolMessage.getResponses()).hasSize(1); + assertThat(retrievedToolMessage.getResponses().get(0).name()).isEqualTo("weather"); + }); + } + + @Test + void shouldReturnEmptyListWhenNoMessagesOfTypeExist() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + + // Clear any existing test data + chatMemory.findConversationIds().forEach(chatMemory::clear); + + String conversationId = "test-empty-type"; + + // Add only user and assistant messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi there")); + + // Search for system messages which don't exist + List systemMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByType(MessageType.SYSTEM, 10); + + // Verify an empty list is returned (not null) + assertThat(systemMessages).isNotNull().isEmpty(); + }); + } + + @Test + void shouldFindMessagesByContent() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId1 = "test-content-1"; + String conversationId2 = "test-content-2"; + + // Add messages with different content patterns + chatMemory.add(conversationId1, new UserMessage("I love programming in Java")); + chatMemory.add(conversationId1, new AssistantMessage("Java is a great programming language")); + chatMemory.add(conversationId2, new UserMessage("Python programming is fun")); + chatMemory.add(conversationId2, new AssistantMessage("Tell me about Spring Boot")); + chatMemory.add(conversationId1, new UserMessage("What about JavaScript programming?")); + + // Search for messages containing "programming" + List programmingMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("programming", 10); + + assertThat(programmingMessages).hasSize(4); + // Verify all messages contain "programming" + programmingMessages + .forEach(msg -> assertThat(msg.message().getText().toLowerCase()).contains("programming")); + + // Search for messages containing "Java" + List javaMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("Java", 10); + + assertThat(javaMessages).hasSize(2); // Only exact case matches + // Verify messages are from conversation 1 only + assertThat(javaMessages.stream().map(m -> m.conversationId()).distinct()).hasSize(1); + + // Search for messages containing "Spring" + List springMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("Spring", 10); + + assertThat(springMessages).hasSize(1); + assertThat(springMessages.get(0).message().getText()).contains("Spring Boot"); + + // Test with limit + List limitedMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("programming", 2); + + assertThat(limitedMessages).hasSize(2); + + // Clean up + chatMemory.clear(conversationId1); + chatMemory.clear(conversationId2); + }); + } + + @Test + void shouldFindMessagesByTimeRange() throws InterruptedException { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId1 = "test-time-1"; + String conversationId2 = "test-time-2"; + + // Record time before adding messages + long startTime = System.currentTimeMillis(); + Thread.sleep(10); // Small delay to ensure timestamps are different + + // Add messages to first conversation + chatMemory.add(conversationId1, new UserMessage("First message")); + Thread.sleep(10); + chatMemory.add(conversationId1, new AssistantMessage("Second message")); + Thread.sleep(10); + + long midTime = System.currentTimeMillis(); + Thread.sleep(10); + + // Add messages to second conversation + chatMemory.add(conversationId2, new UserMessage("Third message")); + Thread.sleep(10); + chatMemory.add(conversationId2, new AssistantMessage("Fourth message")); + Thread.sleep(10); + + long endTime = System.currentTimeMillis(); + + // Test finding messages in full time range across all conversations + List allMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByTimeRange(null, java.time.Instant.ofEpochMilli(startTime), + java.time.Instant.ofEpochMilli(endTime), 10); + + assertThat(allMessages).hasSize(4); + + // Test finding messages in first half of time range + List firstHalfMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByTimeRange(null, java.time.Instant.ofEpochMilli(startTime), + java.time.Instant.ofEpochMilli(midTime), 10); + + assertThat(firstHalfMessages).hasSize(2); + assertThat(firstHalfMessages.stream().allMatch(m -> m.conversationId().equals(conversationId1))).isTrue(); + + // Test finding messages in specific conversation within time range + List conv2Messages = ((AdvancedChatMemoryRepository) chatMemory) + .findByTimeRange(conversationId2, java.time.Instant.ofEpochMilli(startTime), + java.time.Instant.ofEpochMilli(endTime), 10); + + assertThat(conv2Messages).hasSize(2); + assertThat(conv2Messages.stream().allMatch(m -> m.conversationId().equals(conversationId2))).isTrue(); + + // Test with limit + List limitedTimeMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByTimeRange(null, java.time.Instant.ofEpochMilli(startTime), + java.time.Instant.ofEpochMilli(endTime), 2); + + assertThat(limitedTimeMessages).hasSize(2); + + // Clean up + chatMemory.clear(conversationId1); + chatMemory.clear(conversationId2); + }); + } + + @Test + void shouldFindMessagesByMetadata() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId = "test-metadata"; + + // Create messages with different metadata + UserMessage userMsg1 = new UserMessage("User message with metadata"); + userMsg1.getMetadata().put("priority", "high"); + userMsg1.getMetadata().put("category", "question"); + userMsg1.getMetadata().put("score", 95); + + AssistantMessage assistantMsg = new AssistantMessage("Assistant response"); + assistantMsg.getMetadata().put("model", "gpt-4"); + assistantMsg.getMetadata().put("confidence", 0.95); + assistantMsg.getMetadata().put("category", "answer"); + + UserMessage userMsg2 = new UserMessage("Another user message"); + userMsg2.getMetadata().put("priority", "low"); + userMsg2.getMetadata().put("category", "question"); + userMsg2.getMetadata().put("score", 75); + + // Add messages + chatMemory.add(conversationId, userMsg1); + chatMemory.add(conversationId, assistantMsg); + chatMemory.add(conversationId, userMsg2); + + // Give Redis time to index the documents + Thread.sleep(100); + + // Test finding by string metadata + List highPriorityMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("priority", "high", 10); + + assertThat(highPriorityMessages).hasSize(1); + assertThat(highPriorityMessages.get(0).message().getText()).isEqualTo("User message with metadata"); + + // Test finding by category + List questionMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("category", "question", 10); + + assertThat(questionMessages).hasSize(2); + + // Test finding by numeric metadata + List highScoreMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("score", 95, 10); + + assertThat(highScoreMessages).hasSize(1); + assertThat(highScoreMessages.get(0).message().getMetadata().get("score")).isEqualTo(95.0); + + // Test finding by double metadata + List confidentMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("confidence", 0.95, 10); + + assertThat(confidentMessages).hasSize(1); + assertThat(confidentMessages.get(0).message().getMessageType()).isEqualTo(MessageType.ASSISTANT); + + // Test with non-existent metadata + List nonExistentMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("nonexistent", "value", 10); + + assertThat(nonExistentMessages).isEmpty(); + + // Clean up + chatMemory.clear(conversationId); + }); + } + + @Test + void shouldExecuteCustomQuery() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId1 = "test-custom-1"; + String conversationId2 = "test-custom-2"; + + // Add various messages + UserMessage userMsg = new UserMessage("I need help with Redis"); + userMsg.getMetadata().put("urgent", "true"); + + chatMemory.add(conversationId1, userMsg); + chatMemory.add(conversationId1, new AssistantMessage("I can help you with Redis")); + chatMemory.add(conversationId2, new UserMessage("Tell me about Spring")); + chatMemory.add(conversationId2, new SystemMessage("System initialized")); + + // Test custom query for USER messages containing "Redis" + String customQuery = "@type:USER @content:Redis"; + List redisUserMessages = ((AdvancedChatMemoryRepository) chatMemory) + .executeQuery(customQuery, 10); + + assertThat(redisUserMessages).hasSize(1); + assertThat(redisUserMessages.get(0).message().getText()).contains("Redis"); + assertThat(redisUserMessages.get(0).message().getMessageType()).isEqualTo(MessageType.USER); + + // Test custom query for all messages in a specific conversation + // Note: conversation_id is a TAG field, so we need to escape special + // characters + String escapedConvId = conversationId1.replace("-", "\\-"); + String convQuery = "@conversation_id:{" + escapedConvId + "}"; + List conv1Messages = ((AdvancedChatMemoryRepository) chatMemory) + .executeQuery(convQuery, 10); + + assertThat(conv1Messages).hasSize(2); + assertThat(conv1Messages.stream().allMatch(m -> m.conversationId().equals(conversationId1))).isTrue(); + + // Test complex query combining type and content + String complexQuery = "(@type:USER | @type:ASSISTANT) @content:Redis"; + List complexResults = ((AdvancedChatMemoryRepository) chatMemory) + .executeQuery(complexQuery, 10); + + assertThat(complexResults).hasSize(2); + + // Test with limit + List limitedResults = ((AdvancedChatMemoryRepository) chatMemory) + .executeQuery("*", 2); + + assertThat(limitedResults).hasSize(2); + + // Clean up + chatMemory.clear(conversationId1); + chatMemory.clear(conversationId2); + }); + } + + @Test + void shouldHandleSpecialCharactersInQueries() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId = "test-special-chars"; + + // Add messages with special characters + chatMemory.add(conversationId, new UserMessage("What is 2+2?")); + chatMemory.add(conversationId, new AssistantMessage("The answer is: 4")); + chatMemory.add(conversationId, new UserMessage("Tell me about C++")); + + // Test finding content with special characters + List plusMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("C++", 10); + + assertThat(plusMessages).hasSize(1); + assertThat(plusMessages.get(0).message().getText()).contains("C++"); + + // Test finding content with colon - search for "answer is" instead + List colonMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("answer is", 10); + + assertThat(colonMessages).hasSize(1); + + // Clean up + chatMemory.clear(conversationId); + }); + } + + @Test + void shouldReturnEmptyListForNoMatches() { + this.contextRunner.run(context -> { + RedisChatMemoryRepository chatMemory = context.getBean(RedisChatMemoryRepository.class); + String conversationId = "test-no-matches"; + + // Add a simple message + chatMemory.add(conversationId, new UserMessage("Hello world")); + + // Test content that doesn't exist + List noContentMatch = ((AdvancedChatMemoryRepository) chatMemory) + .findByContent("nonexistent", 10); + assertThat(noContentMatch).isEmpty(); + + // Test time range with no messages + List noTimeMatch = ((AdvancedChatMemoryRepository) chatMemory) + .findByTimeRange(conversationId, java.time.Instant.now().plusSeconds(3600), // Future + // time + java.time.Instant.now().plusSeconds(7200), // Even more future + 10); + assertThat(noTimeMatch).isEmpty(); + + // Test metadata that doesn't exist + List noMetadataMatch = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("nonexistent", "value", 10); + assertThat(noMetadataMatch).isEmpty(); + + // Test custom query with no matches + List noQueryMatch = ((AdvancedChatMemoryRepository) chatMemory) + .executeQuery("@type:FUNCTION", 10); + assertThat(noQueryMatch).isEmpty(); + + // Clean up + chatMemory.clear(conversationId); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + // Define metadata fields for proper indexing + List> metadataFields = List.of(Map.of("name", "priority", "type", "tag"), + Map.of("name", "category", "type", "tag"), Map.of("name", "score", "type", "numeric"), + Map.of("name", "confidence", "type", "numeric"), Map.of("name", "model", "type", "tag"), + Map.of("name", "urgent", "type", "tag")); + + // Use a unique index name to avoid conflicts with metadata schema + String uniqueIndexName = "test-adv-app-" + System.currentTimeMillis(); + + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName(uniqueIndexName) + .metadataFields(metadataFields) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryErrorHandlingIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryErrorHandlingIT.java new file mode 100644 index 00000000000..94ad0aa38c8 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryErrorHandlingIT.java @@ -0,0 +1,333 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.exceptions.JedisConnectionException; + +import java.time.Duration; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.stream.Collectors; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatNoException; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Integration tests for RedisChatMemoryRepository focused on error handling scenarios. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryErrorHandlingIT { + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemoryRepository chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-error-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldHandleInvalidConversationId() { + this.contextRunner.run(context -> { + // Using null conversation ID + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> chatMemory.add(null, new UserMessage("Test message"))) + .withMessageContaining("Conversation ID must not be null"); + + // Using empty conversation ID + UserMessage message = new UserMessage("Test message"); + assertThatCode(() -> chatMemory.add("", message)).doesNotThrowAnyException(); + + // Reading with null conversation ID + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> chatMemory.get(null, 10)) + .withMessageContaining("Conversation ID must not be null"); + + // Reading with non-existent conversation ID should return empty list + List messages = chatMemory.get("non-existent-id", 10); + assertThat(messages).isNotNull().isEmpty(); + + // Clearing with null conversation ID + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> chatMemory.clear(null)) + .withMessageContaining("Conversation ID must not be null"); + + // Clearing non-existent conversation should not throw exception + assertThatCode(() -> chatMemory.clear("non-existent-id")).doesNotThrowAnyException(); + }); + } + + @Test + void shouldHandleInvalidMessageParameters() { + this.contextRunner.run(context -> { + String conversationId = UUID.randomUUID().toString(); + + // Null message + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> chatMemory.add(conversationId, (Message) null)) + .withMessageContaining("Message must not be null"); + + // Null message list + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(() -> chatMemory.add(conversationId, (List) null)) + .withMessageContaining("Messages must not be null"); + + // Empty message list should not throw exception + assertThatCode(() -> chatMemory.add(conversationId, List.of())).doesNotThrowAnyException(); + + // Message with empty content (not null - which is not allowed) + UserMessage emptyContentMessage = UserMessage.builder().text("").build(); + + assertThatCode(() -> chatMemory.add(conversationId, emptyContentMessage)).doesNotThrowAnyException(); + + // Message with empty metadata + UserMessage userMessage = UserMessage.builder().text("Hello").build(); + assertThatCode(() -> chatMemory.add(conversationId, userMessage)).doesNotThrowAnyException(); + }); + } + + @Test + void shouldHandleTimeToLive() { + this.contextRunner.run(context -> { + // Create chat memory with short TTL + RedisChatMemoryRepository ttlChatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-ttl-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofSeconds(1)) + .build(); + + String conversationId = "ttl-test-conversation"; + UserMessage message = new UserMessage("This message will expire soon"); + + // Add a message + ttlChatMemory.add(conversationId, message); + + // Immediately verify message exists + List messages = ttlChatMemory.get(conversationId, 10); + assertThat(messages).hasSize(1); + + // Wait for TTL to expire + Thread.sleep(1500); + + // After TTL expiry, message should be gone + List expiredMessages = ttlChatMemory.get(conversationId, 10); + assertThat(expiredMessages).isEmpty(); + }); + } + + @Test + void shouldHandleConnectionFailureGracefully() { + this.contextRunner.run(context -> { + // Using a connection to an invalid Redis server should throw a connection + // exception + assertThatExceptionOfType(JedisConnectionException.class).isThrownBy(() -> { + // Create a JedisPooled with a connection timeout to make the test faster + JedisPooled badConnection = new JedisPooled("localhost", 54321); + // Attempt an operation that would require Redis connection + badConnection.ping(); + }); + }); + } + + @Test + void shouldHandleEdgeCaseConversationIds() { + this.contextRunner.run(context -> { + // Test with a simple conversation ID first to verify basic functionality + String simpleId = "simple-test-id"; + UserMessage simpleMessage = new UserMessage("Simple test message"); + chatMemory.add(simpleId, simpleMessage); + + List simpleMessages = chatMemory.get(simpleId, 10); + assertThat(simpleMessages).hasSize(1); + assertThat(simpleMessages.get(0).getText()).isEqualTo("Simple test message"); + + // Test with conversation IDs containing special characters + String specialCharsId = "test_conversation_with_special_chars_123"; + String specialMessage = "Message with special character conversation ID"; + UserMessage message = new UserMessage(specialMessage); + + // Add message with special chars ID + chatMemory.add(specialCharsId, message); + + // Verify that message can be retrieved + List specialCharMessages = chatMemory.get(specialCharsId, 10); + assertThat(specialCharMessages).hasSize(1); + assertThat(specialCharMessages.get(0).getText()).isEqualTo(specialMessage); + + // Test with non-alphanumeric characters in ID + String complexId = "test-with:complex@chars#123"; + String complexMessage = "Message with complex ID"; + UserMessage complexIdMessage = new UserMessage(complexMessage); + + // Add and retrieve message with complex ID + chatMemory.add(complexId, complexIdMessage); + List complexIdMessages = chatMemory.get(complexId, 10); + assertThat(complexIdMessages).hasSize(1); + assertThat(complexIdMessages.get(0).getText()).isEqualTo(complexMessage); + + // Test with long IDs + StringBuilder longIdBuilder = new StringBuilder(); + for (int i = 0; i < 50; i++) { + longIdBuilder.append("a"); + } + String longId = longIdBuilder.toString(); + String longIdMessageText = "Message with long conversation ID"; + UserMessage longIdMessage = new UserMessage(longIdMessageText); + + // Add and retrieve message with long ID + chatMemory.add(longId, longIdMessage); + List longIdMessages = chatMemory.get(longId, 10); + assertThat(longIdMessages).hasSize(1); + assertThat(longIdMessages.get(0).getText()).isEqualTo(longIdMessageText); + }); + } + + @Test + void shouldHandleConcurrentAccess() { + this.contextRunner.run(context -> { + String conversationId = "concurrent-access-test-" + UUID.randomUUID(); + + // Clear any existing data for this conversation + chatMemory.clear(conversationId); + + // Define thread setup for concurrent access + int threadCount = 3; + int messagesPerThread = 4; + int totalExpectedMessages = threadCount * messagesPerThread; + + // Track all messages created for verification + Set expectedMessageTexts = new HashSet<>(); + + // Create and start threads that concurrently add messages + Thread[] threads = new Thread[threadCount]; + CountDownLatch latch = new CountDownLatch(threadCount); // For synchronized + // start + + for (int i = 0; i < threadCount; i++) { + final int threadId = i; + threads[i] = new Thread(() -> { + try { + latch.countDown(); + latch.await(); // Wait for all threads to be ready + + for (int j = 0; j < messagesPerThread; j++) { + String messageText = String.format("Message %d from thread %d", j, threadId); + expectedMessageTexts.add(messageText); + UserMessage message = new UserMessage(messageText); + chatMemory.add(conversationId, message); + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }); + threads[i].start(); + } + + // Wait for all threads to complete + for (Thread thread : threads) { + thread.join(); + } + + // Allow a short delay for Redis to process all operations + Thread.sleep(500); + + // Retrieve all messages (including extras to make sure we get everything) + List messages = chatMemory.get(conversationId, totalExpectedMessages + 5); + + // We don't check exact message count as Redis async operations might result + // in slight variations + // Just verify the right message format is present + List actualMessageTexts = messages.stream().map(Message::getText).collect(Collectors.toList()); + + // Check that we have messages from each thread + for (int i = 0; i < threadCount; i++) { + final int threadId = i; + assertThat(actualMessageTexts.stream().filter(text -> text.endsWith("from thread " + threadId)).count()) + .isGreaterThan(0); + } + + // Verify message format + for (Message msg : messages) { + assertThat(msg).isInstanceOf(UserMessage.class); + assertThat(msg.getText()).containsPattern("Message \\d from thread \\d"); + } + + // Order check - messages might be in different order than creation, + // but order should be consistent between retrievals + List messagesAgain = chatMemory.get(conversationId, totalExpectedMessages + 5); + for (int i = 0; i < messages.size(); i++) { + assertThat(messagesAgain.get(i).getText()).isEqualTo(messages.get(i).getText()); + } + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-error-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryIT.java new file mode 100644 index 00000000000..591c0afa024 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryIT.java @@ -0,0 +1,228 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.time.Duration; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository using Redis Stack TestContainer. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryIT { + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemoryRepository chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + // Create JedisPooled directly with container properties for more reliable + // connection + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + chatMemory.clear("test-conversation"); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldStoreAndRetrieveMessages() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi there!")); + chatMemory.add(conversationId, new UserMessage("How are you?")); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("Hello"); + assertThat(messages.get(1).getText()).isEqualTo("Hi there!"); + assertThat(messages.get(2).getText()).isEqualTo("How are you?"); + }); + } + + @Test + void shouldRespectMessageLimit() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Message 1")); + chatMemory.add(conversationId, new AssistantMessage("Message 2")); + chatMemory.add(conversationId, new UserMessage("Message 3")); + + // Retrieve limited messages + List messages = chatMemory.get(conversationId, 2); + + assertThat(messages).hasSize(2); + }); + } + + @Test + void shouldClearConversation() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi")); + + // Clear conversation + chatMemory.clear(conversationId); + + // Verify messages are cleared + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).isEmpty(); + }); + } + + @Test + void shouldHandleBatchMessageAddition() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + List messageBatch = List.of(new UserMessage("Message 1"), // + new AssistantMessage("Response 1"), // + new UserMessage("Message 2"), // + new AssistantMessage("Response 2") // + ); + + // Add batch of messages + chatMemory.add(conversationId, messageBatch); + + // Verify all messages were stored + List retrievedMessages = chatMemory.get(conversationId, 10); + assertThat(retrievedMessages).hasSize(4); + }); + } + + @Test + void shouldHandleTimeToLive() throws InterruptedException { + this.contextRunner.run(context -> { + RedisChatMemoryRepository shortTtlMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-ttl-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofSeconds(2)) + .keyPrefix("short-lived:") + .build(); + + String conversationId = "test-conversation"; + shortTtlMemory.add(conversationId, new UserMessage("This should expire")); + + // Verify message exists + assertThat(shortTtlMemory.get(conversationId, 1)).hasSize(1); + + // Wait for TTL to expire + Thread.sleep(2000); + + // Verify message is gone + assertThat(shortTtlMemory.get(conversationId, 1)).isEmpty(); + }); + } + + @Test + void shouldMaintainMessageOrder() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + // Add messages with minimal delay to test timestamp ordering + chatMemory.add(conversationId, new UserMessage("First")); + Thread.sleep(10); + chatMemory.add(conversationId, new AssistantMessage("Second")); + Thread.sleep(10); + chatMemory.add(conversationId, new UserMessage("Third")); + + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("First"); + assertThat(messages.get(1).getText()).isEqualTo("Second"); + assertThat(messages.get(2).getText()).isEqualTo("Third"); + }); + } + + @Test + void shouldHandleMultipleConversations() { + this.contextRunner.run(context -> { + String conv1 = "conversation-1"; + String conv2 = "conversation-2"; + + chatMemory.add(conv1, new UserMessage("Conv1 Message")); + chatMemory.add(conv2, new UserMessage("Conv2 Message")); + + List conv1Messages = chatMemory.get(conv1, 10); + List conv2Messages = chatMemory.get(conv2, 10); + + assertThat(conv1Messages).hasSize(1); + assertThat(conv2Messages).hasSize(1); + assertThat(conv1Messages.get(0).getText()).isEqualTo("Conv1 Message"); + assertThat(conv2Messages.get(0).getText()).isEqualTo("Conv2 Message"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofMinutes(5)) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMediaIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMediaIT.java new file mode 100644 index 00000000000..2e18e982c28 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMediaIT.java @@ -0,0 +1,685 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.content.Media; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.ByteArrayResource; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.springframework.util.MimeType; +import redis.clients.jedis.JedisPooled; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository to verify proper handling of Media + * content. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryMediaIT { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemoryMediaIT.class); + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)) + .withExposedPorts(6379); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemoryRepository chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + // Create JedisPooled directly with container properties for reliable connection + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-media-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + // Clear any existing data + for (String conversationId : chatMemory.findConversationIds()) { + chatMemory.clear(conversationId); + } + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldStoreAndRetrieveUserMessageWithUriMedia() { + this.contextRunner.run(context -> { + // Create a URI media object + URI mediaUri = URI.create("https://example.com/image.png"); + Media imageMedia = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) + .data(mediaUri) + .id("test-image-id") + .name("test-image") + .build(); + + // Create a user message with the media + UserMessage userMessage = UserMessage.builder() + .text("Message with image") + .media(imageMedia) + .metadata(Map.of("test-key", "test-value")) + .build(); + + // Store the message + chatMemory.add("test-conversation", userMessage); + + // Retrieve the message + List messages = chatMemory.get("test-conversation", 10); + + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(UserMessage.class); + + UserMessage retrievedMessage = (UserMessage) messages.get(0); + assertThat(retrievedMessage.getText()).isEqualTo("Message with image"); + assertThat(retrievedMessage.getMetadata()).containsEntry("test-key", "test-value"); + + // Verify media content + assertThat(retrievedMessage.getMedia()).hasSize(1); + Media retrievedMedia = retrievedMessage.getMedia().get(0); + assertThat(retrievedMedia.getMimeType()).isEqualTo(Media.Format.IMAGE_PNG); + assertThat(retrievedMedia.getId()).isEqualTo("test-image-id"); + assertThat(retrievedMedia.getName()).isEqualTo("test-image"); + assertThat(retrievedMedia.getData()).isEqualTo(mediaUri.toString()); + }); + } + + @Test + void shouldStoreAndRetrieveAssistantMessageWithByteArrayMedia() { + this.contextRunner.run(context -> { + // Create a byte array media object + byte[] imageData = new byte[] { 0x00, 0x01, 0x02, 0x03, 0x04 }; + Media byteArrayMedia = Media.builder() + .mimeType(Media.Format.IMAGE_JPEG) + .data(imageData) + .id("test-jpeg-id") + .name("test-jpeg") + .build(); + + // Create a list of tool calls + List toolCalls = List + .of(new AssistantMessage.ToolCall("tool1", "function", "testFunction", "{\"param\":\"value\"}")); + + // Create an assistant message with media and tool calls + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("Response with image") + .properties(Map.of("assistant-key", "assistant-value")) + .toolCalls(toolCalls) + .media(List.of(byteArrayMedia)) + .build(); + + // Store the message + chatMemory.add("test-conversation", assistantMessage); + + // Retrieve the message + List messages = chatMemory.get("test-conversation", 10); + + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(AssistantMessage.class); + + AssistantMessage retrievedMessage = (AssistantMessage) messages.get(0); + assertThat(retrievedMessage.getText()).isEqualTo("Response with image"); + assertThat(retrievedMessage.getMetadata()).containsEntry("assistant-key", "assistant-value"); + + // Verify tool calls + assertThat(retrievedMessage.getToolCalls()).hasSize(1); + AssistantMessage.ToolCall retrievedToolCall = retrievedMessage.getToolCalls().get(0); + assertThat(retrievedToolCall.id()).isEqualTo("tool1"); + assertThat(retrievedToolCall.type()).isEqualTo("function"); + assertThat(retrievedToolCall.name()).isEqualTo("testFunction"); + assertThat(retrievedToolCall.arguments()).isEqualTo("{\"param\":\"value\"}"); + + // Verify media content + assertThat(retrievedMessage.getMedia()).hasSize(1); + Media retrievedMedia = retrievedMessage.getMedia().get(0); + assertThat(retrievedMedia.getMimeType()).isEqualTo(Media.Format.IMAGE_JPEG); + assertThat(retrievedMedia.getId()).isEqualTo("test-jpeg-id"); + assertThat(retrievedMedia.getName()).isEqualTo("test-jpeg"); + assertThat(retrievedMedia.getDataAsByteArray()).isEqualTo(imageData); + }); + } + + @Test + void shouldStoreAndRetrieveMultipleMessagesWithDifferentMediaTypes() { + this.contextRunner.run(context -> { + // Create media objects with different types + Media pngMedia = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) + .data(URI.create("https://example.com/image.png")) + .id("png-id") + .build(); + + Media jpegMedia = Media.builder() + .mimeType(Media.Format.IMAGE_JPEG) + .data(new byte[] { 0x10, 0x20, 0x30, 0x40 }) + .id("jpeg-id") + .build(); + + Media pdfMedia = Media.builder() + .mimeType(Media.Format.DOC_PDF) + .data(new ByteArrayResource("PDF content".getBytes())) + .id("pdf-id") + .build(); + + // Create messages + UserMessage userMessage1 = UserMessage.builder().text("Message with PNG").media(pngMedia).build(); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("Response with JPEG") + .properties(Map.of()) + .toolCalls(List.of()) + .media(List.of(jpegMedia)) + .build(); + + UserMessage userMessage2 = UserMessage.builder().text("Message with PDF").media(pdfMedia).build(); + + // Store all messages + chatMemory.add("media-conversation", List.of(userMessage1, assistantMessage, userMessage2)); + + // Retrieve the messages + List messages = chatMemory.get("media-conversation", 10); + + assertThat(messages).hasSize(3); + + // Verify first user message with PNG + UserMessage retrievedUser1 = (UserMessage) messages.get(0); + assertThat(retrievedUser1.getText()).isEqualTo("Message with PNG"); + assertThat(retrievedUser1.getMedia()).hasSize(1); + assertThat(retrievedUser1.getMedia().get(0).getMimeType()).isEqualTo(Media.Format.IMAGE_PNG); + assertThat(retrievedUser1.getMedia().get(0).getId()).isEqualTo("png-id"); + assertThat(retrievedUser1.getMedia().get(0).getData()).isEqualTo("https://example.com/image.png"); + + // Verify assistant message with JPEG + AssistantMessage retrievedAssistant = (AssistantMessage) messages.get(1); + assertThat(retrievedAssistant.getText()).isEqualTo("Response with JPEG"); + assertThat(retrievedAssistant.getMedia()).hasSize(1); + assertThat(retrievedAssistant.getMedia().get(0).getMimeType()).isEqualTo(Media.Format.IMAGE_JPEG); + assertThat(retrievedAssistant.getMedia().get(0).getId()).isEqualTo("jpeg-id"); + assertThat(retrievedAssistant.getMedia().get(0).getDataAsByteArray()) + .isEqualTo(new byte[] { 0x10, 0x20, 0x30, 0x40 }); + + // Verify second user message with PDF + UserMessage retrievedUser2 = (UserMessage) messages.get(2); + assertThat(retrievedUser2.getText()).isEqualTo("Message with PDF"); + assertThat(retrievedUser2.getMedia()).hasSize(1); + assertThat(retrievedUser2.getMedia().get(0).getMimeType()).isEqualTo(Media.Format.DOC_PDF); + assertThat(retrievedUser2.getMedia().get(0).getId()).isEqualTo("pdf-id"); + // Data should be a byte array from the ByteArrayResource + assertThat(retrievedUser2.getMedia().get(0).getDataAsByteArray()).isEqualTo("PDF content".getBytes()); + }); + } + + @Test + void shouldStoreAndRetrieveMessageWithMultipleMedia() { + this.contextRunner.run(context -> { + // Create multiple media objects + Media textMedia = Media.builder() + .mimeType(Media.Format.DOC_TXT) + .data("This is text content".getBytes()) + .id("text-id") + .name("text-file") + .build(); + + Media imageMedia = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) + .data(URI.create("https://example.com/image.png")) + .id("image-id") + .name("image-file") + .build(); + + // Create a message with multiple media attachments + UserMessage userMessage = UserMessage.builder() + .text("Message with multiple attachments") + .media(textMedia, imageMedia) + .build(); + + // Store the message + chatMemory.add("multi-media-conversation", userMessage); + + // Retrieve the message + List messages = chatMemory.get("multi-media-conversation", 10); + + assertThat(messages).hasSize(1); + UserMessage retrievedMessage = (UserMessage) messages.get(0); + assertThat(retrievedMessage.getText()).isEqualTo("Message with multiple attachments"); + + // Verify multiple media contents + List retrievedMedia = retrievedMessage.getMedia(); + assertThat(retrievedMedia).hasSize(2); + + // The media should be retrieved in the same order + Media retrievedTextMedia = retrievedMedia.get(0); + assertThat(retrievedTextMedia.getMimeType()).isEqualTo(Media.Format.DOC_TXT); + assertThat(retrievedTextMedia.getId()).isEqualTo("text-id"); + assertThat(retrievedTextMedia.getName()).isEqualTo("text-file"); + assertThat(retrievedTextMedia.getDataAsByteArray()).isEqualTo("This is text content".getBytes()); + + Media retrievedImageMedia = retrievedMedia.get(1); + assertThat(retrievedImageMedia.getMimeType()).isEqualTo(Media.Format.IMAGE_PNG); + assertThat(retrievedImageMedia.getId()).isEqualTo("image-id"); + assertThat(retrievedImageMedia.getName()).isEqualTo("image-file"); + assertThat(retrievedImageMedia.getData()).isEqualTo("https://example.com/image.png"); + }); + } + + @Test + void shouldClearConversationWithMedia() { + this.contextRunner.run(context -> { + // Create a message with media + Media imageMedia = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) + .data(new byte[] { 0x01, 0x02, 0x03 }) + .id("test-clear-id") + .build(); + + UserMessage userMessage = UserMessage.builder().text("Message to be cleared").media(imageMedia).build(); + + // Store the message + String conversationId = "conversation-to-clear"; + chatMemory.add(conversationId, userMessage); + + // Verify it was stored + assertThat(chatMemory.get(conversationId, 10)).hasSize(1); + + // Clear the conversation + chatMemory.clear(conversationId); + + // Verify it was cleared + assertThat(chatMemory.get(conversationId, 10)).isEmpty(); + assertThat(chatMemory.findConversationIds()).doesNotContain(conversationId); + }); + } + + @Test + void shouldHandleLargeBinaryData() { + this.contextRunner.run(context -> { + // Create a larger binary payload (around 50KB) + byte[] largeImageData = new byte[50 * 1024]; + // Fill with a recognizable pattern for verification + for (int i = 0; i < largeImageData.length; i++) { + largeImageData[i] = (byte) (i % 256); + } + + // Create media with the large data + Media largeMedia = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) + .data(largeImageData) + .id("large-image-id") + .name("large-image.png") + .build(); + + // Create a message with large media + UserMessage userMessage = UserMessage.builder() + .text("Message with large image attachment") + .media(largeMedia) + .build(); + + // Store the message + String conversationId = "large-media-conversation"; + chatMemory.add(conversationId, userMessage); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify + assertThat(messages).hasSize(1); + UserMessage retrievedMessage = (UserMessage) messages.get(0); + assertThat(retrievedMessage.getMedia()).hasSize(1); + + // Verify the large binary data was preserved exactly + Media retrievedMedia = retrievedMessage.getMedia().get(0); + assertThat(retrievedMedia.getMimeType()).isEqualTo(Media.Format.IMAGE_PNG); + byte[] retrievedData = retrievedMedia.getDataAsByteArray(); + assertThat(retrievedData).hasSize(50 * 1024); + assertThat(retrievedData).isEqualTo(largeImageData); + }); + } + + @Test + void shouldHandleMediaWithEmptyOrNullValues() { + this.contextRunner.run(context -> { + // Create media with null or empty values where allowed + Media edgeCaseMedia1 = Media.builder() + .mimeType(Media.Format.IMAGE_PNG) // MimeType is required + .data(new byte[0]) // Empty byte array + .id(null) // No ID + .name("") // Empty name + .build(); + + // Second media with only required fields + Media edgeCaseMedia2 = Media.builder() + .mimeType(Media.Format.DOC_TXT) // Only required field + .data(new byte[0]) // Empty byte array instead of null + .build(); + + // Create message with these edge case media objects + UserMessage userMessage = UserMessage.builder() + .text("Edge case media test") + .media(edgeCaseMedia1, edgeCaseMedia2) + .build(); + + // Store the message + String conversationId = "edge-case-media"; + chatMemory.add(conversationId, userMessage); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify the message was stored and retrieved + assertThat(messages).hasSize(1); + UserMessage retrievedMessage = (UserMessage) messages.get(0); + + // Verify the media objects + List retrievedMedia = retrievedMessage.getMedia(); + assertThat(retrievedMedia).hasSize(2); + + // Check first media with empty/null values + Media firstMedia = retrievedMedia.get(0); + assertThat(firstMedia.getMimeType()).isEqualTo(Media.Format.IMAGE_PNG); + assertThat(firstMedia.getDataAsByteArray()).isNotNull().isEmpty(); + assertThat(firstMedia.getId()).isNull(); + assertThat(firstMedia.getName()).isEmpty(); + + // Check second media with only required field + Media secondMedia = retrievedMedia.get(1); + assertThat(secondMedia.getMimeType()).isEqualTo(Media.Format.DOC_TXT); + assertThat(secondMedia.getDataAsByteArray()).isNotNull().isEmpty(); + assertThat(secondMedia.getId()).isNull(); + assertThat(secondMedia.getName()).isNotNull(); + }); + } + + @Test + void shouldHandleComplexBinaryDataTypes() { + this.contextRunner.run(context -> { + // Create audio sample data (simple WAV header + sine wave) + byte[] audioData = createSampleAudioData(8000, 2); // 2 seconds of 8kHz audio + + // Create video sample data (mock MP4 data with recognizable pattern) + byte[] videoData = createSampleVideoData(10 * 1024); // 10KB mock video data + + // Create custom MIME types for specialized formats + MimeType customAudioType = new MimeType("audio", "wav"); + MimeType customVideoType = new MimeType("video", "mp4"); + + // Create media objects with the complex binary data + Media audioMedia = Media.builder() + .mimeType(customAudioType) + .data(audioData) + .id("audio-sample-id") + .name("audio-sample.wav") + .build(); + + Media videoMedia = Media.builder() + .mimeType(customVideoType) + .data(videoData) + .id("video-sample-id") + .name("video-sample.mp4") + .build(); + + // Create messages with the complex media + UserMessage userMessage = UserMessage.builder() + .text("Message with audio attachment") + .media(audioMedia) + .build(); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("Response with video attachment") + .properties(Map.of()) + .toolCalls(List.of()) + .media(List.of(videoMedia)) + .build(); + + // Store the messages + String conversationId = "complex-media-conversation"; + chatMemory.add(conversationId, List.of(userMessage, assistantMessage)); + + // Retrieve the messages + List messages = chatMemory.get(conversationId, 10); + + // Verify + assertThat(messages).hasSize(2); + + // Verify audio data in user message + UserMessage retrievedUserMessage = (UserMessage) messages.get(0); + assertThat(retrievedUserMessage.getText()).isEqualTo("Message with audio attachment"); + assertThat(retrievedUserMessage.getMedia()).hasSize(1); + + Media retrievedAudioMedia = retrievedUserMessage.getMedia().get(0); + assertThat(retrievedAudioMedia.getMimeType().toString()).isEqualTo(customAudioType.toString()); + assertThat(retrievedAudioMedia.getId()).isEqualTo("audio-sample-id"); + assertThat(retrievedAudioMedia.getName()).isEqualTo("audio-sample.wav"); + assertThat(retrievedAudioMedia.getDataAsByteArray()).isEqualTo(audioData); + + // Verify binary pattern data integrity + byte[] retrievedAudioData = retrievedAudioMedia.getDataAsByteArray(); + // Check RIFF header (first 4 bytes of WAV) + assertThat(Arrays.copyOfRange(retrievedAudioData, 0, 4)).isEqualTo(new byte[] { 'R', 'I', 'F', 'F' }); + + // Verify video data in assistant message + AssistantMessage retrievedAssistantMessage = (AssistantMessage) messages.get(1); + assertThat(retrievedAssistantMessage.getText()).isEqualTo("Response with video attachment"); + assertThat(retrievedAssistantMessage.getMedia()).hasSize(1); + + Media retrievedVideoMedia = retrievedAssistantMessage.getMedia().get(0); + assertThat(retrievedVideoMedia.getMimeType().toString()).isEqualTo(customVideoType.toString()); + assertThat(retrievedVideoMedia.getId()).isEqualTo("video-sample-id"); + assertThat(retrievedVideoMedia.getName()).isEqualTo("video-sample.mp4"); + assertThat(retrievedVideoMedia.getDataAsByteArray()).isEqualTo(videoData); + + // Verify the MP4 header pattern + byte[] retrievedVideoData = retrievedVideoMedia.getDataAsByteArray(); + // Check mock MP4 signature (first 4 bytes should be ftyp) + assertThat(Arrays.copyOfRange(retrievedVideoData, 4, 8)).isEqualTo(new byte[] { 'f', 't', 'y', 'p' }); + }); + } + + /** + * Creates a sample audio data byte array with WAV format. + * @param sampleRate Sample rate of the audio in Hz + * @param durationSeconds Duration of the audio in seconds + * @return Byte array containing a simple WAV file + */ + private byte[] createSampleAudioData(int sampleRate, int durationSeconds) { + // Calculate sizes + int headerSize = 44; // Standard WAV header size + int dataSize = sampleRate * durationSeconds; // 1 byte per sample, mono + int totalSize = headerSize + dataSize; + + byte[] audioData = new byte[totalSize]; + + // Write WAV header (RIFF chunk) + audioData[0] = 'R'; + audioData[1] = 'I'; + audioData[2] = 'F'; + audioData[3] = 'F'; + + // File size - 8 (4 bytes little endian) + int fileSizeMinus8 = totalSize - 8; + audioData[4] = (byte) (fileSizeMinus8 & 0xFF); + audioData[5] = (byte) ((fileSizeMinus8 >> 8) & 0xFF); + audioData[6] = (byte) ((fileSizeMinus8 >> 16) & 0xFF); + audioData[7] = (byte) ((fileSizeMinus8 >> 24) & 0xFF); + + // WAVE chunk + audioData[8] = 'W'; + audioData[9] = 'A'; + audioData[10] = 'V'; + audioData[11] = 'E'; + + // fmt chunk + audioData[12] = 'f'; + audioData[13] = 'm'; + audioData[14] = 't'; + audioData[15] = ' '; + + // fmt chunk size (16 for PCM) + audioData[16] = 16; + audioData[17] = 0; + audioData[18] = 0; + audioData[19] = 0; + + // Audio format (1 = PCM) + audioData[20] = 1; + audioData[21] = 0; + + // Channels (1 = mono) + audioData[22] = 1; + audioData[23] = 0; + + // Sample rate + audioData[24] = (byte) (sampleRate & 0xFF); + audioData[25] = (byte) ((sampleRate >> 8) & 0xFF); + audioData[26] = (byte) ((sampleRate >> 16) & 0xFF); + audioData[27] = (byte) ((sampleRate >> 24) & 0xFF); + + // Byte rate (SampleRate * NumChannels * BitsPerSample/8) + int byteRate = sampleRate * 1 * 8 / 8; + audioData[28] = (byte) (byteRate & 0xFF); + audioData[29] = (byte) ((byteRate >> 8) & 0xFF); + audioData[30] = (byte) ((byteRate >> 16) & 0xFF); + audioData[31] = (byte) ((byteRate >> 24) & 0xFF); + + // Block align (NumChannels * BitsPerSample/8) + audioData[32] = 1; + audioData[33] = 0; + + // Bits per sample + audioData[34] = 8; + audioData[35] = 0; + + // Data chunk + audioData[36] = 'd'; + audioData[37] = 'a'; + audioData[38] = 't'; + audioData[39] = 'a'; + + // Data size + audioData[40] = (byte) (dataSize & 0xFF); + audioData[41] = (byte) ((dataSize >> 8) & 0xFF); + audioData[42] = (byte) ((dataSize >> 16) & 0xFF); + audioData[43] = (byte) ((dataSize >> 24) & 0xFF); + + // Generate a simple sine wave for audio data + for (int i = 0; i < dataSize; i++) { + // Simple sine wave pattern (0-255) + audioData[headerSize + i] = (byte) (128 + 127 * Math.sin(2 * Math.PI * 440 * i / sampleRate)); + } + + return audioData; + } + + /** + * Creates sample video data with a mock MP4 structure. + * @param sizeBytes Size of the video data in bytes + * @return Byte array containing mock MP4 data + */ + private byte[] createSampleVideoData(int sizeBytes) { + byte[] videoData = new byte[sizeBytes]; + + // Write MP4 header + // First 4 bytes: size of the first atom + int firstAtomSize = 24; // Standard size for ftyp atom + videoData[0] = 0; + videoData[1] = 0; + videoData[2] = 0; + videoData[3] = (byte) firstAtomSize; + + // Next 4 bytes: ftyp (file type atom) + videoData[4] = 'f'; + videoData[5] = 't'; + videoData[6] = 'y'; + videoData[7] = 'p'; + + // Major brand (mp42) + videoData[8] = 'm'; + videoData[9] = 'p'; + videoData[10] = '4'; + videoData[11] = '2'; + + // Minor version + videoData[12] = 0; + videoData[13] = 0; + videoData[14] = 0; + videoData[15] = 1; + + // Compatible brands (mp42, mp41) + videoData[16] = 'm'; + videoData[17] = 'p'; + videoData[18] = '4'; + videoData[19] = '2'; + videoData[20] = 'm'; + videoData[21] = 'p'; + videoData[22] = '4'; + videoData[23] = '1'; + + // Fill the rest with a recognizable pattern + for (int i = firstAtomSize; i < sizeBytes; i++) { + // Create a repeating pattern with some variation + videoData[i] = (byte) ((i % 64) + ((i / 64) % 64)); + } + + return videoData; + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-media-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMessageTypesIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMessageTypesIT.java new file mode 100644 index 00000000000..5f5df123e34 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryMessageTypesIT.java @@ -0,0 +1,674 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository focusing on different message types. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryMessageTypesIT { + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemoryRepository chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + chatMemory.clear("test-conversation"); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldHandleAllMessageTypes() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Create messages of different types with various content + SystemMessage systemMessage = new SystemMessage("You are a helpful assistant"); + UserMessage userMessage = new UserMessage("What's the capital of France?"); + AssistantMessage assistantMessage = new AssistantMessage("The capital of France is Paris."); + + // Store each message type + chatMemory.add(conversationId, systemMessage); + chatMemory.add(conversationId, userMessage); + chatMemory.add(conversationId, assistantMessage); + + // Retrieve and verify messages + List messages = chatMemory.get(conversationId, 10); + + // Verify correct number of messages + assertThat(messages).hasSize(3); + + // Verify message order and content + assertThat(messages.get(0).getText()).isEqualTo("You are a helpful assistant"); + assertThat(messages.get(1).getText()).isEqualTo("What's the capital of France?"); + assertThat(messages.get(2).getText()).isEqualTo("The capital of France is Paris."); + + // Verify message types + assertThat(messages.get(0)).isInstanceOf(SystemMessage.class); + assertThat(messages.get(1)).isInstanceOf(UserMessage.class); + assertThat(messages.get(2)).isInstanceOf(AssistantMessage.class); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER", "Message from system,SYSTEM" }) + void shouldStoreAndRetrieveSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + String conversationId = UUID.randomUUID().toString(); + + // Create a message of the specified type + Message message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + case SYSTEM -> new SystemMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + // Store the message + chatMemory.add(conversationId, message); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + // Verify message was stored and retrieved correctly + assertThat(messages).hasSize(1); + Message retrievedMessage = messages.get(0); + + // Verify the message type + assertThat(retrievedMessage.getMessageType()).isEqualTo(messageType); + + // Verify the content + assertThat(retrievedMessage.getText()).isEqualTo(content + " - " + conversationId); + + // Verify the correct class type + switch (messageType) { + case ASSISTANT -> assertThat(retrievedMessage).isInstanceOf(AssistantMessage.class); + case USER -> assertThat(retrievedMessage).isInstanceOf(UserMessage.class); + case SYSTEM -> assertThat(retrievedMessage).isInstanceOf(SystemMessage.class); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + } + }); + } + + @Test + void shouldHandleSystemMessageWithMetadata() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation-system"; + + // Create a System message with metadata using builder + SystemMessage systemMessage = SystemMessage.builder() + .text("You are a specialized AI assistant for legal questions") + .metadata(Map.of("domain", "legal", "version", "2.0", "restricted", "true")) + .build(); + + // Store the message + chatMemory.add(conversationId, systemMessage); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + // Verify message count + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(SystemMessage.class); + + // Verify content + SystemMessage retrievedMessage = (SystemMessage) messages.get(0); + assertThat(retrievedMessage.getText()).isEqualTo("You are a specialized AI assistant for legal questions"); + + // Verify metadata is preserved + assertThat(retrievedMessage.getMetadata()).containsEntry("domain", "legal"); + assertThat(retrievedMessage.getMetadata()).containsEntry("version", "2.0"); + assertThat(retrievedMessage.getMetadata()).containsEntry("restricted", "true"); + }); + } + + @Test + void shouldHandleMultipleSystemMessages() { + this.contextRunner.run(context -> { + String conversationId = "multi-system-test"; + + // Create multiple system messages with different content + SystemMessage systemMessage1 = new SystemMessage("You are a helpful assistant"); + SystemMessage systemMessage2 = new SystemMessage("Always provide concise answers"); + SystemMessage systemMessage3 = new SystemMessage("Do not share personal information"); + + // Create a batch of system messages + List systemMessages = List.of(systemMessage1, systemMessage2, systemMessage3); + + // Store all messages at once + chatMemory.add(conversationId, systemMessages); + + // Retrieve messages + List retrievedMessages = chatMemory.get(conversationId, 10); + + // Verify all messages were stored and retrieved + assertThat(retrievedMessages).hasSize(3); + retrievedMessages.forEach(message -> assertThat(message).isInstanceOf(SystemMessage.class)); + + // Verify content + assertThat(retrievedMessages.get(0).getText()).isEqualTo(systemMessage1.getText()); + assertThat(retrievedMessages.get(1).getText()).isEqualTo(systemMessage2.getText()); + assertThat(retrievedMessages.get(2).getText()).isEqualTo(systemMessage3.getText()); + }); + } + + @Test + void shouldHandleMessageWithMetadata() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Create messages with metadata using builder + UserMessage userMessage = UserMessage.builder() + .text("Hello with metadata") + .metadata(Map.of("source", "web", "user_id", "12345")) + .build(); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("Hi there!") + .properties(Map.of("model", "gpt-4", "temperature", "0.7")) + .build(); + + // Store messages with metadata + chatMemory.add(conversationId, userMessage); + chatMemory.add(conversationId, assistantMessage); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + // Verify message count + assertThat(messages).hasSize(2); + + // Verify metadata is preserved + assertThat(messages.get(0).getMetadata()).containsEntry("source", "web"); + assertThat(messages.get(0).getMetadata()).containsEntry("user_id", "12345"); + assertThat(messages.get(1).getMetadata()).containsEntry("model", "gpt-4"); + assertThat(messages.get(1).getMetadata()).containsEntry("temperature", "0.7"); + }); + } + + @ParameterizedTest + @CsvSource({ "ASSISTANT,model=gpt-4;temperature=0.7;api_version=1.0", "USER,source=web;user_id=12345;client=mobile", + "SYSTEM,domain=legal;version=2.0;restricted=true" }) + void shouldStoreAndRetrieveMessageWithMetadata(MessageType messageType, String metadataString) { + this.contextRunner.run(context -> { + String conversationId = UUID.randomUUID().toString(); + String content = "Message with metadata - " + messageType; + + // Parse metadata from string + Map metadata = parseMetadata(metadataString); + + // Create a message with metadata + Message message = switch (messageType) { + case ASSISTANT -> AssistantMessage.builder().content(content).properties(metadata).build(); + case USER -> UserMessage.builder().text(content).metadata(metadata).build(); + case SYSTEM -> SystemMessage.builder().text(content).metadata(metadata).build(); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + // Store the message + chatMemory.add(conversationId, message); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify message was stored correctly + assertThat(messages).hasSize(1); + Message retrievedMessage = messages.get(0); + + // Verify message type + assertThat(retrievedMessage.getMessageType()).isEqualTo(messageType); + + // Verify all metadata entries are present + metadata.forEach((key, value) -> assertThat(retrievedMessage.getMetadata()).containsEntry(key, value)); + }); + } + + // Helper method to parse metadata from string in format + // "key1=value1;key2=value2;key3=value3" + private Map parseMetadata(String metadataString) { + Map metadata = new HashMap<>(); + String[] pairs = metadataString.split(";"); + + for (String pair : pairs) { + String[] keyValue = pair.split("="); + if (keyValue.length == 2) { + metadata.put(keyValue[0], keyValue[1]); + } + } + + return metadata; + } + + @Test + void shouldHandleAssistantMessageWithToolCalls() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Create an AssistantMessage with tool calls + List toolCalls = Arrays.asList( + new AssistantMessage.ToolCall("tool-1", "function", "weather", "{\"location\": \"Paris\"}"), + new AssistantMessage.ToolCall("tool-2", "function", "calculator", + "{\"operation\": \"add\", \"args\": [1, 2]}")); + + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("I'll check that for you.") + .properties(Map.of("model", "gpt-4")) + .toolCalls(toolCalls) + .media(List.of()) + .build(); + + // Store message with tool calls + chatMemory.add(conversationId, assistantMessage); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify we get back the same type of message + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(AssistantMessage.class); + + // Cast and verify tool calls + AssistantMessage retrievedMessage = (AssistantMessage) messages.get(0); + assertThat(retrievedMessage.getToolCalls()).hasSize(2); + + // Verify tool call content + AssistantMessage.ToolCall firstToolCall = retrievedMessage.getToolCalls().get(0); + assertThat(firstToolCall.name()).isEqualTo("weather"); + assertThat(firstToolCall.arguments()).isEqualTo("{\"location\": \"Paris\"}"); + + AssistantMessage.ToolCall secondToolCall = retrievedMessage.getToolCalls().get(1); + assertThat(secondToolCall.name()).isEqualTo("calculator"); + assertThat(secondToolCall.arguments()).contains("\"operation\": \"add\""); + }); + } + + @Test + void shouldHandleBasicToolResponseMessage() { + this.contextRunner.run(context -> { + String conversationId = "tool-response-conversation"; + + // Create a simple ToolResponseMessage with a single tool response + ToolResponseMessage.ToolResponse weatherResponse = new ToolResponseMessage.ToolResponse("tool-1", "weather", + "{\"location\":\"Paris\",\"temperature\":\"22°C\",\"conditions\":\"Partly Cloudy\"}"); + + // Create the message with a single tool response + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(weatherResponse)) + .build(); + + // Store the message + chatMemory.add(conversationId, toolResponseMessage); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify we get back the correct message + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(ToolResponseMessage.class); + assertThat(messages.get(0).getMessageType()).isEqualTo(MessageType.TOOL); + + // Cast and verify tool responses + ToolResponseMessage retrievedMessage = (ToolResponseMessage) messages.get(0); + List toolResponses = retrievedMessage.getResponses(); + + // Verify tool response content + assertThat(toolResponses).hasSize(1); + ToolResponseMessage.ToolResponse response = toolResponses.get(0); + assertThat(response.id()).isEqualTo("tool-1"); + assertThat(response.name()).isEqualTo("weather"); + assertThat(response.responseData()).contains("Paris"); + assertThat(response.responseData()).contains("22°C"); + }); + } + + @Test + void shouldHandleToolResponseMessageWithMultipleResponses() { + this.contextRunner.run(context -> { + String conversationId = "multi-tool-response-conversation"; + + // Create multiple tool responses + ToolResponseMessage.ToolResponse weatherResponse = new ToolResponseMessage.ToolResponse("tool-1", "weather", + "{\"location\":\"Paris\",\"temperature\":\"22°C\",\"conditions\":\"Partly Cloudy\"}"); + + ToolResponseMessage.ToolResponse calculatorResponse = new ToolResponseMessage.ToolResponse("tool-2", + "calculator", "{\"operation\":\"add\",\"args\":[1,2],\"result\":3}"); + + ToolResponseMessage.ToolResponse databaseResponse = new ToolResponseMessage.ToolResponse("tool-3", + "database", "{\"query\":\"SELECT * FROM users\",\"count\":42}"); + + // Create the message with multiple tool responses and metadata + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(weatherResponse, calculatorResponse, databaseResponse)) + .metadata(Map.of("source", "tools-api", "version", "1.0")) + .build(); + + // Store the message + chatMemory.add(conversationId, toolResponseMessage); + + // Retrieve the message + List messages = chatMemory.get(conversationId, 10); + + // Verify message type and count + assertThat(messages).hasSize(1); + assertThat(messages.get(0)).isInstanceOf(ToolResponseMessage.class); + + // Cast and verify + ToolResponseMessage retrievedMessage = (ToolResponseMessage) messages.get(0); + + // Verify metadata + assertThat(retrievedMessage.getMetadata()).containsEntry("source", "tools-api"); + assertThat(retrievedMessage.getMetadata()).containsEntry("version", "1.0"); + + // Verify tool responses + List toolResponses = retrievedMessage.getResponses(); + assertThat(toolResponses).hasSize(3); + + // Verify first response (weather) + ToolResponseMessage.ToolResponse response1 = toolResponses.get(0); + assertThat(response1.id()).isEqualTo("tool-1"); + assertThat(response1.name()).isEqualTo("weather"); + assertThat(response1.responseData()).contains("Paris"); + + // Verify second response (calculator) + ToolResponseMessage.ToolResponse response2 = toolResponses.get(1); + assertThat(response2.id()).isEqualTo("tool-2"); + assertThat(response2.name()).isEqualTo("calculator"); + assertThat(response2.responseData()).contains("result"); + + // Verify third response (database) + ToolResponseMessage.ToolResponse response3 = toolResponses.get(2); + assertThat(response3.id()).isEqualTo("tool-3"); + assertThat(response3.name()).isEqualTo("database"); + assertThat(response3.responseData()).contains("count"); + }); + } + + @Test + void shouldHandleToolResponseInConversationFlow() { + this.contextRunner.run(context -> { + String conversationId = "tool-conversation-flow"; + + // Create a typical conversation flow with tool responses + UserMessage userMessage = new UserMessage("What's the weather in Paris?"); + + // Assistant requests weather information via tool + List toolCalls = List + .of(new AssistantMessage.ToolCall("weather-req-1", "function", "weather", "{\"location\":\"Paris\"}")); + AssistantMessage assistantMessage = AssistantMessage.builder() + .content("I'll check the weather for you.") + .properties(Map.of()) + .toolCalls(toolCalls) + .media(List.of()) + .build(); + + // Tool provides weather information + ToolResponseMessage.ToolResponse weatherResponse = new ToolResponseMessage.ToolResponse("weather-req-1", + "weather", "{\"location\":\"Paris\",\"temperature\":\"22°C\",\"conditions\":\"Partly Cloudy\"}"); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(weatherResponse)) + .build(); + + // Assistant summarizes the information + AssistantMessage finalResponse = new AssistantMessage( + "The current weather in Paris is 22°C and partly cloudy."); + + // Store the conversation + List conversation = List.of(userMessage, assistantMessage, toolResponseMessage, finalResponse); + chatMemory.add(conversationId, conversation); + + // Retrieve the conversation + List messages = chatMemory.get(conversationId, 10); + + // Verify the conversation flow + assertThat(messages).hasSize(4); + assertThat(messages.get(0)).isInstanceOf(UserMessage.class); + assertThat(messages.get(1)).isInstanceOf(AssistantMessage.class); + assertThat(messages.get(2)).isInstanceOf(ToolResponseMessage.class); + assertThat(messages.get(3)).isInstanceOf(AssistantMessage.class); + + // Verify the tool response + ToolResponseMessage retrievedToolResponse = (ToolResponseMessage) messages.get(2); + assertThat(retrievedToolResponse.getResponses()).hasSize(1); + assertThat(retrievedToolResponse.getResponses().get(0).name()).isEqualTo("weather"); + assertThat(retrievedToolResponse.getResponses().get(0).responseData()).contains("Paris"); + + // Verify the final response includes information from the tool + AssistantMessage retrievedFinalResponse = (AssistantMessage) messages.get(3); + assertThat(retrievedFinalResponse.getText()).contains("22°C"); + assertThat(retrievedFinalResponse.getText()).contains("partly cloudy"); + }); + } + + @Test + void getMessages_withAllMessageTypes_shouldPreserveMessageOrder() { + this.contextRunner.run(context -> { + String conversationId = "complex-order-test"; + + // Create a complex conversation with all message types in a specific order + SystemMessage systemMessage = new SystemMessage("You are a helpful AI assistant."); + UserMessage userMessage1 = new UserMessage("What's the capital of France?"); + AssistantMessage assistantMessage1 = new AssistantMessage("The capital of France is Paris."); + UserMessage userMessage2 = new UserMessage("What's the weather there?"); + + // Assistant using tool to check weather + List toolCalls = List + .of(new AssistantMessage.ToolCall("weather-tool-1", "function", "weather", "{\"location\":\"Paris\"}")); + AssistantMessage assistantToolCall = AssistantMessage.builder() + .content("I'll check the weather in Paris for you.") + .properties(Map.of()) + .toolCalls(toolCalls) + .media(List.of()) + .build(); + + // Tool response + ToolResponseMessage.ToolResponse weatherResponse = new ToolResponseMessage.ToolResponse("weather-tool-1", + "weather", "{\"location\":\"Paris\",\"temperature\":\"24°C\",\"conditions\":\"Sunny\"}"); + ToolResponseMessage toolResponseMessage = ToolResponseMessage.builder() + .responses(List.of(weatherResponse)) + .build(); + + // Final assistant response using the tool information + AssistantMessage assistantFinal = new AssistantMessage("The weather in Paris is currently 24°C and sunny."); + + // Create ordered list of messages + List expectedMessages = List.of(systemMessage, userMessage1, assistantMessage1, userMessage2, + assistantToolCall, toolResponseMessage, assistantFinal); + + // Add each message individually with small delays + for (Message message : expectedMessages) { + chatMemory.add(conversationId, message); + Thread.sleep(10); // Small delay to ensure distinct timestamps + } + + // Retrieve and verify messages + List retrievedMessages = chatMemory.get(conversationId, 10); + + // Check the total count matches + assertThat(retrievedMessages).hasSize(expectedMessages.size()); + + // Check each message is in the expected order + for (int i = 0; i < expectedMessages.size(); i++) { + Message expected = expectedMessages.get(i); + Message actual = retrievedMessages.get(i); + + // Verify message types match + assertThat(actual.getMessageType()).isEqualTo(expected.getMessageType()); + + // Verify message content matches + assertThat(actual.getText()).isEqualTo(expected.getText()); + + // For each specific message type, verify type-specific properties + if (expected instanceof SystemMessage) { + assertThat(actual).isInstanceOf(SystemMessage.class); + } + else if (expected instanceof UserMessage) { + assertThat(actual).isInstanceOf(UserMessage.class); + } + else if (expected instanceof AssistantMessage) { + assertThat(actual).isInstanceOf(AssistantMessage.class); + + // If the original had tool calls, verify they're preserved + if (((AssistantMessage) expected).hasToolCalls()) { + AssistantMessage expectedAssistant = (AssistantMessage) expected; + AssistantMessage actualAssistant = (AssistantMessage) actual; + + assertThat(actualAssistant.hasToolCalls()).isTrue(); + assertThat(actualAssistant.getToolCalls()).hasSameSizeAs(expectedAssistant.getToolCalls()); + + // Check first tool call details + assertThat(actualAssistant.getToolCalls().get(0).name()) + .isEqualTo(expectedAssistant.getToolCalls().get(0).name()); + } + } + else if (expected instanceof ToolResponseMessage) { + assertThat(actual).isInstanceOf(ToolResponseMessage.class); + + ToolResponseMessage expectedTool = (ToolResponseMessage) expected; + ToolResponseMessage actualTool = (ToolResponseMessage) actual; + + assertThat(actualTool.getResponses()).hasSameSizeAs(expectedTool.getResponses()); + + // Check response details + assertThat(actualTool.getResponses().get(0).name()) + .isEqualTo(expectedTool.getResponses().get(0).name()); + assertThat(actualTool.getResponses().get(0).id()) + .isEqualTo(expectedTool.getResponses().get(0).id()); + } + } + }); + } + + @Test + void getMessages_afterMultipleAdds_shouldReturnMessagesInCorrectOrder() { + this.contextRunner.run(context -> { + String conversationId = "sequential-adds-test"; + + // Create messages that will be added individually + UserMessage userMessage1 = new UserMessage("First user message"); + AssistantMessage assistantMessage1 = new AssistantMessage("First assistant response"); + UserMessage userMessage2 = new UserMessage("Second user message"); + AssistantMessage assistantMessage2 = new AssistantMessage("Second assistant response"); + UserMessage userMessage3 = new UserMessage("Third user message"); + AssistantMessage assistantMessage3 = new AssistantMessage("Third assistant response"); + + // Add messages one at a time with delays to simulate real conversation + chatMemory.add(conversationId, userMessage1); + Thread.sleep(50); + chatMemory.add(conversationId, assistantMessage1); + Thread.sleep(50); + chatMemory.add(conversationId, userMessage2); + Thread.sleep(50); + chatMemory.add(conversationId, assistantMessage2); + Thread.sleep(50); + chatMemory.add(conversationId, userMessage3); + Thread.sleep(50); + chatMemory.add(conversationId, assistantMessage3); + + // Create the expected message order + List expectedMessages = List.of(userMessage1, assistantMessage1, userMessage2, assistantMessage2, + userMessage3, assistantMessage3); + + // Retrieve all messages + List retrievedMessages = chatMemory.get(conversationId, 10); + + // Check count matches + assertThat(retrievedMessages).hasSize(expectedMessages.size()); + + // Verify each message is in the correct order with correct content + for (int i = 0; i < expectedMessages.size(); i++) { + Message expected = expectedMessages.get(i); + Message actual = retrievedMessages.get(i); + + assertThat(actual.getMessageType()).isEqualTo(expected.getMessageType()); + assertThat(actual.getText()).isEqualTo(expected.getText()); + } + + // Test with a limit + List limitedMessages = chatMemory.get(conversationId, 3); + + // Should get the 3 oldest messages + assertThat(limitedMessages).hasSize(3); + assertThat(limitedMessages.get(0).getText()).isEqualTo(userMessage1.getText()); + assertThat(limitedMessages.get(1).getText()).isEqualTo(assistantMessage1.getText()); + assertThat(limitedMessages.get(2).getText()).isEqualTo(userMessage2.getText()); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepositoryIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepositoryIT.java new file mode 100644 index 00000000000..effc9ebd019 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryRepositoryIT.java @@ -0,0 +1,197 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository implementation of ChatMemoryRepository + * interface. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryRepositoryIT { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemoryRepositoryIT.class); + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private ChatMemoryRepository chatMemoryRepository; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + // Create JedisPooled directly with container properties for more reliable + // connection + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + RedisChatMemoryRepository chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + chatMemoryRepository = chatMemory; + + // Clear any existing data + for (String conversationId : chatMemoryRepository.findConversationIds()) { + chatMemoryRepository.deleteByConversationId(conversationId); + } + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldFindAllConversationIds() { + this.contextRunner.run(context -> { + // Add messages for multiple conversations + chatMemoryRepository.saveAll("conversation-1", List.of(new UserMessage("Hello from conversation 1"), + new AssistantMessage("Hi there from conversation 1"))); + + chatMemoryRepository.saveAll("conversation-2", List.of(new UserMessage("Hello from conversation 2"), + new AssistantMessage("Hi there from conversation 2"))); + + // Verify we can get all conversation IDs + List conversationIds = chatMemoryRepository.findConversationIds(); + assertThat(conversationIds).hasSize(2); + assertThat(conversationIds).containsExactlyInAnyOrder("conversation-1", "conversation-2"); + }); + } + + @Test + void shouldEfficientlyFindAllConversationIdsWithAggregation() { + this.contextRunner.run(context -> { + // Add a large number of messages across fewer conversations to verify + // deduplication + for (int i = 0; i < 10; i++) { + chatMemoryRepository.saveAll("conversation-A", List.of(new UserMessage("Message " + i + " in A"))); + chatMemoryRepository.saveAll("conversation-B", List.of(new UserMessage("Message " + i + " in B"))); + chatMemoryRepository.saveAll("conversation-C", List.of(new UserMessage("Message " + i + " in C"))); + } + + List conversationIds = chatMemoryRepository.findConversationIds(); + + // Verify correctness + assertThat(conversationIds).hasSize(3); + assertThat(conversationIds).containsExactlyInAnyOrder("conversation-A", "conversation-B", "conversation-C"); + }); + } + + @Test + void shouldFindMessagesByConversationId() { + this.contextRunner.run(context -> { + // Add messages for a conversation + List messages = List.of(new UserMessage("Hello"), new AssistantMessage("Hi there!"), + new UserMessage("How are you?")); + chatMemoryRepository.saveAll("test-conversation", messages); + + // Verify we can retrieve messages by conversation ID + List retrievedMessages = chatMemoryRepository.findByConversationId("test-conversation"); + assertThat(retrievedMessages).hasSize(3); + assertThat(retrievedMessages.get(0).getText()).isEqualTo("Hello"); + assertThat(retrievedMessages.get(1).getText()).isEqualTo("Hi there!"); + assertThat(retrievedMessages.get(2).getText()).isEqualTo("How are you?"); + }); + } + + @Test + void shouldSaveAllMessagesForConversation() { + this.contextRunner.run(context -> { + // Add some initial messages + chatMemoryRepository.saveAll("test-conversation", List.of(new UserMessage("Initial message"))); + + // Verify initial state + List initialMessages = chatMemoryRepository.findByConversationId("test-conversation"); + assertThat(initialMessages).hasSize(1); + + // Save all with new messages (should replace existing ones) + List newMessages = List.of(new UserMessage("New message 1"), new AssistantMessage("New message 2"), + new UserMessage("New message 3")); + chatMemoryRepository.saveAll("test-conversation", newMessages); + + // Verify new state + List latestMessages = chatMemoryRepository.findByConversationId("test-conversation"); + assertThat(latestMessages).hasSize(3); + assertThat(latestMessages.get(0).getText()).isEqualTo("New message 1"); + assertThat(latestMessages.get(1).getText()).isEqualTo("New message 2"); + assertThat(latestMessages.get(2).getText()).isEqualTo("New message 3"); + }); + } + + @Test + void shouldDeleteConversation() { + this.contextRunner.run(context -> { + // Add messages for a conversation + chatMemoryRepository.saveAll("test-conversation", + List.of(new UserMessage("Hello"), new AssistantMessage("Hi there!"))); + + // Verify initial state + assertThat(chatMemoryRepository.findByConversationId("test-conversation")).hasSize(2); + + // Delete the conversation + chatMemoryRepository.deleteByConversationId("test-conversation"); + + // Verify conversation is gone + assertThat(chatMemoryRepository.findByConversationId("test-conversation")).isEmpty(); + assertThat(chatMemoryRepository.findConversationIds()).doesNotContain("test-conversation"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + ChatMemoryRepository chatMemoryRepository() { + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryWithSchemaIT.java b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryWithSchemaIT.java new file mode 100644 index 00000000000..2c67efe7b07 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/java/org/springframework/ai/chat/memory/repository/redis/RedisChatMemoryWithSchemaIT.java @@ -0,0 +1,207 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.repository.redis; + +import com.redis.testcontainers.RedisContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.repository.redis.AdvancedChatMemoryRepository; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemoryRepository with user-defined metadata schema. + * Demonstrates how to properly index metadata fields with appropriate types. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryWithSchemaIT { + + @Container + static RedisContainer redisContainer = new RedisContainer("redis/redis-stack:latest"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemoryRepository chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + + // Define metadata schema for proper indexing + List> metadataFields = List.of(Map.of("name", "priority", "type", "tag"), + Map.of("name", "category", "type", "tag"), Map.of("name", "score", "type", "numeric"), + Map.of("name", "confidence", "type", "numeric"), Map.of("name", "model", "type", "tag")); + + // Use a unique index name to ensure we get a fresh schema + String uniqueIndexName = "test-schema-" + System.currentTimeMillis(); + + chatMemory = RedisChatMemoryRepository.builder() + .jedisClient(jedisClient) + .indexName(uniqueIndexName) + .metadataFields(metadataFields) + .build(); + + // Clear existing test data + chatMemory.findConversationIds().forEach(chatMemory::clear); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldFindMessagesByMetadataWithProperSchema() { + this.contextRunner.run(context -> { + String conversationId = "test-metadata-schema"; + + // Create messages with different metadata + UserMessage userMsg1 = new UserMessage("High priority task"); + userMsg1.getMetadata().put("priority", "high"); + userMsg1.getMetadata().put("category", "task"); + userMsg1.getMetadata().put("score", 95); + + AssistantMessage assistantMsg = new AssistantMessage("I'll help with that"); + assistantMsg.getMetadata().put("model", "gpt-4"); + assistantMsg.getMetadata().put("confidence", 0.95); + assistantMsg.getMetadata().put("category", "response"); + + UserMessage userMsg2 = new UserMessage("Low priority question"); + userMsg2.getMetadata().put("priority", "low"); + userMsg2.getMetadata().put("category", "question"); + userMsg2.getMetadata().put("score", 75); + + // Add messages + chatMemory.add(conversationId, userMsg1); + chatMemory.add(conversationId, assistantMsg); + chatMemory.add(conversationId, userMsg2); + + // Give Redis time to index the documents + Thread.sleep(100); + + // Test finding by tag metadata (priority) + List highPriorityMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("priority", "high", 10); + + assertThat(highPriorityMessages).hasSize(1); + assertThat(highPriorityMessages.get(0).message().getText()).isEqualTo("High priority task"); + + // Test finding by tag metadata (category) + List taskMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("category", "task", 10); + + assertThat(taskMessages).hasSize(1); + + // Test finding by numeric metadata (score) + List highScoreMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("score", 95, 10); + + assertThat(highScoreMessages).hasSize(1); + assertThat(highScoreMessages.get(0).message().getMetadata().get("score")).isEqualTo(95.0); + + // Test finding by numeric metadata (confidence) + List confidentMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("confidence", 0.95, 10); + + assertThat(confidentMessages).hasSize(1); + assertThat(confidentMessages.get(0).message().getMetadata().get("model")).isEqualTo("gpt-4"); + + // Test with non-existent metadata key (not in schema) + List nonExistentMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("nonexistent", "value", 10); + + assertThat(nonExistentMessages).isEmpty(); + + // Clean up + chatMemory.clear(conversationId); + }); + } + + @Test + void shouldFallbackToTextSearchForUndefinedMetadataFields() { + this.contextRunner.run(context -> { + String conversationId = "test-undefined-metadata"; + + // Create message with metadata field not defined in schema + UserMessage userMsg = new UserMessage("Message with custom metadata"); + userMsg.getMetadata().put("customField", "customValue"); + userMsg.getMetadata().put("priority", "medium"); // This is defined in schema + + chatMemory.add(conversationId, userMsg); + + // Defined field should work with exact match + List priorityMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("priority", "medium", 10); + + assertThat(priorityMessages).hasSize(1); + + // Undefined field will fall back to text search in general metadata + // This may or may not find the message depending on how the text is indexed + List customMessages = ((AdvancedChatMemoryRepository) chatMemory) + .findByMetadata("customField", "customValue", 10); + + // The result depends on whether the general metadata text field caught this + // In practice, users should define all metadata fields they want to search on + + // Clean up + chatMemory.clear(conversationId); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemoryRepository chatMemory() { + List> metadataFields = List.of(Map.of("name", "priority", "type", "tag"), + Map.of("name", "category", "type", "tag"), Map.of("name", "score", "type", "numeric"), + Map.of("name", "confidence", "type", "numeric"), Map.of("name", "model", "type", "tag")); + + // Use a unique index name to ensure we get a fresh schema + String uniqueIndexName = "test-schema-app-" + System.currentTimeMillis(); + + return RedisChatMemoryRepository.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName(uniqueIndexName) + .metadataFields(metadataFields) + .build(); + } + + } + +} \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/application-metadata-schema.yml b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/application-metadata-schema.yml new file mode 100644 index 00000000000..5bd5fe846d0 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/application-metadata-schema.yml @@ -0,0 +1,23 @@ +spring: + ai: + chat: + memory: + redis: + host: localhost + port: 6379 + index-name: chat-memory-with-schema + # Define metadata fields with their types for proper indexing + # This is compatible with RedisVL schema format + metadata-fields: + - name: priority + type: tag # For exact match searches (high, medium, low) + - name: category + type: tag # For exact match searches + - name: score + type: numeric # For numeric range queries + - name: confidence + type: numeric # For numeric comparisons + - name: model + type: tag # For exact match on model names + - name: description + type: text # For full-text search \ No newline at end of file diff --git a/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/logback-test.xml b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/logback-test.xml new file mode 100644 index 00000000000..9a8dc8e8660 --- /dev/null +++ b/memory/repository/spring-ai-model-chat-memory-repository-redis/src/test/resources/logback-test.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/pom.xml b/pom.xml index a2e6dff8602..207a1a68238 100644 --- a/pom.xml +++ b/pom.xml @@ -46,6 +46,7 @@ memory/repository/spring-ai-model-chat-memory-repository-jdbc memory/repository/spring-ai-model-chat-memory-repository-mongodb memory/repository/spring-ai-model-chat-memory-repository-neo4j + memory/repository/spring-ai-model-chat-memory-repository-redis @@ -94,6 +95,7 @@ auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-jdbc auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-mongodb auto-configurations/models/chat/memory/repository/spring-ai-autoconfigure-model-chat-memory-repository-neo4j + auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-redis auto-configurations/models/chat/observation/spring-ai-autoconfigure-model-chat-observation