1616
1717package org .springframework .ai .mistralai ;
1818
19+ import java .util .Arrays ;
1920import java .util .List ;
2021
2122import org .junit .jupiter .api .Test ;
2829
2930import static org .assertj .core .api .Assertions .assertThat ;
3031import static org .mockito .ArgumentMatchers .any ;
32+ import static org .mockito .Mockito .verify ;
3133import static org .mockito .Mockito .when ;
3234
3335/**
3436 * Unit tests for {@link MistralAiEmbeddingModel}.
3537 *
38+ * @author Mark Pollack
3639 * @author Nicolas Krier
3740 */
3841class MistralAiEmbeddingModelTests {
@@ -77,7 +80,7 @@ void testDimensionsForCodestralEmbedModel() {
7780 void testDimensionsFallbackForUnknownModel () {
7881 MistralAiApi mockApi = createMockApiWithEmbeddingResponse (512 );
7982
80- // Use a model name that doesn't exist in KNOWN_EMBEDDING_DIMENSIONS
83+ // Use a model name that doesn't exist in knownEmbeddingDimensions.
8184 MistralAiEmbeddingOptions options = MistralAiEmbeddingOptions .builder ().withModel ("unknown-model" ).build ();
8285
8386 MistralAiEmbeddingModel model = MistralAiEmbeddingModel .builder ()
@@ -87,17 +90,23 @@ void testDimensionsFallbackForUnknownModel() {
8790 .retryTemplate (RetryUtils .DEFAULT_RETRY_TEMPLATE )
8891 .build ();
8992
90- // Should fall back to super.dimensions() which detects dimensions from the API
91- // response
93+ // For the first call, it should fall back to super.dimensions() which detects
94+ // dimensions from the API response.
9295 assertThat (model .dimensions ()).isEqualTo (512 );
96+
97+ // For the second call, it should use the cache mechanism.
98+ assertThat (model .dimensions ()).isEqualTo (512 );
99+
100+ // Verify that super.dimensions() has been called once.
101+ verify (mockApi ).embeddings (any ());
93102 }
94103
95104 @ Test
96105 void testAllEmbeddingModelsHaveDimensionMapping () {
97- // This test ensures that KNOWN_EMBEDDING_DIMENSIONS map stays in sync with the
98- // EmbeddingModel enum
106+ // This test ensures that knownEmbeddingDimensions map stays in sync with the
107+ // EmbeddingModel enum.
99108 // If a new model is added to the enum but not to the dimensions map, this test
100- // will help catch it
109+ // will help catch it.
101110
102111 for (MistralAiApi .EmbeddingModel embeddingModel : MistralAiApi .EmbeddingModel .values ()) {
103112 MistralAiApi mockApi = createMockApiWithEmbeddingResponse (1024 );
@@ -138,16 +147,13 @@ private MistralAiApi createMockApiWithEmbeddingResponse(int dimensions) {
138147
139148 // Create a mock embedding response with the specified dimensions
140149 float [] embedding = new float [dimensions ];
141- for (int i = 0 ; i < dimensions ; i ++) {
142- embedding [i ] = 0.1f ;
143- }
150+ Arrays .fill (embedding , 0.1f );
144151
145152 MistralAiApi .Embedding embeddingData = new MistralAiApi .Embedding (0 , embedding , "embedding" );
146153
147154 MistralAiApi .Usage usage = new MistralAiApi .Usage (10 , 0 , 10 );
148155
149- MistralAiApi .EmbeddingList embeddingList = new MistralAiApi .EmbeddingList ("object" , List .of (embeddingData ),
150- "model" , usage );
156+ var embeddingList = new MistralAiApi .EmbeddingList <>("object" , List .of (embeddingData ), "model" , usage );
151157
152158 when (mockApi .embeddings (any ())).thenReturn (ResponseEntity .ok (embeddingList ));
153159
0 commit comments