diff --git a/Directory.Packages.props b/Directory.Packages.props index b761ccef354..928d36a3852 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -30,5 +30,6 @@ + \ No newline at end of file diff --git a/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs b/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs index 7662b533f83..9ac7428ef40 100644 --- a/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs +++ b/libs/client/ClientSession/GarnetClientSessionMigrationExtensions.cs @@ -25,6 +25,7 @@ public sealed unsafe partial class GarnetClientSession : IServerHook, IMessageCo static ReadOnlySpan MAIN_STORE => "SSTORE"u8; static ReadOnlySpan OBJECT_STORE => "OSTORE"u8; + static ReadOnlySpan VECTOR_STORE => "VSTORE"u8; static ReadOnlySpan T => "T"u8; static ReadOnlySpan F => "F"u8; @@ -170,14 +171,30 @@ public Task SetSlotRange(Memory state, string nodeid, List<(int, i /// /// /// - public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMainStore) + public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMainStore, bool isVectorSets) { currTcsIterationTask = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); tcsQueue.Enqueue(currTcsIterationTask); curr = offset; this.isMainStore = isMainStore; this.ist = IncrementalSendType.MIGRATE; - var storeType = isMainStore ? MAIN_STORE : OBJECT_STORE; + ReadOnlySpan storeType; + if (isMainStore) + { + if (isVectorSets) + { + storeType = VECTOR_STORE; + } + else + { + storeType = MAIN_STORE; + } + } + else + { + storeType = OBJECT_STORE; + } + var replaceOption = replace ? T : F; var arraySize = 6; @@ -249,7 +266,7 @@ public void SetClusterMigrateHeader(string sourceNodeId, bool replace, bool isMa /// public Task CompleteMigrate(string sourceNodeId, bool replace, bool isMainStore) { - SetClusterMigrateHeader(sourceNodeId, replace, isMainStore); + SetClusterMigrateHeader(sourceNodeId, replace, isMainStore, isVectorSets: false); Debug.Assert(end - curr >= 2); *curr++ = (byte)'\r'; diff --git a/libs/cluster/Server/ClusterManager.cs b/libs/cluster/Server/ClusterManager.cs index 2cfbcaf5f07..1dbef4dbed2 100644 --- a/libs/cluster/Server/ClusterManager.cs +++ b/libs/cluster/Server/ClusterManager.cs @@ -239,22 +239,27 @@ public string GetInfo() public static string GetRange(int[] slots) { var range = "> "; - var start = slots[0]; - var end = slots[0]; - for (var i = 1; i < slots.Length + 1; i++) + if (slots.Length >= 1) { - if (i < slots.Length && slots[i] == end + 1) - end = slots[i]; - else + + var start = slots[0]; + var end = slots[0]; + for (var i = 1; i < slots.Length + 1; i++) { - range += $"{start}-{end} "; - if (i < slots.Length) - { - start = slots[i]; + if (i < slots.Length && slots[i] == end + 1) end = slots[i]; + else + { + range += $"{start}-{end} "; + if (i < slots.Length) + { + start = slots[i]; + end = slots[i]; + } } } } + return range; } diff --git a/libs/cluster/Server/ClusterManagerSlotState.cs b/libs/cluster/Server/ClusterManagerSlotState.cs index a35e474a263..0ef36402b84 100644 --- a/libs/cluster/Server/ClusterManagerSlotState.cs +++ b/libs/cluster/Server/ClusterManagerSlotState.cs @@ -17,7 +17,10 @@ namespace Garnet.cluster SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; /// /// Cluster manager diff --git a/libs/cluster/Server/ClusterProvider.cs b/libs/cluster/Server/ClusterProvider.cs index 51ac87401f0..9af5cf6a02e 100644 --- a/libs/cluster/Server/ClusterProvider.cs +++ b/libs/cluster/Server/ClusterProvider.cs @@ -15,12 +15,21 @@ namespace Garnet.cluster { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; /// /// Cluster provider @@ -100,8 +109,8 @@ public void Start() } /// - public IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null) - => new ClusterSession(this, txnManager, authenticator, userHandle, garnetSessionMetrics, basicGarnetApi, networkSender, logger); + public IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, BasicContext basicContext, VectorContext vectorContext, INetworkSender networkSender, ILogger logger = null) + => new ClusterSession(this, txnManager, authenticator, userHandle, garnetSessionMetrics, basicGarnetApi, basicContext, vectorContext, networkSender, logger); /// public void UpdateClusterAuth(string clusterUsername, string clusterPassword) diff --git a/libs/cluster/Server/Migration/MigrateOperation.cs b/libs/cluster/Server/Migration/MigrateOperation.cs index d4f069a8189..3f677c959ee 100644 --- a/libs/cluster/Server/Migration/MigrateOperation.cs +++ b/libs/cluster/Server/Migration/MigrateOperation.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using Garnet.client; using Garnet.server; +using Microsoft.Extensions.Logging; using Tsavorite.core; namespace Garnet.cluster @@ -18,16 +20,25 @@ internal sealed partial class MigrateOperation public MainStoreScan mss; public ObjectStoreScan oss; + private readonly ConcurrentDictionary vectorSetsIndexKeysToMigrate; + readonly MigrateSession session; readonly GarnetClientSession gcs; readonly LocalServerSession localServerSession; public GarnetClientSession Client => gcs; + public IEnumerable> VectorSets => vectorSetsIndexKeysToMigrate; + public void ThrowIfCancelled() => session._cts.Token.ThrowIfCancellationRequested(); public bool Contains(int slot) => session._sslots.Contains(slot); + public bool ContainsNamespace(ulong ns) => session._namespaces?.Contains(ns) ?? false; + + public void EncounteredVectorSet(byte[] key, byte[] value) + => vectorSetsIndexKeysToMigrate.TryAdd(key, value); + public MigrateOperation(MigrateSession session, Sketch sketch = null, int batchSize = 1 << 18) { this.session = session; @@ -37,6 +48,7 @@ public MigrateOperation(MigrateSession session, Sketch sketch = null, int batchS mss = new MainStoreScan(this); oss = new ObjectStoreScan(this); keysToDelete = []; + vectorSetsIndexKeysToMigrate = new(ByteArrayComparer.Instance); } public bool Initialize() @@ -72,7 +84,7 @@ public void Scan(StoreType storeType, ref long currentAddress, long endAddress) /// /// /// - public bool TrasmitSlots(StoreType storeType) + public bool TransmitSlots(StoreType storeType) { var bufferSize = 1 << 10; SectorAlignedMemory buffer = new(bufferSize, 1); @@ -87,7 +99,7 @@ public bool TrasmitSlots(StoreType storeType) { foreach (var key in sketch.argSliceVector) { - var spanByte = key.SpanByte; + var spanByte = key; if (!session.WriteOrSendMainStoreKeyValuePair(gcs, localServerSession, ref spanByte, ref input, ref o, out _)) return false; @@ -117,7 +129,10 @@ public bool TrasmitSlots(StoreType storeType) return true; } - public bool TransmitKeys(StoreType storeType) + /// + /// Move keys in sketch out of the given store, UNLESS they are also in . + /// + public bool TransmitKeys(StoreType storeType, Dictionary vectorSetKeysToIgnore) { var bufferSize = 1 << 10; SectorAlignedMemory buffer = new(bufferSize, 1); @@ -131,12 +146,30 @@ public bool TransmitKeys(StoreType storeType) var keys = sketch.Keys; if (storeType == StoreType.Main) { +#if NET9_0_OR_GREATER + var ignoreLookup = vectorSetKeysToIgnore.GetAlternateLookup>(); +#endif + for (var i = 0; i < keys.Count; i++) { if (keys[i].Item2) continue; var spanByte = keys[i].Item1.SpanByte; + + // Don't transmit if a Vector Set + var isVectorSet = + vectorSetKeysToIgnore.Count > 0 && +#if NET9_0_OR_GREATER + ignoreLookup.ContainsKey(spanByte.AsReadOnlySpan()); +#else + vectorSetKeysToIgnore.ContainsKey(spanByte.ToByteArray()); +#endif + if (isVectorSet) + { + continue; + } + if (!session.WriteOrSendMainStoreKeyValuePair(gcs, localServerSession, ref spanByte, ref input, ref o, out var status)) return false; @@ -158,8 +191,8 @@ public bool TransmitKeys(StoreType storeType) if (keys[i].Item2) continue; - var argSlice = keys[i].Item1; - if (!session.WriteOrSendObjectStoreKeyValuePair(gcs, localServerSession, ref argSlice, out var status)) + var spanByte = keys[i].Item1.SpanByte; + if (!session.WriteOrSendObjectStoreKeyValuePair(gcs, localServerSession, ref spanByte, out var status)) return false; // Skip if key NOTFOUND @@ -182,6 +215,54 @@ public bool TransmitKeys(StoreType storeType) return true; } + /// + /// Transmit data in namespaces during a MIGRATE ... KEYS operation. + /// + /// Doesn't delete anything, just scans and transmits. + /// + public bool TransmitKeysNamespaces(ILogger logger) + { + var migrateOperation = this; + + if (!migrateOperation.Initialize()) + return false; + + var workerStartAddress = migrateOperation.session.clusterProvider.storeWrapper.store.Log.BeginAddress; + var workerEndAddress = migrateOperation.session.clusterProvider.storeWrapper.store.Log.TailAddress; + + var cursor = workerStartAddress; + logger?.LogWarning(" migrate keys (namespaces) scan range [{workerStartAddress}, {workerEndAddress}]", workerStartAddress, workerEndAddress); + while (true) + { + var current = cursor; + // Build Sketch + migrateOperation.sketch.SetStatus(SketchStatus.INITIALIZING); + migrateOperation.Scan(StoreType.Main, ref current, workerEndAddress); + + // Stop if no keys have been found + if (migrateOperation.sketch.argSliceVector.IsEmpty) break; + + logger?.LogWarning("Scan from {cursor} to {current} and discovered {count} keys", cursor, current, migrateOperation.sketch.argSliceVector.Count); + + // Transition EPSM to MIGRATING + migrateOperation.sketch.SetStatus(SketchStatus.TRANSMITTING); + migrateOperation.session.WaitForConfigPropagation(); + + // Transmit all keys gathered + migrateOperation.TransmitSlots(StoreType.Main); + + // Transition EPSM to DELETING + migrateOperation.sketch.SetStatus(SketchStatus.DELETING); + migrateOperation.session.WaitForConfigPropagation(); + + // Clear keys from buffer + migrateOperation.sketch.Clear(); + cursor = current; + } + + return true; + } + /// /// Delete keys after migration if copyOption is not set /// @@ -193,7 +274,13 @@ public void DeleteKeys() { foreach (var key in sketch.argSliceVector) { - var spanByte = key.SpanByte; + if (key.MetadataSize == 1) + { + // Namespace'd keys are not deleted here, but when migration finishes + continue; + } + + var spanByte = key; _ = localServerSession.BasicGarnetApi.DELETE(ref spanByte); } } @@ -209,6 +296,19 @@ public void DeleteKeys() } } } + + /// + /// Delete a Vector Set after migration if _copyOption is not set. + /// + public void DeleteVectorSet(ref SpanByte key) + { + if (session._copyOption) + return; + + var delRes = localServerSession.BasicGarnetApi.DELETE(ref key); + + session.logger?.LogDebug("Deleting Vector Set {key} after migration: {delRes}", System.Text.Encoding.UTF8.GetString(key.AsReadOnlySpan()), delRes); + } } } } \ No newline at end of file diff --git a/libs/cluster/Server/Migration/MigrateScanFunctions.cs b/libs/cluster/Server/Migration/MigrateScanFunctions.cs index 03cb23d1af8..25d9f5da3d3 100644 --- a/libs/cluster/Server/Migration/MigrateScanFunctions.cs +++ b/libs/cluster/Server/Migration/MigrateScanFunctions.cs @@ -36,10 +36,34 @@ public unsafe bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMeta if (ClusterSession.Expired(ref value)) return true; - var s = HashSlotUtils.HashSlot(ref key); - // Check if key belongs to slot that is being migrated and if it can be added to our buffer - if (mss.Contains(s) && !mss.sketch.TryHashAndStore(key.AsSpan())) - return false; + // TODO: Some other way to detect namespaces + if (key.MetadataSize == 1) + { + var ns = key.GetNamespaceInPayload(); + + if (mss.ContainsNamespace(ns) && !mss.sketch.TryHashAndStore(ns, key.AsSpan())) + return false; + } + else + { + var s = HashSlotUtils.HashSlot(ref key); + + // Check if key belongs to slot that is being migrated... + if (mss.Contains(s)) + { + if (recordMetadata.RecordInfo.VectorSet) + { + // We can't delete the vector set _yet_ nor can we migrate it, + // we just need to remember it to migrate once the associated namespaces are all moved over + mss.EncounteredVectorSet(key.ToByteArray(), value.ToByteArray()); + } + else if (!mss.sketch.TryHashAndStore(key.AsSpan())) + { + // Out of space, end scan for now + return false; + } + } + } return true; } diff --git a/libs/cluster/Server/Migration/MigrateSession.cs b/libs/cluster/Server/Migration/MigrateSession.cs index 16c4cb481dd..cd59a66d347 100644 --- a/libs/cluster/Server/Migration/MigrateSession.cs +++ b/libs/cluster/Server/Migration/MigrateSession.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Frozen; using System.Collections.Generic; using System.Linq; using System.Net; @@ -48,6 +49,9 @@ internal sealed unsafe partial class MigrateSession : IDisposable readonly HashSet _sslots; readonly CancellationTokenSource _cts = new(); + HashSet _namespaces; + FrozenDictionary _namespaceMap; + /// /// Get endpoint of target node /// @@ -276,9 +280,10 @@ public bool TrySetSlotRanges(string nodeid, MigrateState state) Status = MigrateState.FAIL; return false; } - logger?.LogTrace("[Completed] SETSLOT {slots} {state} {nodeid}", ClusterManager.GetRange([.. _sslots]), state, nodeid == null ? "" : nodeid); + logger?.LogTrace("[Completed] SETSLOT {slots} {state} {nodeid}", ClusterManager.GetRange([.. _sslots]), state, nodeid ?? ""); return true; - }, TaskContinuationOptions.OnlyOnRanToCompletion).WaitAsync(_timeout, _cts.Token).Result; + }, TaskContinuationOptions.OnlyOnRanToCompletion) + .WaitAsync(_timeout, _cts.Token).Result; } catch (Exception ex) { @@ -338,6 +343,8 @@ public bool TryRecoverFromFailure() // This will execute the equivalent of SETSLOTRANGE STABLE for the slots of the failed migration task ResetLocalSlot(); + // TODO: Need to relinquish any migrating Vector Set contexts from target node + // Log explicit migration failure. Status = MigrateState.FAIL; return true; diff --git a/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs b/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs index 835f755a4b8..a11059bfe49 100644 --- a/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs +++ b/libs/cluster/Server/Migration/MigrateSessionCommonUtils.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using System.Threading.Tasks; using Garnet.client; using Garnet.server; @@ -29,6 +30,18 @@ private bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, LocalServ value = ref SpanByte.ReinterpretWithoutLength(o.Memory.Memory.Span); } + // Map up any namespaces as needed + // TODO: Better way to do "has namespace" + if (key.MetadataSize == 1) + { + var oldNs = key.GetNamespaceInPayload(); + if (_namespaceMap.TryGetValue(oldNs, out var newNs)) + { + Debug.Assert(newNs <= byte.MaxValue, "Namespace too large"); + key.SetNamespaceInPayload((byte)newNs); + } + } + // Write key to network buffer if it has not expired if (!ClusterSession.Expired(ref value) && !WriteOrSendMainStoreKeyValuePair(gcs, ref key, ref value)) return false; @@ -39,7 +52,7 @@ bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, ref SpanByte key, { // Check if we need to initialize cluster migrate command arguments if (gcs.NeedsInitialization) - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: false); // Try write serialized key value to client buffer while (!gcs.TryWriteKeyValueSpanByte(ref key, ref value, out var task)) @@ -49,15 +62,15 @@ bool WriteOrSendMainStoreKeyValuePair(GarnetClientSession gcs, ref SpanByte key, return false; // re-initialize cluster migrate command parameters - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: false); } return true; } } - private bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, LocalServerSession localServerSession, ref ArgSlice key, out GarnetStatus status) + private bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, LocalServerSession localServerSession, ref SpanByte key, out GarnetStatus status) { - var keyByteArray = key.ToArray(); + var keyByteArray = key.AsReadOnlySpan().ToArray(); ObjectInput input = default; GarnetObjectStoreOutput value = default; @@ -81,14 +94,14 @@ bool WriteOrSendObjectStoreKeyValuePair(GarnetClientSession gcs, byte[] key, byt { // Check if we need to initialize cluster migrate command arguments if (gcs.NeedsInitialization) - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false, isVectorSets: false); while (!gcs.TryWriteKeyValueByteArray(key, value, expiration, out var task)) { // Flush key value pairs in the buffer if (!HandleMigrateTaskResponse(task)) return false; - gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false); + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: false, isVectorSets: false); } return true; } diff --git a/libs/cluster/Server/Migration/MigrateSessionKeys.cs b/libs/cluster/Server/Migration/MigrateSessionKeys.cs index 294b4ae3172..a49b5eabf45 100644 --- a/libs/cluster/Server/Migration/MigrateSessionKeys.cs +++ b/libs/cluster/Server/Migration/MigrateSessionKeys.cs @@ -2,6 +2,8 @@ // Licensed under the MIT license. using System; +using System.Collections.Generic; +using System.Linq; using Garnet.server; using Microsoft.Extensions.Logging; using Tsavorite.core; @@ -33,13 +35,78 @@ private bool MigrateKeysFromMainStore() migrateTask.sketch.SetStatus(SketchStatus.TRANSMITTING); WaitForConfigPropagation(); + // Discover Vector Sets linked namespaces + var indexesToMigrate = new Dictionary(ByteArrayComparer.Instance); + _namespaces = clusterProvider.storeWrapper.DefaultDatabase.VectorManager.GetNamespacesForKeys(clusterProvider.storeWrapper, migrateTask.sketch.Keys.Select(t => t.Item1.ToArray()), indexesToMigrate); + + // If we have any namespaces, that implies Vector Sets, and if we have any of THOSE + // we need to reserve destination sets on the other side + if ((_namespaces?.Count ?? 0) > 0 && !ReserveDestinationVectorSetsAsync().GetAwaiter().GetResult()) + { + logger?.LogError("Failed to reserve destination vector sets, migration failed"); + return false; + } + // Transmit keys from main store - if (!migrateTask.TransmitKeys(StoreType.Main)) + if (!migrateTask.TransmitKeys(StoreType.Main, indexesToMigrate)) { logger?.LogError("Failed transmitting keys from main store"); return false; } + if ((_namespaces?.Count ?? 0) > 0) + { + // Actually move element data over + if (!migrateTask.TransmitKeysNamespaces(logger)) + { + logger?.LogError("Failed to transmit vector set (namespaced) element data, migration failed"); + return false; + } + + // Move the indexes over + var gcs = migrateTask.Client; + + foreach (var (key, value) in indexesToMigrate) + { + // Update the index context as we move it, so it arrives on the destination node pointed at the appropriate + // namespaces for element data + VectorManager.ReadIndex(value, out var oldContext, out _, out _, out _, out _, out _, out _, out _); + + var newContext = _namespaceMap[oldContext]; + VectorManager.SetContextForMigration(value, newContext); + + unsafe + { + fixed (byte* keyPtr = key, valuePtr = value) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + var valSpan = SpanByte.FromPinnedPointer(valuePtr, value.Length); + + if (gcs.NeedsInitialization) + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + + while (!gcs.TryWriteKeyValueSpanByte(ref keySpan, ref valSpan, out var task)) + { + if (!HandleMigrateTaskResponse(task)) + { + logger?.LogCritical("Failed to migrate Vector Set key {key} during migration", keySpan); + return false; + } + + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + } + } + } + } + + if (!HandleMigrateTaskResponse(gcs.SendAndResetIterationBuffer())) + { + logger?.LogCritical("Final flush after Vector Set migration failed"); + return false; + } + } + + // Final cleanup, which will also delete Vector Sets DeleteKeys(); } finally @@ -68,7 +135,7 @@ private bool MigrateKeysFromObjectStore() WaitForConfigPropagation(); // Transmit keys from object store - if (!migrateTask.TransmitKeys(StoreType.Object)) + if (!migrateTask.TransmitKeys(StoreType.Object, new(ByteArrayComparer.Instance))) { logger?.LogError("Failed transmitting keys from object store"); return false; diff --git a/libs/cluster/Server/Migration/MigrateSessionSlots.cs b/libs/cluster/Server/Migration/MigrateSessionSlots.cs index 0d153cc4aa0..7ce25a4048d 100644 --- a/libs/cluster/Server/Migration/MigrateSessionSlots.cs +++ b/libs/cluster/Server/Migration/MigrateSessionSlots.cs @@ -2,17 +2,68 @@ // Licensed under the MIT license. using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; using System.Threading.Tasks; #if DEBUG using Garnet.common; #endif using Garnet.server; using Microsoft.Extensions.Logging; +using Tsavorite.core; namespace Garnet.cluster { internal sealed partial class MigrateSession : IDisposable { + /// + /// Attempts to reserve contexts on the destination node for migrating vector sets. + /// + /// This maps roughly to "for each namespaces, reserve one context, record the mapping". + /// + public async Task ReserveDestinationVectorSetsAsync() + { + Debug.Assert((_namespaces.Count % (int)VectorManager.ContextStep) == 0, "Expected to be migrating Vector Sets, and thus to have an even number of namespaces"); + + var neededContexts = _namespaces.Count / (int)VectorManager.ContextStep; + + try + { + var reservedCtxs = await migrateOperation[0].Client.ExecuteForArrayAsync("CLUSTER", "RESERVE", "VECTOR_SET_CONTEXTS", neededContexts.ToString()); + + var rootNamespacesMigrating = _namespaces.Where(static x => (x % VectorManager.ContextStep) == 0); + + var nextReservedIx = 0; + + var namespaceMap = new Dictionary(); + + foreach (var migratingContext in rootNamespacesMigrating) + { + var toMapTo = ulong.Parse(reservedCtxs[nextReservedIx]); + for (var i = 0U; i < VectorManager.ContextStep; i++) + { + var fromCtx = migratingContext + i; + var toCtx = toMapTo + i; + + namespaceMap[fromCtx] = toCtx; + } + + nextReservedIx++; + } + + _namespaceMap = namespaceMap.ToFrozenDictionary(); + + return true; + } + catch (Exception ex) + { + logger?.LogError(ex, "Failed to reserve {count} Vector Set contexts on destination node {node}", neededContexts, _targetNodeId); + return false; + } + } + /// /// Migrate Slots inline driver /// @@ -61,6 +112,60 @@ async Task CreateAndRunMigrateTasks(StoreType storeType, long beginAddress try { await Task.WhenAll(migrateOperationRunners).WaitAsync(_timeout, _cts.Token).ConfigureAwait(false); + + // Handle migration of discovered Vector Set keys now that they're namespaces have been moved + if (storeType == StoreType.Main) + { + var vectorSets = migrateOperation.SelectMany(static mo => mo.VectorSets).GroupBy(static g => g.Key, ByteArrayComparer.Instance).ToDictionary(static g => g.Key, g => g.First().Value, ByteArrayComparer.Instance); + + if (vectorSets.Count > 0) + { + var gcs = migrateOperation[0].Client; + + foreach (var (key, value) in vectorSets) + { + // Update the index context as we move it, so it arrives on the destination node pointed at the appropriate + // namespaces for element data + VectorManager.ReadIndex(value, out var oldContext, out _, out _, out _, out _, out _, out _, out _); + + var newContext = _namespaceMap[oldContext]; + VectorManager.SetContextForMigration(value, newContext); + + unsafe + { + fixed (byte* keyPtr = key, valuePtr = value) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + var valSpan = SpanByte.FromPinnedPointer(valuePtr, value.Length); + + if (gcs.NeedsInitialization) + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + + while (!gcs.TryWriteKeyValueSpanByte(ref keySpan, ref valSpan, out var task)) + { + if (!HandleMigrateTaskResponse(task)) + { + logger?.LogCritical("Failed to migrate Vector Set key {key} during migration", keySpan); + return false; + } + + gcs.SetClusterMigrateHeader(_sourceNodeId, _replaceOption, isMainStore: true, isVectorSets: true); + } + + // Force a flush before doing the delete, in case that fails + if (!HandleMigrateTaskResponse(gcs.SendAndResetIterationBuffer())) + { + logger?.LogCritical("Flush failed before deletion of Vector Set {key} duration migration", keySpan); + return false; + } + + // Delete the index on this node now that it's moved over to the destination node + migrateOperation[0].DeleteVectorSet(ref keySpan); + } + } + } + } + } } catch (Exception ex) { @@ -68,6 +173,7 @@ async Task CreateAndRunMigrateTasks(StoreType storeType, long beginAddress _cts.Cancel(); return false; } + return true; } @@ -103,7 +209,7 @@ Task ScanStoreTask(int taskId, StoreType storeType, long beginAddress, lon WaitForConfigPropagation(); // Transmit all keys gathered - migrateOperation.TrasmitSlots(storeType); + migrateOperation.TransmitSlots(storeType); // Transition EPSM to DELETING migrateOperation.sketch.SetStatus(SketchStatus.DELETING); diff --git a/libs/cluster/Server/Migration/MigrationDriver.cs b/libs/cluster/Server/Migration/MigrationDriver.cs index d2e6af5c1c2..eeda6d6d7e2 100644 --- a/libs/cluster/Server/Migration/MigrationDriver.cs +++ b/libs/cluster/Server/Migration/MigrationDriver.cs @@ -78,6 +78,19 @@ private async Task BeginAsyncMigrationTask() if (!clusterProvider.BumpAndWaitForEpochTransition()) return; #endregion + // Acquire namespaces at this point, after slots have been switch to migration + _namespaces = clusterProvider.storeWrapper.DefaultDatabase.VectorManager.GetNamespacesForHashSlots(_sslots); + + // If we have any namespaces, that implies Vector Sets, and if we have any of THOSE + // we need to reserve destination sets on the other side + if ((_namespaces?.Count ?? 0) > 0 && !await ReserveDestinationVectorSetsAsync()) + { + logger?.LogError("Failed to reserve destination vector sets, migration failed"); + TryRecoverFromFailure(); + Status = MigrateState.FAIL; + return; + } + #region migrateData // Migrate actual data if (!await MigrateSlotsDriverInline()) @@ -87,6 +100,7 @@ private async Task BeginAsyncMigrationTask() Status = MigrateState.FAIL; return; } + #endregion #region transferSlotOwnnershipToTargetNode diff --git a/libs/cluster/Server/Migration/Sketch.cs b/libs/cluster/Server/Migration/Sketch.cs index 4c1ff3e376e..59f3d0bc4a5 100644 --- a/libs/cluster/Server/Migration/Sketch.cs +++ b/libs/cluster/Server/Migration/Sketch.cs @@ -44,6 +44,19 @@ public bool TryHashAndStore(Span key) return true; } + public bool TryHashAndStore(ulong ns, Span key) + { + if (!argSliceVector.TryAddItem(ns, key)) + return false; + + var slot = (int)HashUtils.MurmurHash2x64A(key, seed: (uint)ns) & (size - 1); + var byteOffset = slot >> 3; + var bitOffset = slot & 7; + bitmap[byteOffset] = (byte)(bitmap[byteOffset] | (1UL << bitOffset)); + + return true; + } + /// /// Hash key to bloomfilter and store it for future use (NOTE: Use only with KEYS option) /// @@ -65,7 +78,19 @@ public unsafe void HashAndStore(ref ArgSlice key) /// public unsafe bool Probe(SpanByte key, out SketchStatus status) { - var slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length) & (size - 1); + int slot; + + // TODO: better way to detect namespace + if (key.MetadataSize == 1) + { + var ns = key.GetNamespaceInPayload(); + slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length, seed: (uint)ns) & (size - 1); + } + else + { + slot = (int)HashUtils.MurmurHash2x64A(key.ToPointer(), key.Length) & (size - 1); + } + var byteOffset = slot >> 3; var bitOffset = slot & 7; diff --git a/libs/cluster/Session/ClusterCommands.cs b/libs/cluster/Session/ClusterCommands.cs index 104e05144b7..d938b710340 100644 --- a/libs/cluster/Session/ClusterCommands.cs +++ b/libs/cluster/Session/ClusterCommands.cs @@ -135,7 +135,7 @@ private bool TryParseSlots(int startIdx, out HashSet slots, out ReadOnlySpa /// Subcommand to execute. /// True if number of parameters is invalid /// True if command is fully processed, false if more processing is needed. - private void ProcessClusterCommands(RespCommand command, out bool invalidParameters) + private void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, out bool invalidParameters) { _ = command switch { @@ -173,6 +173,7 @@ private void ProcessClusterCommands(RespCommand command, out bool invalidParamet RespCommand.CLUSTER_PUBLISH or RespCommand.CLUSTER_SPUBLISH => NetworkClusterPublish(out invalidParameters), RespCommand.CLUSTER_REPLICAS => NetworkClusterReplicas(out invalidParameters), RespCommand.CLUSTER_REPLICATE => NetworkClusterReplicate(out invalidParameters), + RespCommand.CLUSTER_RESERVE => NetworkClusterReserve(vectorManager, out invalidParameters), RespCommand.CLUSTER_RESET => NetworkClusterReset(out invalidParameters), RespCommand.CLUSTER_SEND_CKPT_FILE_SEGMENT => NetworkClusterSendCheckpointFileSegment(out invalidParameters), RespCommand.CLUSTER_SEND_CKPT_METADATA => NetworkClusterSendCheckpointMetadata(out invalidParameters), diff --git a/libs/cluster/Session/ClusterKeyIterationFunctions.cs b/libs/cluster/Session/ClusterKeyIterationFunctions.cs index 54d91d6cd3d..af011f3798c 100644 --- a/libs/cluster/Session/ClusterKeyIterationFunctions.cs +++ b/libs/cluster/Session/ClusterKeyIterationFunctions.cs @@ -34,6 +34,14 @@ internal sealed class MainStoreCountKeys : IScanIteratorFunctions keys, int slot, int maxKeyCount) public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) { + // TODO: better way to detect namespace + if (key.MetadataSize == 1) + { + // Namespace means not visible + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + cursorRecordResult = CursorRecordResult.Accept; // default; not used here, out CursorRecordResult cursorRecordResult + if (HashSlotUtils.HashSlot(ref key) == slot && !Expired(ref value)) keys.Add(key.ToByteArray()); return keys.Count < maxKeyCount; diff --git a/libs/cluster/Session/ClusterSession.cs b/libs/cluster/Session/ClusterSession.cs index 45780b2d2bf..bfe1f6c475a 100644 --- a/libs/cluster/Session/ClusterSession.cs +++ b/libs/cluster/Session/ClusterSession.cs @@ -12,12 +12,21 @@ namespace Garnet.cluster { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; internal sealed unsafe partial class ClusterSession : IClusterSession { @@ -57,7 +66,20 @@ internal sealed unsafe partial class ClusterSession : IClusterSession /// public IGarnetServer Server { get; set; } - public ClusterSession(ClusterProvider clusterProvider, TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics sessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null) + private VectorContext vectorContext; + private BasicContext basicContext; + + public ClusterSession( + ClusterProvider clusterProvider, + TransactionManager txnManager, + IGarnetAuthenticator authenticator, + UserHandle userHandle, + GarnetSessionMetrics sessionMetrics, + BasicGarnetApi basicGarnetApi, + BasicContext basicContext, + VectorContext vectorContext, + INetworkSender networkSender, + ILogger logger = null) { this.clusterProvider = clusterProvider; this.authenticator = authenticator; @@ -65,11 +87,13 @@ public ClusterSession(ClusterProvider clusterProvider, TransactionManager txnMan this.txnManager = txnManager; this.sessionMetrics = sessionMetrics; this.basicGarnetApi = basicGarnetApi; + this.basicContext = basicContext; + this.vectorContext = vectorContext; this.networkSender = networkSender; this.logger = logger; } - public void ProcessClusterCommands(RespCommand command, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend) + public void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend) { this.dcurr = dcurr; this.dend = dend; @@ -89,7 +113,7 @@ public void ProcessClusterCommands(RespCommand command, ref SessionParseState pa return; } - ProcessClusterCommands(command, out invalidParameters); + ProcessClusterCommands(command, vectorManager, out invalidParameters); } else { diff --git a/libs/cluster/Session/MigrateCommand.cs b/libs/cluster/Session/MigrateCommand.cs index 897ca187e02..f884d4d27b1 100644 --- a/libs/cluster/Session/MigrateCommand.cs +++ b/libs/cluster/Session/MigrateCommand.cs @@ -13,7 +13,7 @@ namespace Garnet.cluster { internal sealed unsafe partial class ClusterSession : IClusterSession { - public static bool Expired(ref SpanByte value) => value.MetadataSize > 0 && value.ExtraMetadata < DateTimeOffset.UtcNow.Ticks; + public static bool Expired(ref SpanByte value) => value.MetadataSize == 8 && value.ExtraMetadata < DateTimeOffset.UtcNow.Ticks; public static bool Expired(ref IGarnetObject value) => value.Expiration != 0 && value.Expiration < DateTimeOffset.UtcNow.Ticks; diff --git a/libs/cluster/Session/RespClusterMigrateCommands.cs b/libs/cluster/Session/RespClusterMigrateCommands.cs index 3dd58cf82a1..5fe9c8d1c4c 100644 --- a/libs/cluster/Session/RespClusterMigrateCommands.cs +++ b/libs/cluster/Session/RespClusterMigrateCommands.cs @@ -17,7 +17,10 @@ namespace Garnet.cluster SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; internal sealed unsafe partial class ClusterSession : IClusterSession { @@ -103,18 +106,30 @@ void Process(BasicGarnetApi basicGarnetApi, byte[] input, string storeTypeSpan, continue; } - var slot = HashSlotUtils.HashSlot(ref key); - if (!currentConfig.IsImportingSlot(slot)) // Slot is not in importing state + // TODO: better way to handle namespaces + if (key.MetadataSize == 1) { - migrateState = 1; - i++; - continue; + // This is a Vector Set namespace key being migrated - it won't necessarily look like it's "in" a hash slot + // because it's dependent on some other key (the index key) being migrated which itself is in a moving hash slot + + clusterProvider.storeWrapper.DefaultDatabase.VectorManager.HandleMigratedElementKey(ref basicContext, ref vectorContext, ref key, ref value); + } + else + { + var slot = HashSlotUtils.HashSlot(ref key); + if (!currentConfig.IsImportingSlot(slot)) // Slot is not in importing state + { + migrateState = 1; + i++; + continue; + } + + // Set if key replace flag is set or key does not exist + var keySlice = new ArgSlice(key.ToPointer(), key.Length); + if (replaceOption || !Exists(ref keySlice)) + _ = basicGarnetApi.SET(ref key, ref value); } - // Set if key replace flag is set or key does not exist - var keySlice = new ArgSlice(key.ToPointer(), key.Length); - if (replaceOption || !Exists(ref keySlice)) - _ = basicGarnetApi.SET(ref key, ref value); i++; } } @@ -150,6 +165,35 @@ void Process(BasicGarnetApi basicGarnetApi, byte[] input, string storeTypeSpan, i++; } } + else if (storeTypeSpan.Equals("VSTORE", StringComparison.OrdinalIgnoreCase)) + { + // This is the subset of the main store that holds Vector Set _index_ keys + // + // Namespace'd element keys are handled by the SSTORE path + + var keyCount = *(int*)payloadPtr; + payloadPtr += 4; + var i = 0; + + TrackImportProgress(keyCount, isMainStore: true, keyCount == 0); + while (i < keyCount) + { + ref var key = ref SpanByte.Reinterpret(payloadPtr); + payloadPtr += key.TotalSize; + ref var value = ref SpanByte.Reinterpret(payloadPtr); + payloadPtr += value.TotalSize; + + // An error has occurred + if (migrateState > 0) + { + i++; + continue; + } + + clusterProvider.storeWrapper.DefaultDatabase.VectorManager.HandleMigratedIndexKey(clusterProvider.storeWrapper.DefaultDatabase, clusterProvider.storeWrapper, ref key, ref value); + i++; + } + } else { throw new Exception("CLUSTER MIGRATE STORE TYPE ERROR!"); diff --git a/libs/cluster/Session/RespClusterReplicationCommands.cs b/libs/cluster/Session/RespClusterReplicationCommands.cs index b30a3ff00f4..8bc596c8e54 100644 --- a/libs/cluster/Session/RespClusterReplicationCommands.cs +++ b/libs/cluster/Session/RespClusterReplicationCommands.cs @@ -115,6 +115,59 @@ private bool NetworkClusterReplicate(out bool invalidParameters) return true; } + /// + /// Implements CLUSTER reserve command (only for internode use). + /// + /// Allows for pre-migration reservation of certain resources. + /// + /// For now, this is only used for Vector Sets. + /// + private bool NetworkClusterReserve(VectorManager vectorManager, out bool invalidParameters) + { + if (parseState.Count < 2) + { + invalidParameters = true; + return true; + } + + var kind = parseState.GetArgSliceByRef(0); + if (!kind.ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("VECTOR_SET_CONTEXTS"u8)) + { + while (!RespWriteUtils.TryWriteError("Unrecognized reservation type"u8, ref dcurr, dend)) + SendAndReset(); + + invalidParameters = false; + return true; + } + + if (!parseState.TryGetInt(1, out var numVectorSetContexts) || numVectorSetContexts <= 0) + { + invalidParameters = true; + return true; + } + + invalidParameters = false; + + if (!vectorManager.TryReserveContextsForMigration(ref vectorContext, numVectorSetContexts, out var newContexts)) + { + while (!RespWriteUtils.TryWriteError("Insufficients contexts available to reserve"u8, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + while (!RespWriteUtils.TryWriteArrayLength(newContexts.Count, ref dcurr, dend)) + SendAndReset(); + + foreach (var ctx in newContexts) + { + while (!RespWriteUtils.TryWriteInt64AsSimpleString((long)ctx, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + /// /// Implements CLUSTER aofsync command (only for internode use) /// diff --git a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs index 0416c064d43..b2bdc1fba17 100644 --- a/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/ClusterSlotVerify.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; using Garnet.server; @@ -23,9 +24,13 @@ private bool CheckIfKeyExists(byte[] key) } } - private ClusterSlotVerificationResult SingleKeySlotVerify(ref ClusterConfig config, ref ArgSlice keySlice, bool readOnly, byte SessionAsking, int slot = -1) + private ClusterSlotVerificationResult SingleKeySlotVerify(ref ClusterConfig config, ref ArgSlice keySlice, bool readOnly, byte SessionAsking, bool isVectorSetWriteCommand, int slot = -1) { - return readOnly ? SingleKeyReadSlotVerify(ref config, ref keySlice) : SingleKeyReadWriteSlotVerify(ref config, ref keySlice); + Debug.Assert(!isVectorSetWriteCommand || (isVectorSetWriteCommand && !readOnly), "Shouldn't see Vector Set writes and readonly at same time"); + + var ret = readOnly ? SingleKeyReadSlotVerify(ref config, ref keySlice) : SingleKeyReadWriteSlotVerify(isVectorSetWriteCommand, ref config, ref keySlice); + + return ret; [MethodImpl(MethodImplOptions.AggressiveInlining)] ClusterSlotVerificationResult SingleKeyReadSlotVerify(ref ClusterConfig config, ref ArgSlice keySlice) @@ -69,12 +74,20 @@ ClusterSlotVerificationResult SingleKeyReadSlotVerify(ref ClusterConfig config, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - ClusterSlotVerificationResult SingleKeyReadWriteSlotVerify(ref ClusterConfig config, ref ArgSlice keySlice) + ClusterSlotVerificationResult SingleKeyReadWriteSlotVerify(bool isVectorSetWriteCommand, ref ClusterConfig config, ref ArgSlice keySlice) { var _slot = slot == -1 ? ArgSliceUtils.HashSlot(ref keySlice) : (ushort)slot; + + tryAgain: var IsLocal = config.IsLocal(_slot, readWriteSession: readWriteSession); var state = config.GetState(_slot); + if (isVectorSetWriteCommand && state is SlotState.IMPORTING or SlotState.MIGRATING) + { + WaitForSlotToStabalize(_slot, ref keySlice, ref config); + goto tryAgain; + } + // Redirect r/w requests towards primary if (config.LocalNodeRole == NodeRole.REPLICA && !readWriteSession) return new(SlotVerifiedState.MOVED, _slot); @@ -123,18 +136,35 @@ bool CanOperateOnKey(ref ArgSlice key, int slot, bool readOnly) } return Exists(ref key); } + + void WaitForSlotToStabalize(ushort slot, ref ArgSlice keySlice, ref ClusterConfig config) + { + // For Vector Set ops specifically, we need a slot to be stable (or faulted, but not migrating) before writes can proceed + // + // This isn't key specific because we can't know the Vector Sets being migrated in advance, only that the slot is moving + + do + { + ReleaseCurrentEpoch(); + _ = Thread.Yield(); + AcquireCurrentEpoch(); + + config = clusterProvider.clusterManager.CurrentConfig; + } + while (config.GetState(slot) is SlotState.IMPORTING or SlotState.MIGRATING); + } } - ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Span keys, bool readOnly, byte sessionAsking, int count) + ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Span keys, bool readOnly, byte sessionAsking, bool isVectorSetWriteCommand, int count) { var _end = count < 0 ? keys.Length : count; var slot = ArgSliceUtils.HashSlot(ref keys[0]); - var verifyResult = SingleKeySlotVerify(ref config, ref keys[0], readOnly, sessionAsking, slot); + var verifyResult = SingleKeySlotVerify(ref config, ref keys[0], readOnly, sessionAsking, isVectorSetWriteCommand, slot); for (var i = 1; i < _end; i++) { var _slot = ArgSliceUtils.HashSlot(ref keys[i]); - var _verifyResult = SingleKeySlotVerify(ref config, ref keys[i], readOnly, sessionAsking, _slot); + var _verifyResult = SingleKeySlotVerify(ref config, ref keys[i], readOnly, sessionAsking, isVectorSetWriteCommand, _slot); // Check if slot changes between keys if (_slot != slot) @@ -152,7 +182,7 @@ ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Sessi { ref var key = ref parseState.GetArgSliceByRef(csvi.firstKey); var slot = ArgSliceUtils.HashSlot(ref key); - var verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, slot); + var verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, csvi.isVectorSetWriteCommand, slot); var secondKey = csvi.firstKey + csvi.step; for (var i = secondKey; i < csvi.lastKey; i += csvi.step) @@ -161,7 +191,7 @@ ClusterSlotVerificationResult MultiKeySlotVerify(ClusterConfig config, ref Sessi continue; key = ref parseState.GetArgSliceByRef(i); var _slot = ArgSliceUtils.HashSlot(ref key); - var _verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, _slot); + var _verifyResult = SingleKeySlotVerify(ref config, ref key, csvi.readOnly, csvi.sessionAsking, csvi.isVectorSetWriteCommand, _slot); // Check if slot changes between keys if (_slot != slot) diff --git a/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs b/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs index 3fe36867e9c..7bae12b778c 100644 --- a/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/RespClusterIterativeSlotVerify.cs @@ -28,14 +28,14 @@ public void ResetCachedSlotVerificationResult() /// /// /// - public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking) + public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking, bool isVectorSetWriteCommand) { ClusterSlotVerificationResult verifyResult; // If it is the first verification initialize the result cache if (!initialized) { - verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking); + verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking, isVectorSetWriteCommand); cachedVerificationResult = verifyResult; initialized = true; return verifyResult.state == SlotVerifiedState.OK; @@ -45,7 +45,7 @@ public bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte Se if (cachedVerificationResult.state != SlotVerifiedState.OK) return false; - verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking); + verifyResult = SingleKeySlotVerify(ref configSnapshot, ref keySlice, readOnly, SessionAsking, isVectorSetWriteCommand); // Check if slot changes between keys if (verifyResult.slot != cachedVerificationResult.slot) diff --git a/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs b/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs index af69ed8d2b2..d61822b1f4a 100644 --- a/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs +++ b/libs/cluster/Session/SlotVerification/RespClusterSlotVerify.cs @@ -95,13 +95,13 @@ private void WriteClusterSlotVerificationMessage(ClusterConfig config, ClusterSl /// /// /// - public bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte sessionAsking, ref byte* dcurr, ref byte* dend, int count = -1) + public bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte sessionAsking, bool isVectorSetWriteCommand, ref byte* dcurr, ref byte* dend, int count = -1) { // If cluster is not enabled or a transaction is running skip slot check if (!clusterProvider.serverOptions.EnableCluster || txnManager.state == TxnState.Running) return false; var config = clusterProvider.clusterManager.CurrentConfig; - var vres = MultiKeySlotVerify(config, ref keys, readOnly, sessionAsking, count); + var vres = MultiKeySlotVerify(config, ref keys, readOnly, sessionAsking, isVectorSetWriteCommand, count); if (vres.state == SlotVerifiedState.OK) return false; diff --git a/libs/common/HashSlotUtils.cs b/libs/common/HashSlotUtils.cs index f1811ce3a7e..67fbc4d29fd 100644 --- a/libs/common/HashSlotUtils.cs +++ b/libs/common/HashSlotUtils.cs @@ -10,6 +10,8 @@ namespace Garnet.common { public static unsafe class HashSlotUtils { + public const ushort MaxHashSlot = 16_383; + /// /// This table is based on the CRC-16-CCITT polynomial (0x1021) /// @@ -101,14 +103,14 @@ public static unsafe ushort HashSlot(byte* keyPtr, int ksize) var startPtr = keyPtr; var end = keyPtr + ksize; - // Find first occurence of '{' + // Find first occurrence of '{' while (startPtr < end && *startPtr != '{') { startPtr++; } // Return early if did not find '{' - if (startPtr == end) return (ushort)(Hash(keyPtr, ksize) & 16383); + if (startPtr == end) return (ushort)(Hash(keyPtr, ksize) & MaxHashSlot); var endPtr = startPtr + 1; @@ -116,10 +118,10 @@ public static unsafe ushort HashSlot(byte* keyPtr, int ksize) while (endPtr < end && *endPtr != '}') { endPtr++; } // Return early if did not find '}' after '{' - if (endPtr == end || endPtr == startPtr + 1) return (ushort)(Hash(keyPtr, ksize) & 16383); + if (endPtr == end || endPtr == startPtr + 1) return (ushort)(Hash(keyPtr, ksize) & MaxHashSlot); // Return hash for byte sequence between brackets - return (ushort)(Hash(startPtr + 1, (int)(endPtr - startPtr - 1)) & 16383); + return (ushort)(Hash(startPtr + 1, (int)(endPtr - startPtr - 1)) & MaxHashSlot); } } } \ No newline at end of file diff --git a/libs/common/ReadOptimizedLock.cs b/libs/common/ReadOptimizedLock.cs new file mode 100644 index 00000000000..2d3e76b8861 --- /dev/null +++ b/libs/common/ReadOptimizedLock.cs @@ -0,0 +1,408 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Garnet.common +{ + /// + /// Holds a set of RW-esque locks, optimized for reads. + /// + /// This was originally created for Vector Sets, but is general enough for reuse. + /// For Vector Sets, these are acquired and released as needed to prevent concurrent creation/deletion operations or deletion concurrent with read operations. + /// + /// These are outside of Tsavorite for re-entrancy reasons reasons. + /// + /// + /// This is a counter based r/w lock scheme, with a bit of biasing for cache line awareness. + /// + /// Each "key" acquires locks based on its hash. + /// Each hash is mapped to a range of indexes, each range is lockShardCount in length. + /// When acquiring a shared lock, we take one index out of the keys range and acquire a read lock. + /// This will block exclusive locks, but not impact other readers. + /// When acquiring an exclusive lock, we acquire write locks for all indexes in the key's range IN INCREASING _LOGICAL_ ORDER. + /// The order is necessary to avoid deadlocks. + /// By ensuring all exclusive locks walk "up" we guarantee no two exclusive lock acquisitions end up waiting for each other. + /// + /// Locks themselves are just ints, where a negative value indicates an exclusive lock and a positive value is the number of active readers. + /// Read locks are acquired optimistically, so actual lock values will fluctate above int.MinValue when an exclusive lock is held. + /// + /// The last set of optimizations is around cache lines coherency: + /// We assume cache lines of 64-bytes (the x86 default, which is also true for some [but not all] ARM processors) and size counters-per-core in multiples of that + /// We access array elements via reference, to avoid thrashing cache lines due to length checks + /// Each shard is placed, in so much as is possible, into a different cache line rather than grouping a hash's counts physically near each other + /// This will tend to allow a core to retain ownership of the same cache lines even as it moves between different hashes + /// + /// Experimentally (using some rough microbenchmarks) various optimizations are worth (on either shared or exclusive acquisiton paths): + /// - Split shards across cache lines : 7x (read path), 2.5x (write path) + /// - Fast math instead of mod and mult : 50% (read path), 20% (write path) + /// - Unsafe ref instead of array access: 0% (read path), 10% (write path) + /// + public struct ReadOptimizedLock + { + // Beyond 4K bytes per core we're well past "this is worth the tradeoff", so cut off then. + // + // Must be a power of 2. + private const int MaxPerCoreContexts = 1_024; + + /// + /// Estimated size of cache lines on a processor. + /// + /// Generally correct for x86-derived processors, sometimes correct for ARM-derived ones. + /// + public const int CacheLineSizeBytes = 64; + + [ThreadStatic] + private static int ProcessorHint; + + private readonly int[] lockCounts; + private readonly int coreSelectionMask; + private readonly int perCoreCounts; + private readonly ulong perCoreCountsFastMod; + private readonly byte perCoreCountsMultShift; + + /// + /// Create a new . + /// + /// accuracy impacts performance, not correctness. + /// + /// Too low and unrelated locks will end up delaying each other. + /// Too high and more memory than is necessary will be used. + /// + public ReadOptimizedLock(int estimatedSimultaneousActiveLockers) + { + Debug.Assert(estimatedSimultaneousActiveLockers > 0); + + // ~1 per core + var coreCount = (int)BitOperations.RoundUpToPowerOf2((uint)Environment.ProcessorCount); + coreSelectionMask = coreCount - 1; + + // Use estimatedSimultaneousActiveLockers to determine number of shards per lock. + // + // We scale up to a whole multiple of CacheLineSizeBytes to reduce cache line thrashing. + // + // We scale to a power of 2 to avoid divisions (and some multiplies) in index calculation. + perCoreCounts = estimatedSimultaneousActiveLockers; + if (perCoreCounts % (CacheLineSizeBytes / sizeof(int)) != 0) + { + perCoreCounts += (CacheLineSizeBytes / sizeof(int)) - (perCoreCounts % (CacheLineSizeBytes / sizeof(int))); + } + Debug.Assert(perCoreCounts % (CacheLineSizeBytes / sizeof(int)) == 0, "Each core should be whole cache lines of data"); + + perCoreCounts = (int)BitOperations.RoundUpToPowerOf2((uint)perCoreCounts); + + // Put an upper bound of ~1 page worth of locks per core (which is still quite high). + // + // For the largest realistic machines out there (384 cores) this will put us at around ~2M of lock data, max. + if (perCoreCounts is <= 0 or > MaxPerCoreContexts) + { + perCoreCounts = MaxPerCoreContexts; + } + + // Pre-calculate an alternative to %, as that division will be in the hot path + perCoreCountsFastMod = (ulong.MaxValue / (uint)perCoreCounts) + 1; + + // Avoid two multiplies in the hot path + perCoreCountsMultShift = (byte)BitOperations.Log2((uint)perCoreCounts); + + var numInts = coreCount * perCoreCounts; + lockCounts = new int[numInts]; + } + + /// + /// Take a hash and a _hint_ about the current processor and determine which count should be used. + /// + /// Walking from 0 to ( + 1) [exclusive] will return + /// all possible counts for a given hash. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly int CalculateIndex(long hashLong, int currentProcessorHint) + { + // Throw away half the top half of the hash + // + // This set of locks will be small enough that the extra bits shoulnd't matter + var hash = (int)hashLong; + + // Hint might be out of range, so force it into the space we expect + var currentProcessor = currentProcessorHint & coreSelectionMask; + + var startOfCoreCounts = currentProcessor << perCoreCountsMultShift; + + // Avoid doing a division in the hot path + // Based on: https://github.com/dotnet/runtime/blob/3a95842304008b9ca84c14b4bec9ec99ed5802db/src/libraries/System.Private.CoreLib/src/System/Collections/HashHelpers.cs#L99 + var hashOffset = (uint)(((((perCoreCountsFastMod * (uint)hash) >> 32) + 1) << perCoreCountsMultShift) >> 32); + + Debug.Assert(hashOffset == ((uint)hash % perCoreCounts), "Replacing mod with multiplies failed"); + + var ix = (int)(startOfCoreCounts + hashOffset); + + Debug.Assert(ix >= 0 && ix < lockCounts.Length, "About to do something out of bounds"); + + return ix; + } + + /// + /// Attempt to acquire a shared lock for the given hash. + /// + /// Will block exclusive locks until released. + /// + public readonly bool TryAcquireSharedLock(long hash, out int lockToken) + { + var ix = CalculateIndex(hash, GetProcessorHint()); + + ref var acquireRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), ix); + + var res = Interlocked.Increment(ref acquireRef); + if (res < 0) + { + // Exclusively locked + _ = Interlocked.Decrement(ref acquireRef); + Unsafe.SkipInit(out lockToken); + return false; + } + + lockToken = ix; + return true; + } + + /// + /// Acquire a shared lock for the given hash, blocking until that succeeds. + /// + /// Will block exclusive locks until released. + /// + public readonly void AcquireSharedLock(long hash, out int lockToken) + { + var ix = CalculateIndex(hash, GetProcessorHint()); + + ref var acquireRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), ix); + + while (true) + { + var res = Interlocked.Increment(ref acquireRef); + if (res < 0) + { + // Exclusively locked + _ = Interlocked.Decrement(ref acquireRef); + + // Spin until we can grab this one + _ = Thread.Yield(); + } + else + { + lockToken = ix; + return; + } + } + } + + /// + /// Release a lock previously acquired with or . + /// + public readonly void ReleaseSharedLock(int lockToken) + { + Debug.Assert(lockToken >= 0 && lockToken < lockCounts.Length, "Invalid lock token"); + + ref var releaseRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(lockCounts), lockToken); + + _ = Interlocked.Decrement(ref releaseRef); + } + + /// + /// Attempt to acquire an exclusive lock for the given hash. + /// + /// Will block all other locks until released. + /// + public readonly bool TryAcquireExclusiveLock(long hash, out int lockToken) + { + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Failed, release previously acquired + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + Unsafe.SkipInit(out lockToken); + return false; + } + } + + // Successfully acquired all shards exclusively + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + lockToken = (int)hash; + + return true; + } + + /// + /// Acquire an exclusive lock for the given hash, blocking until that succeeds. + /// + /// Will block all other locks until released. + /// + public readonly void AcquireExclusiveLock(long hash, out int lockToken) + { + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + while (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Optimistic shared lock got us, or conflict with some other excluive lock acquisition + // + // Backoff and try again + _ = Thread.Yield(); + } + } + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + lockToken = (int)hash; + } + + /// + /// Release a lock previously acquired with , , or . + /// + public readonly void ReleaseExclusiveLock(int lockToken) + { + // The lockToken is a hash, so no range check here + + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var hash = lockToken; + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var releaseIx = CalculateIndex(hash, i); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + } + + /// + /// Attempt to promote a shared lock previously acquired via or to an exclusive lock. + /// + /// If successful, will block all other locks until released. + /// + /// If successful, must be released with . + /// + /// If unsuccessful, shared lock will still be held and must be released with . + /// + public readonly bool TryPromoteSharedLock(long hash, int lockToken, out int newLockToken) + { + Debug.Assert(Interlocked.CompareExchange(ref lockCounts[lockToken], 0, 0) > 0, "Illegal call when not holding shard lock"); + + Debug.Assert(lockToken >= 0 && lockToken < lockCounts.Length, "Invalid lock token"); + + ref var countRef = ref MemoryMarshal.GetArrayDataReference(lockCounts); + + var coreCount = coreSelectionMask + 1; + for (var i = 0; i < coreCount; i++) + { + var acquireIx = CalculateIndex(hash, i); + ref var acquireRef = ref Unsafe.Add(ref countRef, acquireIx); + + if (acquireIx == lockToken) + { + // Do the promote + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 1) != 1) + { + // Failed, release previously acquired all of which are exclusive locks + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, 0, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + // Note we're still holding the shared lock here + Unsafe.SkipInit(out newLockToken); + return false; + } + } + else + { + // Otherwise attempt an exclusive acquire + if (Interlocked.CompareExchange(ref acquireRef, int.MinValue, 0) != 0) + { + // Failed, release previously acquired - one of which MIGHT be the shared lock + for (var j = 0; j < i; j++) + { + var releaseIx = CalculateIndex(hash, j); + var releaseTargetValue = releaseIx == lockToken ? 1 : 0; + + ref var releaseRef = ref Unsafe.Add(ref countRef, releaseIx); + while (Interlocked.CompareExchange(ref releaseRef, releaseTargetValue, int.MinValue) != int.MinValue) + { + // Optimistic shared lock got us, back off and try again + _ = Thread.Yield(); + } + } + + // Note we're still holding the shared lock here + Unsafe.SkipInit(out newLockToken); + return false; + } + } + } + + // Throwing away half the hash shouldn't affect correctness since we do the same thing when processing the full hash + newLockToken = (int)hash; + return true; + } + + /// + /// Get a somewhat-correlated-to-processor value. + /// + /// While we could use , that isn't fast on all platforms. + /// + /// For our purposes, we just need something that will tend to keep different active processors + /// from touching each other. ManagedThreadId works well enough. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetProcessorHint() + { + var ret = ProcessorHint; + if (ret == 0) + { + ProcessorHint = ret = Environment.CurrentManagedThreadId; + } + + return ret; + } + } +} \ No newline at end of file diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index 92c41ec4739..1202e8c0e09 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -1341,5 +1341,40 @@ public static bool TryReadInfinity(ReadOnlySpan value, out double number) number = default; return false; } + + /// + /// Parses "[+/-]inf" string and returns float.PositiveInfinity/float.NegativeInfinity respectively. + /// If string is not an infinity, parsing fails. + /// + /// input data + /// If parsing was successful,contains positive or negative infinity + /// True is infinity was read, false otherwise + public static bool TryReadInfinity(ReadOnlySpan value, out float number) + { + if (value.Length == 3) + { + if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.INFINITY)) + { + number = float.PositiveInfinity; + return true; + } + } + else if (value.Length == 4) + { + if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.POS_INFINITY, true)) + { + number = float.PositiveInfinity; + return true; + } + else if (value.EqualsUpperCaseSpanIgnoringCase(RespStrings.NEG_INFINITY, true)) + { + number = float.NegativeInfinity; + return true; + } + } + + number = default; + return false; + } } } \ No newline at end of file diff --git a/libs/host/Configuration/Options.cs b/libs/host/Configuration/Options.cs index 2ba5b3302d4..ba87d1c59a9 100644 --- a/libs/host/Configuration/Options.cs +++ b/libs/host/Configuration/Options.cs @@ -666,6 +666,9 @@ public IEnumerable LuaAllowedFunctions [Option("cluster-replica-resume-with-data", Required = false, HelpText = "If a Cluster Replica resumes with data, allow it to be served prior to a Primary being available")] public bool ClusterReplicaResumeWithData { get; set; } + [Option("enable-vector-set-preview", Required = false, HelpText = "Enable Vector Sets (preview) - this feature (and associated commands) are incomplete, unstable, and subject to change while still in preview")] + public bool EnableVectorSetPreview { get; set; } + /// /// This property contains all arguments that were not parsed by the command line argument parser /// @@ -952,6 +955,7 @@ public GarnetServerOptions GetServerOptions(ILogger logger = null) ExpiredKeyDeletionScanFrequencySecs = ExpiredKeyDeletionScanFrequencySecs, ClusterReplicationReestablishmentTimeout = ClusterReplicationReestablishmentTimeout, ClusterReplicaResumeWithData = ClusterReplicaResumeWithData, + EnableVectorSetPreview = EnableVectorSetPreview, }; } diff --git a/libs/host/GarnetServer.cs b/libs/host/GarnetServer.cs index 2d12f43a0d4..05103d39106 100644 --- a/libs/host/GarnetServer.cs +++ b/libs/host/GarnetServer.cs @@ -303,9 +303,18 @@ private GarnetDatabase CreateDatabase(int dbId, GarnetServerOptions serverOption var store = CreateMainStore(dbId, clusterFactory, out var epoch, out var stateMachineDriver); var objectStore = CreateObjectStore(dbId, clusterFactory, customCommandManager, epoch, stateMachineDriver, out var objectStoreSizeTracker); var (aofDevice, aof) = CreateAOF(dbId); + + var vectorManager = new VectorManager( + serverOptions.EnableVectorSetPreview, + dbId, + () => Provider.GetSession(WireFormat.ASCII, null), + loggerFactory + ); + return new GarnetDatabase(dbId, store, objectStore, epoch, stateMachineDriver, objectStoreSizeTracker, aofDevice, aof, serverOptions.AdjustedIndexMaxCacheLines == 0, - serverOptions.AdjustedObjectStoreIndexMaxCacheLines == 0); + serverOptions.AdjustedObjectStoreIndexMaxCacheLines == 0, + vectorManager); } private void LoadModules(CustomCommandManager customCommandManager) diff --git a/libs/host/defaults.conf b/libs/host/defaults.conf index a14f15f5052..0b0d025144a 100644 --- a/libs/host/defaults.conf +++ b/libs/host/defaults.conf @@ -447,5 +447,8 @@ "ClusterReplicationReestablishmentTimeout": 0, /* If a Cluster Replica has on disk checkpoints or AOF, if that data should be loaded on restart instead of waiting for a Primary to sync with */ - "ClusterReplicaResumeWithData": false + "ClusterReplicaResumeWithData": false, + + /* Enable Vector Sets (preview) - this feature (and associated commands) are incomplete, unstable, and subject to change while still in preview */ + "EnableVectorSetPreview": false } \ No newline at end of file diff --git a/libs/resources/RespCommandsDocs.json b/libs/resources/RespCommandsDocs.json index be77703a3ed..049b46fd80f 100644 --- a/libs/resources/RespCommandsDocs.json +++ b/libs/resources/RespCommandsDocs.json @@ -7719,6 +7719,204 @@ "Group": "Transactions", "Complexity": "O(1)" }, + { + "Command": "VADD", + "Name": "VADD", + "Summary": "Add a new element into the vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VCARD", + "Name": "VCARD", + "Summary": "Return the number of elements in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VDIM", + "Name": "VDIM", + "Summary": "Return the number of dimensions in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VEMB", + "Name": "VEMB", + "Summary": "Return the approximate vector associated with an element in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VGETATTR", + "Name": "VGETATTR", + "Summary": "Return the JSON attributes associated with the element in the vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VINFO", + "Name": "VINFO", + "Summary": "Return details about a vector set, including dimensions, quantization, and structure.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VISMEMBER", + "Name": "VISMEMBER", + "Summary": "Determines whether a member belongs to vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + }, + { + "TypeDiscriminator": "RespCommandBasicArgument", + "Name": "ELEMENT", + "DisplayText": "element", + "Type": "String" + } + ] + }, + { + "Command": "VLINKS", + "Name": "VLINKS", + "Summary": "Return the neighbors of an element in a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VRANDMEMBER", + "Name": "VRANDMEMBER", + "Summary": "Return some number of random elements from a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VREM", + "Name": "VREM", + "Summary": "Remove an element from a vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VSETATTR", + "Name": "VSETATTR", + "Summary": "Store attributes alongside a member of a vector set.", + "Group": "Vector", + "Complexity": "O(1)", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, + { + "Command": "VSIM", + "Name": "VSIM", + "Summary": "Return elements similar to a given vector or existing element of a vector set.", + "Group": "Vector", + "Complexity": "O(log(N))", + "Arguments": [ + { + "TypeDiscriminator": "RespCommandKeyArgument", + "Name": "KEY", + "DisplayText": "key", + "Type": "Key", + "KeySpecIndex": 0 + } + ] + }, { "Command": "WATCH", "Name": "WATCH", diff --git a/libs/resources/RespCommandsInfo.json b/libs/resources/RespCommandsInfo.json index daa1acd29d3..40aa9686505 100644 --- a/libs/resources/RespCommandsInfo.json +++ b/libs/resources/RespCommandsInfo.json @@ -811,6 +811,14 @@ "Flags": "Admin, NoMulti, NoScript", "AclCategories": "Admin, Dangerous, Slow, Garnet" }, + { + "Command": "CLUSTER_RESERVE", + "Name": "CLUSTER|RESERVE", + "IsInternal": true, + "Arity": 4, + "Flags": "Admin, NoMulti, NoScript", + "AclCategories": "Admin, Dangerous, Garnet" + }, { "Command": "CLUSTER_MTASKS", "Name": "CLUSTER|MTASKS", @@ -5093,6 +5101,306 @@ "Flags": "Fast, Loading, NoScript, Stale, AllowBusy", "AclCategories": "Fast, Transaction" }, + { + "Command": "VADD", + "Name": "VADD", + "Arity": -1, + "Flags": "DenyOom, Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Vector, Write", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Insert" + } + ] + }, + { + "Command": "VCARD", + "Name": "VCARD", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VDIM", + "Name": "VDIM", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VEMB", + "Name": "VEMB", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VGETATTR", + "Name": "VGETATTR", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VINFO", + "Name": "VINFO", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VISMEMBER", + "Name": "VISMEMBER", + "Arity": 3, + "Flags": "Fast, ReadOnly", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VLINKS", + "Name": "VLINKS", + "Arity": -1, + "Flags": "Fast, ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VRANDMEMBER", + "Name": "VRANDMEMBER", + "Arity": -1, + "Flags": "ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, + { + "Command": "VREM", + "Name": "VREM", + "Arity": -1, + "Flags": "Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Write, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Delete" + } + ] + }, + { + "Command": "VSETATTR", + "Name": "VSETATTR", + "Arity": -1, + "Flags": "Fast, Write, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Fast, Write, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RW, Insert" + } + ] + }, + { + "Command": "VSIM", + "Name": "VSIM", + "Arity": -1, + "Flags": "ReadOnly, Module", + "FirstKey": 1, + "LastKey": 1, + "Step": 1, + "AclCategories": "Slow, Read, Vector", + "KeySpecifications": [ + { + "BeginSearch": { + "TypeDiscriminator": "BeginSearchIndex", + "Index": 1 + }, + "FindKeys": { + "TypeDiscriminator": "FindKeysRange", + "LastKey": 0, + "KeyStep": 1, + "Limit": 0 + }, + "Flags": "RO" + } + ] + }, { "Command": "WATCH", "Name": "WATCH", diff --git a/libs/server/ACL/ACLParser.cs b/libs/server/ACL/ACLParser.cs index 621d7a44e8e..2ee3297867c 100644 --- a/libs/server/ACL/ACLParser.cs +++ b/libs/server/ACL/ACLParser.cs @@ -33,6 +33,7 @@ class ACLParser ["stream"] = RespAclCategories.Stream, ["string"] = RespAclCategories.String, ["transaction"] = RespAclCategories.Transaction, + ["vector"] = RespAclCategories.Vector, ["write"] = RespAclCategories.Write, ["garnet"] = RespAclCategories.Garnet, ["custom"] = RespAclCategories.Custom, diff --git a/libs/server/AOF/AofProcessor.cs b/libs/server/AOF/AofProcessor.cs index f4c21c5e7f5..766d83864d8 100644 --- a/libs/server/AOF/AofProcessor.cs +++ b/libs/server/AOF/AofProcessor.cs @@ -34,6 +34,7 @@ public sealed unsafe partial class AofProcessor private readonly SessionParseState parseState; int activeDbId; + VectorManager activeVectorManager; /// /// Set ReadWriteSession on the cluster session (NOTE: used for replaying stored procedures only) @@ -57,6 +58,9 @@ public sealed unsafe partial class AofProcessor readonly ILogger logger; + readonly StoreWrapper replayAofStoreWrapper; + readonly IClusterProvider clusterProvider; + MemoryResult output; /// @@ -70,10 +74,11 @@ public AofProcessor( { this.storeWrapper = storeWrapper; - var replayAofStoreWrapper = new StoreWrapper(storeWrapper, recordToAof); + replayAofStoreWrapper = new StoreWrapper(storeWrapper, recordToAof); + this.clusterProvider = clusterProvider; this.activeDbId = 0; - this.respServerSession = new RespServerSession(0, networkSender: null, storeWrapper: replayAofStoreWrapper, subscribeBroker: null, authenticator: null, enableScripts: false, clusterProvider: clusterProvider); + this.respServerSession = ObtainServerSession(); // Switch current contexts to match the default database SwitchActiveDatabaseContext(storeWrapper.DefaultDatabase, true); @@ -90,6 +95,9 @@ public AofProcessor( this.logger = logger; } + private RespServerSession ObtainServerSession() + => new(0, networkSender: null, storeWrapper: replayAofStoreWrapper, subscribeBroker: null, authenticator: null, enableScripts: false, clusterProvider: clusterProvider); + /// /// Dispose /// @@ -191,6 +199,12 @@ public unsafe void ProcessAofRecordInternal(byte* ptr, int length, bool asReplic AofHeader header = *(AofHeader*)ptr; isCheckpointStart = false; + // Aggressively do not move data if VADD are being replayed + if (header.opType != AofEntryType.StoreRMW) + { + activeVectorManager.WaitForVectorOperationsToComplete(); + } + if (inflightTxns.ContainsKey(header.sessionID)) { switch (header.opType) @@ -332,6 +346,14 @@ private unsafe bool ReplayOp(byte* entryPtr, int length, bool replayAsReplica) { AofHeader header = *(AofHeader*)entryPtr; + // StoreRMW can queue VADDs onto different threads + // but everything else needs to WAIT for those to complete + // otherwise we might loose consistency + if (header.opType != AofEntryType.StoreRMW) + { + activeVectorManager.WaitForVectorOperationsToComplete(); + } + // Skips (1) entries with versions that were part of prior checkpoint; and (2) future entries in fuzzy region if (SkipRecord(entryPtr, length, replayAsReplica)) return false; @@ -341,10 +363,10 @@ private unsafe bool ReplayOp(byte* entryPtr, int length, bool replayAsReplica) StoreUpsert(basicContext, storeInput, entryPtr); break; case AofEntryType.StoreRMW: - StoreRMW(basicContext, storeInput, entryPtr); + StoreRMW(basicContext, storeInput, activeVectorManager, respServerSession, ObtainServerSession, entryPtr); break; case AofEntryType.StoreDelete: - StoreDelete(basicContext, entryPtr); + StoreDelete(basicContext, activeVectorManager, respServerSession.storageSession, entryPtr); break; case AofEntryType.ObjectStoreRMW: ObjectStoreRMW(objectStoreBasicContext, objectStoreInput, entryPtr, bufferPtr, buffer.Length); @@ -396,6 +418,8 @@ private void SwitchActiveDatabaseContext(GarnetDatabase db, bool initialSetup = objectStoreBasicContext = objectStoreSession.BasicContext; this.activeDbId = db.Id; } + + activeVectorManager = db.VectorManager; } static void StoreUpsert(BasicContext basicContext, @@ -419,7 +443,14 @@ static void StoreUpsert(BasicContext basicContext, RawStringInput storeInput, byte* ptr) + static void StoreRMW( + BasicContext basicContext, + RawStringInput storeInput, + VectorManager vectorManager, + RespServerSession currentSession, + Func obtainServerSession, + byte* ptr + ) { var curr = ptr + sizeof(AofHeader); ref var key = ref Unsafe.AsRef(curr); @@ -428,21 +459,52 @@ static void StoreRMW(BasicContext basicContext, byte* ptr) + static void StoreDelete( + BasicContext basicContext, + VectorManager vectorManager, + StorageSession storageSession, + byte* ptr) { ref var key = ref Unsafe.AsRef(ptr + sizeof(AofHeader)); - basicContext.Delete(ref key); + var res = basicContext.Delete(ref key); + + if (res.IsCanceled) + { + // Might be a vector set + res = vectorManager.TryDeleteVectorSet(storageSession, ref key); + if (res.IsPending) + _ = basicContext.CompletePending(true); + } } static void ObjectStoreUpsert(BasicContext basicContext, diff --git a/libs/server/API/GarnetApi.cs b/libs/server/API/GarnetApi.cs index 09d23aad563..8f9e19200e9 100644 --- a/libs/server/API/GarnetApi.cs +++ b/libs/server/API/GarnetApi.cs @@ -21,9 +21,10 @@ namespace Garnet.server /// /// Garnet API implementation /// - public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi + public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext + where TVectorContext : ITsavoriteContext { readonly StorageSession storageSession; TContext context; @@ -48,8 +49,12 @@ public void WATCH(byte[] key, StoreType type) #region GET /// - public GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) - => storageSession.GET(ref key, ref input, ref output, ref context); + public GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) + { + var asSpanByte = key.SpanByte; + + return storageSession.GET(ref asSpanByte, ref input, ref output, ref context); + } /// public GarnetStatus GET_WithPending(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output, long ctx, out bool pending) @@ -68,7 +73,9 @@ public unsafe GarnetStatus GETForMemoryResult(ArgSlice key, out MemoryResult public unsafe GarnetStatus GET(ArgSlice key, out ArgSlice value) - => storageSession.GET(key, out value, ref context); + { + return storageSession.GET(key, out value, ref context); + } /// public GarnetStatus GET(byte[] key, out GarnetObjectStoreOutput value) @@ -118,33 +125,52 @@ public GarnetStatus PEXPIRETIME(ref SpanByte key, StoreType storeType, ref SpanB #endregion #region SET - /// + public GarnetStatus SET(ref SpanByte key, ref SpanByte value) - => storageSession.SET(ref key, ref value, ref context); + => storageSession.SET(ref key, ref value, ref context); /// - public GarnetStatus SET(ref SpanByte key, ref RawStringInput input, ref SpanByte value) - => storageSession.SET(ref key, ref input, ref value, ref context); + public GarnetStatus SET(ArgSlice key, ref RawStringInput input, ref SpanByte value) + { + var asSpanByte = key.SpanByte; - /// - public GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input) - => storageSession.SET_Conditional(ref key, ref input, ref context); + return storageSession.SET(ref asSpanByte, ref input, ref value, ref context); + } /// public GarnetStatus DEL_Conditional(ref SpanByte key, ref RawStringInput input) => storageSession.DEL_Conditional(ref key, ref input, ref context); /// - public GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) - => storageSession.SET_Conditional(ref key, ref input, ref output, ref context); + public GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) + { + var asSpanByte = key.SpanByte; + + return storageSession.SET_Conditional(ref asSpanByte, ref input, ref output, ref context); + } + + /// + public GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input) + { + var asSpanByte = key.SpanByte; + + return storageSession.SET_Conditional(ref asSpanByte, ref input, ref context); + } /// public GarnetStatus SET(ArgSlice key, Memory value) - => storageSession.SET(key, value, ref context); + { + return storageSession.SET(key, value, ref context); + } /// public GarnetStatus SET(ArgSlice key, ArgSlice value) - => storageSession.SET(key, value, ref context); + { + var asSpanByte = key.SpanByte; + var valSpanByte = value.SpanByte; + + return storageSession.SET(ref asSpanByte, ref valSpanByte, ref context); + } /// public GarnetStatus SET(byte[] key, IGarnetObject value) @@ -302,7 +328,7 @@ public GarnetStatus DELETE(ArgSlice key, StoreType storeType = StoreType.All) /// public GarnetStatus DELETE(ref SpanByte key, StoreType storeType = StoreType.All) - => storageSession.DELETE(ref key, storeType, ref context, ref objectContext); + => storageSession.DELETE(ref key, storeType, ref context, ref objectContext); /// public GarnetStatus DELETE(byte[] key, StoreType storeType = StoreType.All) @@ -482,5 +508,33 @@ public int GetScratchBufferOffset() public bool ResetScratchBuffer(int offset) => storageSession.scratchBufferBuilder.ResetScratchBuffer(offset); #endregion + + #region VectorSet commands + + /// + public unsafe GarnetStatus VectorSetAdd(ArgSlice key, int reduceDims, VectorValueType valueType, ArgSlice values, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, out VectorManagerResult result, out ReadOnlySpan errorMsg) + => storageSession.VectorSetAdd(SpanByte.FromPinnedPointer(key.ptr, key.length), reduceDims, valueType, values, element, quantizer, buildExplorationFactor, attributes, numLinks, out result, out errorMsg); + + /// + public unsafe GarnetStatus VectorSetRemove(ArgSlice key, ArgSlice element) + => storageSession.VectorSetRemove(SpanByte.FromPinnedPointer(key.ptr, key.length), SpanByte.FromPinnedPointer(element.ptr, element.length)); + + /// + public unsafe GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + => storageSession.VectorSetValueSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), valueType, values, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + + /// + public unsafe GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + => storageSession.VectorSetElementSimilarity(SpanByte.FromPinnedPointer(key.ptr, key.length), element.ReadOnlySpan, count, delta, searchExplorationFactor, filter.ReadOnlySpan, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + + /// + public unsafe GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputDistances) + => storageSession.VectorSetEmbedding(SpanByte.FromPinnedPointer(key.ptr, key.length), element.ReadOnlySpan, ref outputDistances); + + /// + public unsafe GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions) + => storageSession.VectorSetDimensions(SpanByte.FromPinnedPointer(key.ptr, key.length), out dimensions); + + #endregion } } \ No newline at end of file diff --git a/libs/server/API/GarnetApiObjectCommands.cs b/libs/server/API/GarnetApiObjectCommands.cs index b0a72473b8e..9ba483e08d7 100644 --- a/libs/server/API/GarnetApiObjectCommands.cs +++ b/libs/server/API/GarnetApiObjectCommands.cs @@ -16,9 +16,10 @@ namespace Garnet.server /// /// Garnet API implementation /// - public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi + public partial struct GarnetApi : IGarnetApi, IGarnetWatchApi where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext + where TVectorContext : ITsavoriteContext { #region SortedSet Methods diff --git a/libs/server/API/GarnetWatchApi.cs b/libs/server/API/GarnetWatchApi.cs index ac68e97e66f..ff0f3a2063f 100644 --- a/libs/server/API/GarnetWatchApi.cs +++ b/libs/server/API/GarnetWatchApi.cs @@ -23,10 +23,10 @@ public GarnetWatchApi(TGarnetApi garnetApi) #region GET /// - public GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output) + public GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output) { - garnetApi.WATCH(new ArgSlice(ref key), StoreType.Main); - return garnetApi.GET(ref key, ref input, ref output); + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.GET(key, ref input, ref output); } /// @@ -647,5 +647,35 @@ public bool ResetScratchBuffer(int offset) => garnetApi.ResetScratchBuffer(offset); #endregion + + #region Vector Sets + /// + public GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetValueSimilarity(key, valueType, value, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + } + + /// + public GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetElementSimilarity(key, element, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes, out result); + } + + /// + public GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputDistances) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetEmbedding(key, element, ref outputDistances); + } + + /// + public GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions) + { + garnetApi.WATCH(key, StoreType.Main); + return garnetApi.VectorSetDimensions(key, out dimensions); + } + #endregion } } \ No newline at end of file diff --git a/libs/server/API/IGarnetApi.cs b/libs/server/API/IGarnetApi.cs index a78ac22f556..f81597912fa 100644 --- a/libs/server/API/IGarnetApi.cs +++ b/libs/server/API/IGarnetApi.cs @@ -26,17 +26,12 @@ public interface IGarnetApi : IGarnetReadApi, IGarnetAdvancedApi /// /// SET /// - GarnetStatus SET(ref SpanByte key, ref SpanByte value); - - /// - /// SET - /// - GarnetStatus SET(ref SpanByte key, ref RawStringInput input, ref SpanByte value); + GarnetStatus SET(ArgSlice key, ref RawStringInput input, ref SpanByte value); /// /// SET Conditional /// - GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input); + GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input); /// /// DEL Conditional @@ -46,7 +41,7 @@ public interface IGarnetApi : IGarnetReadApi, IGarnetAdvancedApi /// /// SET Conditional /// - GarnetStatus SET_Conditional(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output); + GarnetStatus SET_Conditional(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output); /// /// SET @@ -1206,6 +1201,18 @@ GarnetStatus GeoSearchStore(ArgSlice key, ArgSlice destinationKey, ref GeoSearch GarnetStatus HyperLogLogMerge(ref RawStringInput input, out bool error); #endregion + + #region VectorSet Methods + /// + /// Adds to (and may create) a vector set with the given parameters. + /// + GarnetStatus VectorSetAdd(ArgSlice key, int reduceDims, VectorValueType valueType, ArgSlice value, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, out VectorManagerResult result, out ReadOnlySpan errorMsg); + + /// + /// Remove a member from a vector set, if it is present and the key exists. + /// + GarnetStatus VectorSetRemove(ArgSlice key, ArgSlice element); + #endregion } /// @@ -1217,7 +1224,7 @@ public interface IGarnetReadApi /// /// GET /// - GarnetStatus GET(ref SpanByte key, ref RawStringInput input, ref SpanByteAndMemory output); + GarnetStatus GET(ArgSlice key, ref RawStringInput input, ref SpanByteAndMemory output); /// /// GET @@ -2026,6 +2033,36 @@ public bool IterateObjectStore(ref TScanFunctions scanFunctions, #endregion + #region Vector Sets + + /// + /// Perform a similarity search given a vector and these parameters. + /// + /// Ids are encoded in as length prefixed blobs of bytes. + /// Attributes are encoded in as length prefixed blobs of bytes. + /// + GarnetStatus VectorSetValueSimilarity(ArgSlice key, VectorValueType valueType, ArgSlice value, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result); + + /// + /// Perform a similarity search given an element already in the vector set and these parameters. + /// + /// Ids are encoded in as length prefixed blobs of bytes. + /// Attributes are encoded in as length prefixed blobs of bytes. + /// + GarnetStatus VectorSetElementSimilarity(ArgSlice key, ArgSlice element, int count, float delta, int searchExplorationFactor, ArgSlice filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result); + + /// + /// Fetch the embedding of a given element in a Vector set. + /// + GarnetStatus VectorSetEmbedding(ArgSlice key, ArgSlice element, ref SpanByteAndMemory outputDistances); + + /// + /// Fetch the dimensionality of the given Vector Set. + /// + /// If the Vector Set was created with reduced dimensions, reports the reduced dimensions. + /// + GarnetStatus VectorSetDimensions(ArgSlice key, out int dimensions); + #endregion } /// diff --git a/libs/server/ArgSlice/ArgSliceVector.cs b/libs/server/ArgSlice/ArgSliceVector.cs index 07091e1b130..26e792d4f56 100644 --- a/libs/server/ArgSlice/ArgSliceVector.cs +++ b/libs/server/ArgSlice/ArgSliceVector.cs @@ -4,6 +4,8 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; +using Tsavorite.core; namespace Garnet.server { @@ -11,13 +13,13 @@ namespace Garnet.server /// Vector of ArgSlices /// /// - public unsafe class ArgSliceVector(int maxItemNum = 1 << 18) : IEnumerable + public unsafe class ArgSliceVector(int maxItemNum = 1 << 18) : IEnumerable { ScratchBufferBuilder bufferManager = new(); readonly int maxCount = maxItemNum; public int Count => items.Count; public bool IsEmpty => items.Count == 0; - readonly List items = []; + readonly List items = []; /// /// Try to add ArgSlice @@ -29,7 +31,32 @@ public bool TryAddItem(Span item) if (Count + 1 >= maxCount) return false; - items.Add(bufferManager.CreateArgSlice(item)); + var argSlice = bufferManager.CreateArgSlice(item); + + items.Add(argSlice.SpanByte); + return true; + } + + /// + /// Try to add ArgSlice + /// + /// + /// True if it succeeds to add ArgSlice, false if maxCount has been reached. + public bool TryAddItem(ulong ns, Span item) + { + Debug.Assert(ns <= byte.MaxValue, "Only byte-size namespaces supported currently"); + + if (Count + 1 >= maxCount) + return false; + + var argSlice = bufferManager.CreateArgSlice(item.Length + 1); + var sb = argSlice.SpanByte; + + sb.MarkNamespace(); + sb.SetNamespaceInPayload((byte)ns); + item.CopyTo(sb.AsSpan()); + + items.Add(sb); return true; } @@ -42,7 +69,7 @@ public void Clear() bufferManager.Reset(); } - public IEnumerator GetEnumerator() + public IEnumerator GetEnumerator() { foreach (var item in items) yield return item; diff --git a/libs/server/Cluster/ClusterSlotVerificationInput.cs b/libs/server/Cluster/ClusterSlotVerificationInput.cs index 8b673189add..0d72b177363 100644 --- a/libs/server/Cluster/ClusterSlotVerificationInput.cs +++ b/libs/server/Cluster/ClusterSlotVerificationInput.cs @@ -34,5 +34,12 @@ public struct ClusterSlotVerificationInput /// Offset of key num if any /// public int keyNumOffset; + + /// + /// If the command being executed modifes a Vector Set. + /// + /// This requires special handling during migrations. + /// + public bool isVectorSetWriteCommand; } } \ No newline at end of file diff --git a/libs/server/Cluster/IClusterProvider.cs b/libs/server/Cluster/IClusterProvider.cs index 344c88c41e2..f8d854ed409 100644 --- a/libs/server/Cluster/IClusterProvider.cs +++ b/libs/server/Cluster/IClusterProvider.cs @@ -12,22 +12,33 @@ namespace Garnet.server { + using BasicContext = BasicContext, + SpanByteAllocator>>; + using BasicGarnetApi = GarnetApi, SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; + + using VectorContext = BasicContext, SpanByteAllocator>>; /// /// Cluster provider /// public interface IClusterProvider : IDisposable { + // TODO: I really hate having to pass Vector and Basic contexts here... cleanup + /// /// Create cluster session /// - IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, INetworkSender networkSender, ILogger logger = null); + IClusterSession CreateClusterSession(TransactionManager txnManager, IGarnetAuthenticator authenticator, UserHandle userHandle, GarnetSessionMetrics garnetSessionMetrics, BasicGarnetApi basicGarnetApi, BasicContext basicContext, VectorContext vectorContext, INetworkSender networkSender, ILogger logger = null); /// diff --git a/libs/server/Cluster/IClusterSession.cs b/libs/server/Cluster/IClusterSession.cs index 045d4de959b..af7ceaa5cfd 100644 --- a/libs/server/Cluster/IClusterSession.cs +++ b/libs/server/Cluster/IClusterSession.cs @@ -62,7 +62,7 @@ public interface IClusterSession /// /// Process cluster commands /// - unsafe void ProcessClusterCommands(RespCommand command, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend); + unsafe void ProcessClusterCommands(RespCommand command, VectorManager vectorManager, ref SessionParseState parseState, ref byte* dcurr, ref byte* dend); /// /// Reset cached slot verification result @@ -77,7 +77,7 @@ public interface IClusterSession /// /// /// - bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking); + bool NetworkIterativeSlotVerify(ArgSlice keySlice, bool readOnly, byte SessionAsking, bool isVectorSetWriteCommand); /// /// Write cached slot verification message to output @@ -88,7 +88,7 @@ public interface IClusterSession /// /// Key array slot verify (write result to network) /// - unsafe bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte SessionAsking, ref byte* dcurr, ref byte* dend, int count = -1); + unsafe bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, byte SessionAsking, bool isVectorSetWriteCommand, ref byte* dcurr, ref byte* dend, int count = -1); /// /// Array slot verify (write result to network) diff --git a/libs/server/Databases/DatabaseManagerBase.cs b/libs/server/Databases/DatabaseManagerBase.cs index 2700eaa088c..04c823a8727 100644 --- a/libs/server/Databases/DatabaseManagerBase.cs +++ b/libs/server/Databases/DatabaseManagerBase.cs @@ -414,7 +414,7 @@ protected void ExecuteObjectCollection(GarnetDatabase db, ILogger logger = null) { var scratchBufferManager = new ScratchBufferBuilder(); db.ObjectStoreCollectionDbStorageSession = - new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } ExecuteHashCollect(db.ObjectStoreCollectionDbStorageSession); @@ -722,7 +722,7 @@ private static void ExecuteSortedSetCollect(StorageSession storageSession) if (db.MainStoreExpiredKeyDeletionDbStorageSession == null) { var scratchBufferManager = new ScratchBufferBuilder(); - db.MainStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + db.MainStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } var scanFrom = StoreWrapper.store.Log.ReadOnlyAddress; @@ -738,7 +738,7 @@ private static void ExecuteSortedSetCollect(StorageSession storageSession) if (db.ObjectStoreExpiredKeyDeletionDbStorageSession == null) { var scratchBufferManager = new ScratchBufferBuilder(); - db.ObjectStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, Logger); + db.ObjectStoreExpiredKeyDeletionDbStorageSession = new StorageSession(StoreWrapper, scratchBufferManager, null, null, db.Id, db.VectorManager, Logger); } var scanFrom = StoreWrapper.objectStore.Log.ReadOnlyAddress; @@ -778,7 +778,7 @@ private HybridLogScanMetrics CollectHybridLogStats>(sessionFunctions); diff --git a/libs/server/Databases/MultiDatabaseManager.cs b/libs/server/Databases/MultiDatabaseManager.cs index 6d5855f5dd5..14fe8f4d685 100644 --- a/libs/server/Databases/MultiDatabaseManager.cs +++ b/libs/server/Databases/MultiDatabaseManager.cs @@ -147,6 +147,9 @@ public override void RecoverCheckpoint(bool replicaRecover = false, bool recover if (StoreWrapper.serverOptions.FailOnRecoveryError) throw new GarnetException("Main store and object store checkpoint versions do not match"); } + + // Once everything is setup, initialize the VectorManager + db.VectorManager.Initialize(); } } @@ -712,7 +715,7 @@ public override FunctionsState CreateFunctionsState(int dbId = 0, byte respProto throw new GarnetException($"Database with ID {dbId} was not found."); return new(db.AppendOnlyFile, db.VersionMap, StoreWrapper.customCommandManager, null, db.ObjectStoreSizeTracker, - StoreWrapper.GarnetObjectSerializer, respProtocolVersion); + StoreWrapper.GarnetObjectSerializer, db.VectorManager, respProtocolVersion); } /// diff --git a/libs/server/Databases/SingleDatabaseManager.cs b/libs/server/Databases/SingleDatabaseManager.cs index 605262dd098..15a3423f88c 100644 --- a/libs/server/Databases/SingleDatabaseManager.cs +++ b/libs/server/Databases/SingleDatabaseManager.cs @@ -111,6 +111,9 @@ public override void RecoverCheckpoint(bool replicaRecover = false, bool recover if (StoreWrapper.serverOptions.FailOnRecoveryError) throw new GarnetException("Main store and object store checkpoint versions do not match"); } + + // Once everything is setup, initialize the VectorManager + defaultDatabase.VectorManager.Initialize(); } /// @@ -391,7 +394,7 @@ public override FunctionsState CreateFunctionsState(int dbId = 0, byte respProto ArgumentOutOfRangeException.ThrowIfNotEqual(dbId, 0); return new(AppendOnlyFile, VersionMap, StoreWrapper.customCommandManager, null, ObjectStoreSizeTracker, - StoreWrapper.GarnetObjectSerializer, respProtocolVersion); + StoreWrapper.GarnetObjectSerializer, DefaultDatabase.VectorManager, respProtocolVersion); } private async Task TryPauseCheckpointsContinuousAsync(int dbId, diff --git a/libs/server/Garnet.server.csproj b/libs/server/Garnet.server.csproj index 2c351e80f45..dc679f37e8f 100644 --- a/libs/server/Garnet.server.csproj +++ b/libs/server/Garnet.server.csproj @@ -22,6 +22,7 @@ + \ No newline at end of file diff --git a/libs/server/GarnetDatabase.cs b/libs/server/GarnetDatabase.cs index 41eb4784f6d..ef3788c7e85 100644 --- a/libs/server/GarnetDatabase.cs +++ b/libs/server/GarnetDatabase.cs @@ -100,6 +100,14 @@ public class GarnetDatabase : IDisposable /// public SingleWriterMultiReaderLock CheckpointingLock; + /// + /// Per-DB VectorManager + /// + /// Contexts, metadata, and associated namespaces are DB-specific, and meaningless + /// outside of the container DB. + /// + public readonly VectorManager VectorManager; + /// /// Storage session intended for store-wide object collection operations /// @@ -124,7 +132,7 @@ public GarnetDatabase(int id, TsavoriteKV objectStore, LightEpoch epoch, StateMachineDriver stateMachineDriver, CacheSizeTracker objectStoreSizeTracker, IDevice aofDevice, TsavoriteLog appendOnlyFile, - bool mainStoreIndexMaxedOut, bool objectStoreIndexMaxedOut) : this() + bool mainStoreIndexMaxedOut, bool objectStoreIndexMaxedOut, VectorManager vectorManager) : this() { Id = id; MainStore = mainStore; @@ -136,6 +144,7 @@ public GarnetDatabase(int id, TsavoriteKV + /// Header for Garnet Main Store inputs but for Vector element r/w/d ops + /// + public struct VectorInput : IStoreInput + { + public int SerializedLength => throw new NotImplementedException(); + + public int ReadDesiredSize { get; set; } + + public int WriteDesiredSize { get; set; } + + public int Index { get; set; } + public nint CallbackContext { get; set; } + public nint Callback { get; set; } + + public VectorInput() + { + } + + public unsafe int CopyTo(byte* dest, int length) => throw new NotImplementedException(); + public unsafe int DeserializeFrom(byte* src) => throw new NotImplementedException(); + } } \ No newline at end of file diff --git a/libs/server/Resp/AdminCommands.cs b/libs/server/Resp/AdminCommands.cs index 73851314355..fa134a1498f 100644 --- a/libs/server/Resp/AdminCommands.cs +++ b/libs/server/Resp/AdminCommands.cs @@ -703,7 +703,7 @@ private bool NetworkProcessClusterCommand(RespCommand command) return AbortWithErrorMessage(CmdStrings.RESP_ERR_GENERIC_CLUSTER_DISABLED); } - clusterSession.ProcessClusterCommands(command, ref parseState, ref dcurr, ref dend); + clusterSession.ProcessClusterCommands(command, storageSession.vectorManager, ref parseState, ref dcurr, ref dend); return true; } diff --git a/libs/server/Resp/BasicCommands.cs b/libs/server/Resp/BasicCommands.cs index 6cc37408b4a..838e65d3b21 100644 --- a/libs/server/Resp/BasicCommands.cs +++ b/libs/server/Resp/BasicCommands.cs @@ -31,12 +31,13 @@ bool NetworkGET(ref TGarnetApi storageApi) RawStringInput input = default; - var key = parseState.GetArgSliceByRef(0).SpanByte; + ref var key = ref parseState.GetArgSliceByRef(0); var o = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref o); + var status = storageApi.GET(key, ref input, ref o); switch (status) { + case GarnetStatus.WRONGTYPE: case GarnetStatus.OK: if (!o.IsSpanByte) SendAndReset(o.Memory, o.Length); @@ -278,10 +279,10 @@ private bool NetworkSET(ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; - var value = parseState.GetArgSliceByRef(1).SpanByte; + var key = parseState.GetArgSliceByRef(0); + var value = parseState.GetArgSliceByRef(1); - storageApi.SET(ref key, ref value); + storageApi.SET(key, value); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); @@ -296,9 +297,9 @@ private bool NetworkGETSET(ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); - return NetworkSET_Conditional(RespCommand.SET, 0, ref key, true, + return NetworkSET_Conditional(RespCommand.SET, 0, key, true, false, false, ref storageApi); } @@ -377,7 +378,7 @@ private bool NetworkGetRange(ref TGarnetApi storageApi) private bool NetworkSETEX(bool highPrecision, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); // Validate expiry if (!parseState.TryGetInt(1, out var expiry)) @@ -398,7 +399,7 @@ private bool NetworkSETEX(bool highPrecision, ref TGarnetApi storage var sbVal = parseState.GetArgSliceByRef(2).SpanByte; var input = new RawStringInput(RespCommand.SETEX, 0, valMetadata); - _ = storageApi.SET(ref key, ref input, ref sbVal); + _ = storageApi.SET(key, ref input, ref sbVal); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); @@ -418,10 +419,9 @@ private bool NetworkSETNX(bool highPrecision, ref TGarnetApi storage } var key = parseState.GetArgSliceByRef(0); - var sbKey = key.SpanByte; var input = new RawStringInput(RespCommand.SETEXNX, ref parseState, startIdx: 1); - var status = storageApi.SET_Conditional(ref sbKey, ref input); + var status = storageApi.SET_Conditional(key, ref input); // The status returned for SETNX as NOTFOUND is the expected status in the happy path var retVal = status == GarnetStatus.NOTFOUND ? 1 : 0; @@ -573,14 +573,14 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) { case ExistOptions.None: return getValue || withEtag - ? NetworkSET_Conditional(RespCommand.SET, expiry, ref sbKey, getValue, + ? NetworkSET_Conditional(RespCommand.SET, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi) - : NetworkSET_EX(RespCommand.SET, expOption, expiry, ref sbKey, ref sbVal, ref storageApi); // Can perform a blind update + : NetworkSET_EX(RespCommand.SET, expOption, expiry, key, ref sbVal, ref storageApi); // Can perform a blind update case ExistOptions.XX: - return NetworkSET_Conditional(RespCommand.SETEXXX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXXX, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi); case ExistOptions.NX: - return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, key, getValue, isHighPrecision, withEtag, ref storageApi); } break; @@ -590,13 +590,13 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) { case ExistOptions.None: // We can never perform a blind update due to KEEPTTL - return NetworkSET_Conditional(RespCommand.SETKEEPTTL, expiry, ref sbKey + return NetworkSET_Conditional(RespCommand.SETKEEPTTL, expiry, key , getValue, highPrecision: false, withEtag, ref storageApi); case ExistOptions.XX: - return NetworkSET_Conditional(RespCommand.SETKEEPTTLXX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETKEEPTTLXX, expiry, key, getValue, highPrecision: false, withEtag, ref storageApi); case ExistOptions.NX: - return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, ref sbKey, + return NetworkSET_Conditional(RespCommand.SETEXNX, expiry, key, getValue, highPrecision: false, withEtag, ref storageApi); } break; @@ -608,7 +608,7 @@ private bool NetworkSETEXNX(ref TGarnetApi storageApi) } private unsafe bool NetworkSET_EX(RespCommand cmd, ExpirationOption expOption, int expiry, - ref SpanByte key, ref SpanByte val, ref TGarnetApi storageApi) + ArgSlice key, ref SpanByte val, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { Debug.Assert(cmd == RespCommand.SET); @@ -621,14 +621,14 @@ private unsafe bool NetworkSET_EX(RespCommand cmd, ExpirationOption var input = new RawStringInput(cmd, 0, valMetadata); - storageApi.SET(ref key, ref input, ref val); + storageApi.SET(key, ref input, ref val); while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) SendAndReset(); return true; } - private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref SpanByte key, bool getValue, bool highPrecision, bool withEtag, ref TGarnetApi storageApi) + private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ArgSlice key, bool getValue, bool highPrecision, bool withEtag, ref TGarnetApi storageApi) where TGarnetApi : IGarnetApi { var inputArg = expiry == 0 @@ -645,7 +645,7 @@ private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref // the following debug assertion is the catch any edge case leading to SETIFMATCH, or SETIFGREATER skipping the above block Debug.Assert(cmd is not (RespCommand.SETIFMATCH or RespCommand.SETIFGREATER), "SETIFMATCH should have gone though pointing to right output variable"); - var status = storageApi.SET_Conditional(ref key, ref input); + var status = storageApi.SET_Conditional(key, ref input); // KEEPTTL without flags doesn't care whether it was found or not. if (cmd == RespCommand.SETKEEPTTL) @@ -684,7 +684,7 @@ private bool NetworkSET_Conditional(RespCommand cmd, int expiry, ref // anything with getValue or withEtag always writes to the buffer in the happy path SpanByteAndMemory outputBuffer = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - GarnetStatus status = storageApi.SET_Conditional(ref key, ref input, ref outputBuffer); + GarnetStatus status = storageApi.SET_Conditional(key, ref input, ref outputBuffer); // The data will be on the buffer either when we know the response is ok or when the withEtag flag is set. bool ok = status != GarnetStatus.NOTFOUND || withEtag; diff --git a/libs/server/Resp/BasicEtagCommands.cs b/libs/server/Resp/BasicEtagCommands.cs index 59ef098eaa7..2fee440918d 100644 --- a/libs/server/Resp/BasicEtagCommands.cs +++ b/libs/server/Resp/BasicEtagCommands.cs @@ -22,10 +22,10 @@ private bool NetworkGETWITHETAG(ref TGarnetApi storageApi) { Debug.Assert(parseState.Count == 1); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); var input = new RawStringInput(RespCommand.GETWITHETAG); var output = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref output); + var status = storageApi.GET(key, ref input, ref output); switch (status) { @@ -53,10 +53,10 @@ private bool NetworkGETIFNOTMATCH(ref TGarnetApi storageApi) { Debug.Assert(parseState.Count == 2); - var key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); var input = new RawStringInput(RespCommand.GETIFNOTMATCH, ref parseState, startIdx: 1); var output = new SpanByteAndMemory(dcurr, (int)(dend - dcurr)); - var status = storageApi.GET(ref key, ref input, ref output); + var status = storageApi.GET(key, ref input, ref output); switch (status) { @@ -213,9 +213,9 @@ private bool NetworkSetETagConditional(RespCommand cmd, ref TGarnetA return true; } - SpanByte key = parseState.GetArgSliceByRef(0).SpanByte; + var key = parseState.GetArgSliceByRef(0); - NetworkSET_Conditional(cmd, expiry, ref key, getValue: !noGet, highPrecision: expOption == ExpirationOption.PX, withEtag: true, ref storageApi); + NetworkSET_Conditional(cmd, expiry, key, getValue: !noGet, highPrecision: expOption == ExpirationOption.PX, withEtag: true, ref storageApi); return true; } diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index cd3263aa808..e8c5ba5fb9e 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -440,6 +440,7 @@ static partial class CmdStrings public static ReadOnlySpan publish => "PUBLISH"u8; public static ReadOnlySpan spublish => "SPUBLISH"u8; public static ReadOnlySpan mtasks => "MTASKS"u8; + public static ReadOnlySpan reserve => "RESERVE"u8; public static ReadOnlySpan aofsync => "AOFSYNC"u8; public static ReadOnlySpan appendlog => "APPENDLOG"u8; public static ReadOnlySpan attach_sync => "ATTACH_SYNC"u8; diff --git a/libs/server/Resp/GarnetDatabaseSession.cs b/libs/server/Resp/GarnetDatabaseSession.cs index 0e52d40d9c1..1eed9e96553 100644 --- a/libs/server/Resp/GarnetDatabaseSession.cs +++ b/libs/server/Resp/GarnetDatabaseSession.cs @@ -8,13 +8,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; /// /// Represents a logical database session in Garnet diff --git a/libs/server/Resp/KeyAdminCommands.cs b/libs/server/Resp/KeyAdminCommands.cs index 812617a3a57..1e9e18efefe 100644 --- a/libs/server/Resp/KeyAdminCommands.cs +++ b/libs/server/Resp/KeyAdminCommands.cs @@ -99,8 +99,6 @@ bool NetworkRESTORE(ref TGarnetApi storageApi) var valArgSlice = scratchBufferBuilder.CreateArgSlice(val); - var sbKey = key.SpanByte; - parseState.InitializeWithArgument(valArgSlice); RawStringInput input; @@ -114,7 +112,7 @@ bool NetworkRESTORE(ref TGarnetApi storageApi) input = new RawStringInput(RespCommand.SETEXNX, ref parseState); } - var status = storageApi.SET_Conditional(ref sbKey, ref input); + var status = storageApi.SET_Conditional(key, ref input); if (status is GarnetStatus.NOTFOUND) { diff --git a/libs/server/Resp/LocalServerSession.cs b/libs/server/Resp/LocalServerSession.cs index b3283504041..3bf4a4ca1c5 100644 --- a/libs/server/Resp/LocalServerSession.cs +++ b/libs/server/Resp/LocalServerSession.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Diagnostics; using Microsoft.Extensions.Logging; using Tsavorite.core; @@ -12,7 +13,10 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; /// /// Local server session @@ -47,8 +51,11 @@ public LocalServerSession(StoreWrapper storeWrapper) // Initialize session-local scratch buffer of size 64 bytes, used for constructing arguments in GarnetApi this.scratchBufferBuilder = new ScratchBufferBuilder(); + var dbRes = storeWrapper.TryGetOrAddDatabase(0, out var database, out _); + Debug.Assert(dbRes, "Should always be able to get DB 0"); + // Create storage session and API - this.storageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId: 0, logger); + this.storageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId: 0, database.VectorManager, logger); this.BasicGarnetApi = new BasicGarnetApi(storageSession, storageSession.basicContext, storageSession.objectStoreBasicContext); } diff --git a/libs/server/Resp/Parser/ParseUtils.cs b/libs/server/Resp/Parser/ParseUtils.cs index 14d6e0f5edc..02e9a2c41ca 100644 --- a/libs/server/Resp/Parser/ParseUtils.cs +++ b/libs/server/Resp/Parser/ParseUtils.cs @@ -130,6 +130,44 @@ public static bool TryReadDouble(ref ArgSlice slice, out double number, bool can return canBeInfinite && RespReadUtils.TryReadInfinity(sbNumber, out number); } + /// + /// Read a signed 32-bit float from a given ArgSlice. + /// + /// Source + /// Allow reading an infinity + /// + /// Parsed double + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static float ReadFloat(ref ArgSlice slice, bool canBeInfinite) + { + if (!TryReadFloat(ref slice, out var number, canBeInfinite)) + { + RespParsingException.ThrowNotANumber(slice.ptr, slice.length); + } + return number; + } + + /// + /// Try to read a signed 32-bit float from a given ArgSlice. + /// + /// Source + /// Result + /// Allow reading an infinity + /// + /// True if float parsed successfully + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryReadFloat(ref ArgSlice slice, out float number, bool canBeInfinite) + { + var sbNumber = slice.ReadOnlySpan; + if (Utf8Parser.TryParse(sbNumber, out number, out var bytesConsumed) && + bytesConsumed == sbNumber.Length) + return true; + + return canBeInfinite && RespReadUtils.TryReadInfinity(sbNumber, out number); + } + /// /// Read an ASCII string from a given ArgSlice. /// diff --git a/libs/server/Resp/Parser/RespCommand.cs b/libs/server/Resp/Parser/RespCommand.cs index cc81121b1df..c0f0d906cdc 100644 --- a/libs/server/Resp/Parser/RespCommand.cs +++ b/libs/server/Resp/Parser/RespCommand.cs @@ -81,6 +81,15 @@ public enum RespCommand : ushort SUNION, TTL, TYPE, + VCARD, + VDIM, + VEMB, + VGETATTR, + VINFO, + VISMEMBER, + VLINKS, + VRANDMEMBER, + VSIM, WATCH, WATCHMS, WATCHOS, @@ -195,6 +204,9 @@ public enum RespCommand : ushort SUNIONSTORE, SWAPDB, UNLINK, + VADD, + VREM, + VSETATTR, ZADD, ZCOLLECT, ZDIFFSTORE, @@ -374,6 +386,7 @@ public enum RespCommand : ushort CLUSTER_SPUBLISH, CLUSTER_REPLICAS, CLUSTER_REPLICATE, + CLUSTER_RESERVE, CLUSTER_RESET, CLUSTER_SEND_CKPT_FILE_SEGMENT, CLUSTER_SEND_CKPT_METADATA, @@ -627,6 +640,12 @@ public static bool IsClusterSubCommand(this RespCommand cmd) bool inRange = test <= (RespCommand.CLUSTER_SYNC - RespCommand.CLUSTER_ADDSLOTS); return inRange; } + + /// + /// Returns true if this command can operate on a Vector Set. + /// + public static bool IsLegalOnVectorSet(this RespCommand cmd) + => cmd is RespCommand.DEL or RespCommand.TYPE or RespCommand.DEBUG or RespCommand.VADD or RespCommand.VCARD or RespCommand.VDIM or RespCommand.VEMB or RespCommand.VGETATTR or RespCommand.VINFO or server.RespCommand.VISMEMBER or RespCommand.VLINKS or RespCommand.VRANDMEMBER or RespCommand.VREM or RespCommand.VSETATTR or RespCommand.VSIM; } /// @@ -961,6 +980,29 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVADD\r\n"u8)) + { + return RespCommand.VADD; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVDIM\r\n"u8)) + { + return RespCommand.VDIM; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVEMB\r\n"u8)) + { + return RespCommand.VEMB; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVREM\r\n"u8)) + { + return RespCommand.VREM; + } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nVSIM\r\n"u8)) + { + return RespCommand.VSIM; + } + break; + case 'Z': if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("\r\nZADD\r\n"u8)) { @@ -1141,6 +1183,17 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nVCARD\r\n"u8)) + { + return RespCommand.VCARD; + } + else if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nVINFO\r\n"u8)) + { + return RespCommand.VINFO; + } + break; + case 'W': if (*(ulong*)(ptr + 3) == MemoryMarshal.Read("\nWATCH\r\n"u8)) { @@ -1335,6 +1388,13 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan } break; + case 'V': + if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VLINKS\r\n"u8)) + { + return RespCommand.VLINKS; + } + break; + case 'Z': if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("ZCOUNT\r\n"u8)) { @@ -1510,6 +1570,14 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.SPUBLISH; } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VGETATTR"u8) && *(ushort*)(ptr + 12) == MemoryMarshal.Read("\r\n"u8)) + { + return RespCommand.VGETATTR; + } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VSETATTR"u8) && *(ushort*)(ptr + 12) == MemoryMarshal.Read("\r\n"u8)) + { + return RespCommand.VSETATTR; + } break; case 9: if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("SUBSCRIB"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("BE\r\n"u8)) @@ -1548,6 +1616,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZEXPIREAT; } + else if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("VISMEMBE"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("ER\r\n"u8)) + { + return RespCommand.VISMEMBER; + } break; case 10: if (*(ulong*)(ptr + 4) == MemoryMarshal.Read("SSUBSCRI"u8) && *(uint*)(ptr + 11) == MemoryMarshal.Read("BE\r\n"u8)) @@ -1684,6 +1756,10 @@ private RespCommand FastParseArrayCommand(ref int count, ref ReadOnlySpan { return RespCommand.ZEXPIRETIME; } + else if (*(ulong*)(ptr + 2) == MemoryMarshal.Read("1\r\nVRAND"u8) && *(ulong*)(ptr + 10) == MemoryMarshal.Read("MEMBER\r\n"u8)) + { + return RespCommand.VRANDMEMBER; + } break; case 12: @@ -2201,6 +2277,10 @@ private RespCommand SlowParseCommand(ReadOnlySpan command, ref int count, { return RespCommand.CLUSTER_MIGRATE; } + else if (subCommand.SequenceEqual(CmdStrings.reserve)) + { + return RespCommand.CLUSTER_RESERVE; + } else if (subCommand.SequenceEqual(CmdStrings.mtasks)) { return RespCommand.CLUSTER_MTASKS; diff --git a/libs/server/Resp/Parser/SessionParseState.cs b/libs/server/Resp/Parser/SessionParseState.cs index e0e523c7ea2..358b37b14fc 100644 --- a/libs/server/Resp/Parser/SessionParseState.cs +++ b/libs/server/Resp/Parser/SessionParseState.cs @@ -163,18 +163,19 @@ public void InitializeWithArguments(ArgSlice arg1, ArgSlice arg2, ArgSlice arg3, } /// - /// Initialize the parse state with a given set of arguments + /// Expand (if necessary) capacity of , preserving contents. /// - /// Set of arguments to initialize buffer with - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void InitializeWithArguments(ArgSlice[] args) + public void EnsureCapacity(int count) { - Initialize(args.Length); - - for (var i = 0; i < args.Length; i++) + if (count <= Count) { - *(bufferPtr + i) = args[i]; + return; } + + var oldBuffer = rootBuffer; + Initialize(count); + + oldBuffer?.AsSpan().CopyTo(rootBuffer); } /// @@ -432,6 +433,28 @@ public bool TryGetDouble(int i, out double value, bool canBeInfinite = true) return ParseUtils.TryReadDouble(ref Unsafe.AsRef(bufferPtr + i), out value, canBeInfinite); } + /// + /// Get float argument at the given index + /// + /// True if double parsed successfully + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public float GetFloat(int i, bool canBeInfinite = true) + { + Debug.Assert(i < Count); + return ParseUtils.ReadFloat(ref Unsafe.AsRef(bufferPtr + i), canBeInfinite); + } + + /// + /// Try to get double argument at the given index + /// + /// True if double parsed successfully + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool TryGetFloat(int i, out float value, bool canBeInfinite = true) + { + Debug.Assert(i < Count); + return ParseUtils.TryReadFloat(ref Unsafe.AsRef(bufferPtr + i), out value, canBeInfinite); + } + /// /// Get ASCII string argument at the given index /// diff --git a/libs/server/Resp/RespCommandDocs.cs b/libs/server/Resp/RespCommandDocs.cs index f6adceaecf0..b58578f7371 100644 --- a/libs/server/Resp/RespCommandDocs.cs +++ b/libs/server/Resp/RespCommandDocs.cs @@ -330,6 +330,8 @@ public enum RespCommandGroup : byte String, [Description("transactions")] Transactions, + [Description("vector")] + Vector } /// diff --git a/libs/server/Resp/RespCommandInfoFlags.cs b/libs/server/Resp/RespCommandInfoFlags.cs index e4f391a8613..bfe03845bf7 100644 --- a/libs/server/Resp/RespCommandInfoFlags.cs +++ b/libs/server/Resp/RespCommandInfoFlags.cs @@ -55,6 +55,8 @@ public enum RespCommandFlags Write = 1 << 19, [Description("allow_busy")] AllowBusy = 1 << 20, + [Description("module")] + Module = 1 << 21, } /// @@ -110,6 +112,8 @@ public enum RespAclCategories Garnet = 1 << 21, [Description("custom")] Custom = 1 << 22, + [Description("vector")] + Vector = 1 << 23, [Description("all")] All = (Custom << 1) - 1, } diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 854def816d9..8d8894e89b1 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -25,13 +25,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; /// /// RESP server session @@ -283,7 +289,8 @@ public RespServerSession( this.AuthenticateUser(Encoding.ASCII.GetBytes(this.storeWrapper.accessControlList.GetDefaultUserHandle().User.Name)); var cp = clusterProvider ?? storeWrapper.clusterProvider; - clusterSession = cp?.CreateClusterSession(txnManager, this._authenticator, this._userHandle, sessionMetrics, basicGarnetApi, networkSender, logger); + + clusterSession = cp?.CreateClusterSession(txnManager, this._authenticator, this._userHandle, sessionMetrics, basicGarnetApi, storageSession.basicContext, storageSession.vectorContext, networkSender, logger); clusterSession?.SetUserHandle(this._userHandle); sessionScriptCache?.SetUserHandle(this._userHandle); @@ -946,6 +953,20 @@ private bool ProcessArrayCommands(RespCommand cmd, ref TGarnetApi st RespCommand.SUNIONSTORE => SetUnionStore(ref storageApi), RespCommand.SDIFF => SetDiff(ref storageApi), RespCommand.SDIFFSTORE => SetDiffStore(ref storageApi), + // Vector Commands + RespCommand.VADD => NetworkVADD(ref storageApi), + RespCommand.VCARD => NetworkVCARD(ref storageApi), + RespCommand.VDIM => NetworkVDIM(ref storageApi), + RespCommand.VEMB => NetworkVEMB(ref storageApi), + RespCommand.VGETATTR => NetworkVGETATTR(ref storageApi), + RespCommand.VINFO => NetworkVINFO(ref storageApi), + RespCommand.VISMEMBER => NetworkVISMEMBER(ref storageApi), + RespCommand.VLINKS => NetworkVLINKS(ref storageApi), + RespCommand.VRANDMEMBER => NetworkVRANDMEMBER(ref storageApi), + RespCommand.VREM => NetworkVREM(ref storageApi), + RespCommand.VSETATTR => NetworkVSETATTR(ref storageApi), + RespCommand.VSIM => NetworkVSIM(ref storageApi), + // Everything else _ => ProcessOtherCommands(cmd, ref storageApi) }; return success; @@ -1332,7 +1353,7 @@ private void Send(byte* d) if ((int)(dcurr - d) > 0) { - // Debug.WriteLine("SEND: [" + Encoding.UTF8.GetString(new Span(d, (int)(dcurr - d))).Replace("\n", "|").Replace("\r", "!") + "]"); + //Debug.WriteLine("SEND: [" + Encoding.UTF8.GetString(new Span(d, (int)(dcurr - d))).Replace("\n", "|").Replace("\r", "!") + "]"); if (waitForAofBlocking) { var task = storeWrapper.WaitForCommitAsync(); @@ -1496,7 +1517,10 @@ private GarnetDatabaseSession TryGetOrSetDatabaseSession(int dbId, out bool succ /// New database session private GarnetDatabaseSession CreateDatabaseSession(int dbId) { - var dbStorageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId, logger, respProtocolVersion); + var dbRes = storeWrapper.TryGetOrAddDatabase(dbId, out var database, out _); + Debug.Assert(dbRes, "Should always find database if we're switching to it"); + + var dbStorageSession = new StorageSession(storeWrapper, scratchBufferBuilder, sessionMetrics, LatencyMetrics, dbId, database.VectorManager, logger, respProtocolVersion); var dbGarnetApi = new BasicGarnetApi(dbStorageSession, dbStorageSession.basicContext, dbStorageSession.objectStoreBasicContext); var dbLockableGarnetApi = new LockableGarnetApi(dbStorageSession, dbStorageSession.lockableContext, dbStorageSession.objectStoreLockableContext); diff --git a/libs/server/Resp/RespServerSessionSlotVerify.cs b/libs/server/Resp/RespServerSessionSlotVerify.cs index 9de8ee1c18d..39179c979f5 100644 --- a/libs/server/Resp/RespServerSessionSlotVerify.cs +++ b/libs/server/Resp/RespServerSessionSlotVerify.cs @@ -17,9 +17,10 @@ internal sealed unsafe partial class RespServerSession : ServerSessionBase /// Array of key ArgSlice /// Whether caller is going to perform a readonly or read/write operation /// Key count if different than keys array length + /// Whether the executing command performs a write against a Vector Set. /// True when ownership is verified, false otherwise - bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, int count = -1) - => clusterSession != null && clusterSession.NetworkKeyArraySlotVerify(keys, readOnly, SessionAsking, ref dcurr, ref dend, count); + bool NetworkKeyArraySlotVerify(Span keys, bool readOnly, bool isVectorSetWriteCommand, int count = -1) + => clusterSession != null && clusterSession.NetworkKeyArraySlotVerify(keys, readOnly, SessionAsking, isVectorSetWriteCommand, ref dcurr, ref dend, count); bool CanServeSlot(RespCommand cmd) { @@ -43,6 +44,7 @@ bool CanServeSlot(RespCommand cmd) storeWrapper.clusterProvider.ExtractKeySpecs(commandInfo, cmd, ref parseState, ref csvi); csvi.readOnly = cmd.IsReadOnly(); csvi.sessionAsking = SessionAsking; + csvi.isVectorSetWriteCommand = cmd is RespCommand.VADD or RespCommand.VREM or RespCommand.VSETATTR; return !clusterSession.NetworkMultiKeySlotVerify(ref parseState, ref csvi, ref dcurr, ref dend); } } diff --git a/libs/server/Resp/Vector/DiskANNService.cs b/libs/server/Resp/Vector/DiskANNService.cs new file mode 100644 index 00000000000..8a178af5eff --- /dev/null +++ b/libs/server/Resp/Vector/DiskANNService.cs @@ -0,0 +1,319 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + internal sealed unsafe class DiskANNService + { + // Term types. + internal const byte FullVector = 0; + private const byte NeighborList = 1; + private const byte QuantizedVector = 2; + internal const byte Attributes = 3; + + public nint CreateIndex( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + delegate* unmanaged[Cdecl] readCallback, + delegate* unmanaged[Cdecl] writeCallback, + delegate* unmanaged[Cdecl] deleteCallback, + delegate* unmanaged[Cdecl] readModifyWriteCallback + ) + { + unsafe + { + return NativeDiskANNMethods.create_index(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, (nint)readCallback, (nint)writeCallback, (nint)deleteCallback, (nint)readModifyWriteCallback); + } + } + + public nint RecreateIndex( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + delegate* unmanaged[Cdecl] readCallback, + delegate* unmanaged[Cdecl] writeCallback, + delegate* unmanaged[Cdecl] deleteCallback, + delegate* unmanaged[Cdecl] readModifyWriteCallback + ) + => CreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, readCallback, writeCallback, deleteCallback, readModifyWriteCallback); + + public void DropIndex(ulong context, nint index) + { + NativeDiskANNMethods.drop_index(context, index); + } + + public bool Insert(ulong context, nint index, ReadOnlySpan id, VectorValueType vectorType, ReadOnlySpan vector, ReadOnlySpan attributes) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); + int vector_len; + + if (vectorType == VectorValueType.FP32) + { + vector_len = vector.Length / sizeof(float); + } + else if (vectorType == VectorValueType.XB8) + { + vector_len = vector.Length; + } + else + { + throw new NotImplementedException($"{vectorType}"); + } + + var attributes_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(attributes)); + var attributes_len = attributes.Length; + + return NativeDiskANNMethods.insert(context, index, (nint)id_data, (nuint)id_len, vectorType, (nint)vector_data, (nuint)vector_len, (nint)attributes_data, (nuint)attributes_len) == 1; + } + + public bool Remove(ulong context, nint index, ReadOnlySpan id) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + return NativeDiskANNMethods.remove(context, index, (nint)id_data, (nuint)id_len) == 1; + } + + public int SearchVector( + ulong context, + nint index, + VectorValueType vectorType, + ReadOnlySpan vector, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + Span outputIds, + Span outputDistances, + out nint continuation + ) + { + var vector_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(vector)); + int vector_len; + + if (vectorType == VectorValueType.FP32) + { + vector_len = vector.Length / sizeof(float); + } + else if (vectorType == VectorValueType.XB8) + { + vector_len = vector.Length; + } + else + { + throw new NotImplementedException($"{vectorType}"); + } + + var filter_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var filter_len = filter.Length; + + var output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)); + var output_ids_len = outputIds.Length; + + var output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)); + var output_distances_len = outputDistances.Length; + + + continuation = 0; + ref var continuationRef = ref continuation; + var continuationAddr = (nint)Unsafe.AsPointer(ref continuationRef); + + return NativeDiskANNMethods.search_vector( + context, + index, + vectorType, + (nint)vector_data, + (nuint)vector_len, + delta, + searchExplorationFactor, + (nint)filter_data, + (nuint)filter_len, + (nuint)maxFilteringEffort, + (nint)output_ids, + (nuint)output_ids_len, + (nint)output_distances, + (nuint)output_distances_len, + continuationAddr + ); + } + + public int SearchElement( + ulong context, + nint index, + ReadOnlySpan id, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + Span outputIds, + Span outputDistances, + out nint continuation + ) + { + var id_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)); + var id_len = id.Length; + + var filter_data = Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)); + var filter_len = filter.Length; + + var output_ids = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)); + var output_ids_len = outputIds.Length; + + var output_distances = Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)); + var output_distances_len = outputDistances.Length; + + continuation = 0; + ref var continuationRef = ref continuation; + var continuationAddr = (nint)Unsafe.AsPointer(ref continuationRef); + + return NativeDiskANNMethods.search_element( + context, + index, + (nint)id_data, + (nuint)id_len, + delta, + searchExplorationFactor, + (nint)filter_data, + (nuint)filter_len, + (nuint)maxFilteringEffort, + (nint)output_ids, + (nuint)output_ids_len, + (nint)output_distances, + (nuint)output_distances_len, + continuationAddr + ); + } + + public int ContinueSearch(ulong context, nint index, nint continuation, Span outputIds, Span outputDistances, out nint newContinuation) + { + throw new NotImplementedException(); + } + + public bool TryGetEmbedding(ulong context, nint index, ReadOnlySpan id, Span dimensions) + { + throw new NotImplementedException(); + } + } + + public static partial class NativeDiskANNMethods + { + const string DISKANN_GARNET = "diskann_garnet"; + + [LibraryImport(DISKANN_GARNET)] + public static partial nint create_index( + ulong context, + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + nint readCallback, + nint writeCallback, + nint deleteCallback, + nint readModifyWriteCallback + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial void drop_index( + ulong context, + nint index + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte insert( + ulong context, + nint index, + nint id_data, + nuint id_len, + VectorValueType vector_value_type, + nint vector_data, + nuint vector_len, + nint attribute_data, + nuint attribute_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte remove( + ulong context, + nint index, + nint id_data, + nuint id_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial byte set_attribute( + ulong context, + nint index, + nint id_data, + nuint id_len, + nint attribute_data, + nuint attribute_len + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int search_vector( + ulong context, + nint index, + VectorValueType vector_value_type, + nint vector_data, + nuint vector_len, + float delta, + int search_exploration_factor, + nint filter_data, + nuint filter_len, + nuint max_filtering_effort, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int search_element( + ulong context, + nint index, + nint id_data, + nuint id_len, + float delta, + int search_exploration_factor, + nint filter_data, + nuint filter_len, + nuint max_filtering_effort, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial int continue_search( + ulong context, + nint index, + nint continuation, + nint output_ids, + nuint output_ids_len, + nint output_distances, + nuint output_distances_len, + nint new_continuation + ); + + [LibraryImport(DISKANN_GARNET)] + public static partial ulong card( + ulong context, + nint index + ); + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/RespServerSessionVectors.cs b/libs/server/Resp/Vector/RespServerSessionVectors.cs new file mode 100644 index 00000000000..dd44f1865f5 --- /dev/null +++ b/libs/server/Resp/Vector/RespServerSessionVectors.cs @@ -0,0 +1,1081 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Runtime.InteropServices; +using Garnet.common; +using Tsavorite.core; + +namespace Garnet.server +{ + internal sealed unsafe partial class RespServerSession : ServerSessionBase + { + private bool NetworkVADD(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + // VADD key [REDUCE dim] (FP32 | XB8 | VALUES num) vector element [CAS] [NOQUANT | Q8 | BIN | XPREQ8] [EF build-exploration-factor] [SETATTR attributes] [M numlinks] + // + // XB8 is a non-Redis extension, stands for: eXtension Binary 8-bit values - encodes [0, 255] per dimension + // XPREQ8 is a non-Redis extension, stands for: eXtension PREcalculated Quantization 8-bit - requests no quantization on pre-calculated [0, 255] values + + const int MinM = 4; + const int MaxM = 4_096; + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // key FP32|VALUES vector element + if (parseState.Count < 4) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + + var curIx = 1; + + var reduceDim = 0; + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("REDUCE"u8)) + { + curIx++; + if (!parseState.TryGetInt(curIx, out var reduceDimValue) || reduceDimValue <= 0) + { + return AbortWithErrorMessage("REDUCE dimension must be > 0"u8); + } + + reduceDim = reduceDimValue; + curIx++; + } + + var valueType = VectorValueType.Invalid; + byte[] rentedValues = null; + Span values = stackalloc byte[64 * sizeof(float)]; + + try + { + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("FP32"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + if ((asBytes.Length % sizeof(float)) != 0) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + + curIx++; + valueType = VectorValueType.FP32; + values = asBytes; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("VALUES"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + if (!parseState.TryGetInt(curIx, out var valueCount) || valueCount <= 0) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + curIx++; + + if (valueCount * sizeof(float) > values.Length) + { + values = rentedValues = ArrayPool.Shared.Rent(valueCount * sizeof(float)); + } + values = values[..(valueCount * sizeof(float))]; + + if (curIx + valueCount > parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + valueType = VectorValueType.FP32; + var floatValues = MemoryMarshal.Cast(values); + + for (var valueIx = 0; valueIx < valueCount; valueIx++) + { + if (!parseState.TryGetFloat(curIx, out floatValues[valueIx])) + { + return AbortWithErrorMessage("ERR invalid vector specification"); + } + + curIx++; + } + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + { + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + curIx++; + + valueType = VectorValueType.XB8; + values = asBytes; + } + + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VADD"); + } + + var element = parseState.GetArgSliceByRef(curIx); + curIx++; + + // Order for everything after element is unspecified + var cas = false; + VectorQuantType? quantType = null; + int? buildExplorationFactor = null; + ArgSlice? attributes = null; + int? numLinks = null; + + while (curIx < parseState.Count) + { + // REDUCE is illegal after values, no matter how specified + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("REDUCE"u8)) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + // Look for CAS + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("CAS"u8)) + { + if (cas) + { + return AbortWithErrorMessage("CAS specified multiple times"); + } + + // We ignore CAS, just remember we saw it + cas = true; + curIx++; + + continue; + } + + // Look for quantizer specs + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("NOQUANT"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.NoQuant; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("Q8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.Q8; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("BIN"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.Bin; + curIx++; + + continue; + } + else if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("XPREQ8"u8)) + { + if (quantType != null) + { + return AbortWithErrorMessage("Quantization specified multiple times"); + } + + quantType = VectorQuantType.XPreQ8; + curIx++; + + continue; + } + + // Look for build-exploration-factor + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("EF"u8)) + { + if (buildExplorationFactor != null) + { + return AbortWithErrorMessage("EF specified multiple times"); + } + + curIx++; + + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + if (!parseState.TryGetInt(curIx, out var buildExplorationFactorNonNull) || buildExplorationFactorNonNull <= 0) + { + return AbortWithErrorMessage("ERR invalid EF"); + } + + buildExplorationFactor = buildExplorationFactorNonNull; + curIx++; + continue; + } + + // Look for attributes + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("SETATTR"u8)) + { + if (attributes != null) + { + return AbortWithErrorMessage("SETATTR specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + attributes = parseState.GetArgSliceByRef(curIx); + curIx++; + + // You might think we need to validate attributes, but Redis actually lets anything through + + continue; + } + + // Look for num links + if (parseState.GetArgSliceByRef(curIx).Span.EqualsUpperCaseSpanIgnoringCase("M"u8)) + { + if (numLinks != null) + { + return AbortWithErrorMessage("M specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithErrorMessage("ERR invalid option after element"); + } + + if (!parseState.TryGetInt(curIx, out var numLinksNonNull) || numLinksNonNull < MinM || numLinksNonNull > MaxM) + { + return AbortWithErrorMessage("ERR invalid M"); + } + + numLinks = numLinksNonNull; + curIx++; + + continue; + } + + // Didn't recognize this option, error out + return AbortWithErrorMessage("ERR invalid option after element"); + } + + // Default unspecified options + quantType ??= VectorQuantType.Q8; + buildExplorationFactor ??= 200; + attributes ??= default; + numLinks ??= 16; + + // We need to reject these HERE because validation during create_index is very awkward + GarnetStatus res; + VectorManagerResult result; + ReadOnlySpan customErrMsg; + if (quantType == VectorQuantType.XPreQ8 && reduceDim != 0) + { + result = VectorManagerResult.BadParams; + res = GarnetStatus.OK; + customErrMsg = default; + } + else + { + res = storageApi.VectorSetAdd(key, reduceDim, valueType, ArgSlice.FromPinnedSpan(values), element, quantType.Value, buildExplorationFactor.Value, attributes.Value, numLinks.Value, out result, out customErrMsg); + } + + if (res == GarnetStatus.OK) + { + if (result == VectorManagerResult.OK) + { + if (respProtocolVersion == 3) + { + while (!RespWriteUtils.TryWriteTrue(ref dcurr, dend)) + SendAndReset(); + } + else + { + while (!RespWriteUtils.TryWriteInt32(1, ref dcurr, dend)) + SendAndReset(); + } + } + else if (result == VectorManagerResult.Duplicate) + { + if (respProtocolVersion == 3) + { + while (!RespWriteUtils.TryWriteFalse(ref dcurr, dend)) + SendAndReset(); + } + else + { + while (!RespWriteUtils.TryWriteInt32(0, ref dcurr, dend)) + SendAndReset(); + } + } + else if (result == VectorManagerResult.BadParams) + { + if (customErrMsg.IsEmpty) + { + return AbortWithErrorMessage("ERR asked quantization mismatch with existing vector set"u8); + } + + return AbortWithErrorMessage(customErrMsg); + } + } + else + { + return AbortWithErrorMessage($"Unexpected GarnetStatus: {res}"); + } + + return true; + } + finally + { + if (rentedValues != null) + { + ArrayPool.Shared.Return(rentedValues); + } + } + } + + private bool NetworkVSIM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + const int DefaultResultSetSize = 64; + const int DefaultIdSize = sizeof(ulong); + const int DefaultAttributeSize = 32; + + // VSIM key (ELE | FP32 | XB8 | VALUES num) (vector | element) [WITHSCORES] [WITHATTRIBS] [COUNT num] [EPSILON delta] [EF search-exploration - factor] [FILTER expression][FILTER-EF max - filtering - effort] [TRUTH][NOTHREAD] + // + // XB8 is a non-Redis extension, stands for: eXtension Binary 8-bit values - encodes [0, 255] per dimension + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count < 3) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + var kind = parseState.GetArgSliceByRef(1); + + var curIx = 2; + + ArgSlice? element; + + VectorValueType valueType = VectorValueType.Invalid; + byte[] rentedValues = null; + try + { + Span values = stackalloc byte[64 * sizeof(float)]; + if (kind.Span.EqualsUpperCaseSpanIgnoringCase("ELE"u8)) + { + element = parseState.GetArgSliceByRef(curIx); + values = default; + curIx++; + } + else + { + element = default; + if (kind.Span.EqualsUpperCaseSpanIgnoringCase("FP32"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + if ((asBytes.Length % sizeof(float)) != 0) + { + return AbortWithErrorMessage("FP32 values must be multiple of 4-bytes in size"); + } + + valueType = VectorValueType.FP32; + values = asBytes; + curIx++; + } + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("XB8"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + var asBytes = parseState.GetArgSliceByRef(curIx).Span; + + valueType = VectorValueType.XB8; + values = asBytes; + curIx++; + } + else if (kind.Span.EqualsUpperCaseSpanIgnoringCase("VALUES"u8)) + { + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var valueCount) || valueCount <= 0) + { + return AbortWithErrorMessage("VALUES count must > 0"); + } + curIx++; + + if (valueCount * sizeof(float) > values.Length) + { + values = rentedValues = ArrayPool.Shared.Rent(valueCount * sizeof(float)); + } + values = values[..(valueCount * sizeof(float))]; + + if (curIx + valueCount > parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + valueType = VectorValueType.FP32; + var floatValues = MemoryMarshal.Cast(values); + + for (var valueIx = 0; valueIx < valueCount; valueIx++) + { + if (!parseState.TryGetFloat(curIx, out floatValues[valueIx])) + { + return AbortWithErrorMessage("VALUES value must be valid float"); + } + + curIx++; + } + } + else + { + return AbortWithErrorMessage("VSIM expected ELE, FP32, or VALUES"); + } + } + + bool? withScores = null; + bool? withAttributes = null; + int? count = null; + float? delta = null; + int? searchExplorationFactor = null; + ArgSlice? filter = null; + int? maxFilteringEffort = null; + var truth = false; + var noThread = false; + + while (curIx < parseState.Count) + { + // Check for withScores + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("WITHSCORES"u8)) + { + if (withScores != null) + { + return AbortWithErrorMessage("WITHSCORES specified multiple times"); + } + + withScores = true; + curIx++; + continue; + } + + // Check for withAttributes + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("WITHATTRIBS"u8)) + { + if (withAttributes != null) + { + return AbortWithErrorMessage("WITHATTRIBS specified multiple times"); + } + + withAttributes = true; + curIx++; + continue; + } + + // Check for count + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("COUNT"u8)) + { + if (count != null) + { + return AbortWithErrorMessage("COUNT specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var countNonNull) || countNonNull < 0) + { + return AbortWithErrorMessage("COUNT must be integer >= 0"); + } + + count = countNonNull; + curIx++; + continue; + } + + // Check for delta + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("EPSILON"u8)) + { + if (delta != null) + { + return AbortWithErrorMessage("EPSILON specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetFloat(curIx, out var deltaNonNull) || deltaNonNull <= 0) + { + return AbortWithErrorMessage("EPSILON must be float > 0"); + } + + delta = deltaNonNull; + curIx++; + continue; + } + + // Check for search exploration factor + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("EF"u8)) + { + if (searchExplorationFactor != null) + { + return AbortWithErrorMessage("EF specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var searchExplorationFactorNonNull) || searchExplorationFactorNonNull < 0) + { + return AbortWithErrorMessage("EF must be >= 0"); + } + + searchExplorationFactor = searchExplorationFactorNonNull; + curIx++; + continue; + } + + // Check for filter + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("FILTER"u8)) + { + if (filter != null) + { + return AbortWithErrorMessage("FILTER specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + filter = parseState.GetArgSliceByRef(curIx); + curIx++; + + // TODO: validate filter + + continue; + } + + // Check for max filtering effort + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("FILTER-EF"u8)) + { + if (maxFilteringEffort != null) + { + return AbortWithErrorMessage("FILTER-EF specified multiple times"); + } + + curIx++; + if (curIx >= parseState.Count) + { + return AbortWithWrongNumberOfArguments("VSIM"); + } + + if (!parseState.TryGetInt(curIx, out var maxFilteringEffortNonNull) || maxFilteringEffortNonNull < 0) + { + return AbortWithErrorMessage("FILTER-EF must be >= 0"); + } + + maxFilteringEffort = maxFilteringEffortNonNull; + curIx++; + continue; + } + + // Check for truth + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("TRUTH"u8)) + { + if (truth) + { + + } + + // TODO: should we implement TRUTH? + truth = true; + curIx++; + continue; + } + + // Check for no thread + if (parseState.GetArgSliceByRef(curIx).ReadOnlySpan.EqualsUpperCaseSpanIgnoringCase("NOTHREAD"u8)) + { + if (noThread) + { + return AbortWithErrorMessage("NOTHREAD specified multiple times"); + } + + // We ignore NOTHREAD + noThread = true; + curIx++; + continue; + } + + // Didn't recognize this option, error out + return AbortWithErrorMessage("Unknown option"); + } + + // Default unspecified options + withScores ??= false; + withAttributes ??= false; + count ??= 10; + delta ??= 2f; + searchExplorationFactor ??= 100; + filter ??= default; + maxFilteringEffort ??= count.Value * 100; + + // TODO: these stackallocs are dangerous, need logic to avoid stack overflow + Span idSpace = stackalloc byte[(DefaultResultSetSize * DefaultIdSize) + (DefaultResultSetSize * sizeof(int))]; + Span distanceSpace = stackalloc float[DefaultResultSetSize]; + Span attributeSpace = withAttributes.Value ? stackalloc byte[(DefaultResultSetSize * DefaultAttributeSize) + (DefaultResultSetSize * sizeof(int))] : default; + + var idResult = SpanByteAndMemory.FromPinnedSpan(idSpace); + var distanceResult = SpanByteAndMemory.FromPinnedSpan(MemoryMarshal.Cast(distanceSpace)); + var attributeResult = SpanByteAndMemory.FromPinnedSpan(attributeSpace); + try + { + + GarnetStatus res; + VectorManagerResult vectorRes; + VectorIdFormat idFormat; + if (!element.HasValue) + { + res = storageApi.VectorSetValueSimilarity(key, valueType, ArgSlice.FromPinnedSpan(values), count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes); + } + else + { + res = storageApi.VectorSetElementSimilarity(key, element.Value, count.Value, delta.Value, searchExplorationFactor.Value, filter.Value, maxFilteringEffort.Value, withAttributes.Value, ref idResult, out idFormat, ref distanceResult, ref attributeResult, out vectorRes); + } + + if (res == GarnetStatus.NOTFOUND) + { + // Vector Set does not exist + + while (!RespWriteUtils.TryWriteEmptyArray(ref dcurr, dend)) + SendAndReset(); + } + else if (res == GarnetStatus.OK) + { + if (vectorRes == VectorManagerResult.MissingElement) + { + while (!RespWriteUtils.TryWriteError("Element not in Vector Set"u8, ref dcurr, dend)) + SendAndReset(); + } + else if (vectorRes == VectorManagerResult.OK) + { + if (respProtocolVersion == 3) + { + // TODO: this is rather complicated, so punt for now + throw new NotImplementedException(); + } + else + { + var remainingIds = idResult.AsReadOnlySpan(); + var distancesSpan = MemoryMarshal.Cast(distanceResult.AsReadOnlySpan()); + var remaininingAttributes = withAttributes.Value ? attributeResult.AsReadOnlySpan() : default; + + var arrayItemCount = distancesSpan.Length; + if (withScores.Value) + { + arrayItemCount += distancesSpan.Length; + } + if (withAttributes.Value) + { + arrayItemCount += distancesSpan.Length; + } + + while (!RespWriteUtils.TryWriteArrayLength(arrayItemCount, ref dcurr, dend)) + SendAndReset(); + + for (var resultIndex = 0; resultIndex < distancesSpan.Length; resultIndex++) + { + ReadOnlySpan elementData; + + if (idFormat == VectorIdFormat.I32LengthPrefixed) + { + if (remainingIds.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for result id length at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + var elementLen = BinaryPrimitives.ReadInt32LittleEndian(remainingIds); + + if (remainingIds.Length < sizeof(int) + elementLen) + { + throw new GarnetException($"Insufficient bytes for result of length={elementLen} at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + elementData = remainingIds.Slice(sizeof(int), elementLen); + remainingIds = remainingIds[(sizeof(int) + elementLen)..]; + } + else if (idFormat == VectorIdFormat.FixedI32) + { + if (remainingIds.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for result id length at resultIndex={resultIndex}: {Convert.ToHexString(distanceResult.AsReadOnlySpan())}"); + } + + elementData = remainingIds[..sizeof(int)]; + remainingIds = remainingIds[sizeof(int)..]; + } + else + { + throw new GarnetException($"Unexpected id format: {idFormat}"); + } + + while (!RespWriteUtils.TryWriteBulkString(elementData, ref dcurr, dend)) + SendAndReset(); + + if (withScores.Value) + { + var distance = distancesSpan[resultIndex]; + + while (!RespWriteUtils.TryWriteDoubleBulkString(distance, ref dcurr, dend)) + SendAndReset(); + } + + if (withAttributes.Value) + { + if (remaininingAttributes.Length < sizeof(int)) + { + throw new GarnetException($"Insufficient bytes for attribute length at resultIndex={resultIndex}: {Convert.ToHexString(attributeResult.AsReadOnlySpan())}"); + } + + var attrLen = BinaryPrimitives.ReadInt32LittleEndian(remaininingAttributes); + var attr = remaininingAttributes.Slice(sizeof(int), attrLen); + remaininingAttributes = remaininingAttributes[(sizeof(int) + attrLen)..]; + + while (!RespWriteUtils.TryWriteBulkString(attr, ref dcurr, dend)) + SendAndReset(); + } + } + } + } + else + { + throw new GarnetException($"Unexpected {nameof(VectorManagerResult)}: {vectorRes}"); + } + } + else + { + throw new GarnetException($"Unexpected {nameof(GarnetStatus)}: {res}"); + } + + return true; + } + finally + { + idResult.Memory?.Dispose(); + distanceResult.Memory?.Dispose(); + attributeResult.Memory?.Dispose(); + } + } + finally + { + if (rentedValues != null) + { + ArrayPool.Shared.Return(rentedValues); + } + } + } + + private bool NetworkVEMB(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + const int DefaultResultSetSize = 64; + + // VEMB key element [RAW] + + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count < 2 || parseState.Count > 3) + { + return AbortWithWrongNumberOfArguments("VEMB"); + } + + ref var key = ref parseState.GetArgSliceByRef(0); + var elem = parseState.GetArgSliceByRef(1); + + var raw = false; + if (parseState.Count == 3) + { + if (!parseState.GetArgSliceByRef(2).Span.EqualsUpperCaseSpanIgnoringCase("RAW"u8)) + { + return AbortWithErrorMessage("Unexpected option to VSIM"); + } + + raw = true; + } + + // TODO: what do we do here? + if (raw) + { + throw new NotImplementedException(); + } + + Span distanceSpace = stackalloc float[DefaultResultSetSize]; + + var distanceResult = SpanByteAndMemory.FromPinnedSpan(MemoryMarshal.Cast(distanceSpace)); + + try + { + var res = storageApi.VectorSetEmbedding(key, elem, ref distanceResult); + + if (res == GarnetStatus.OK) + { + var distanceSpan = MemoryMarshal.Cast(distanceResult.AsReadOnlySpan()); + + while (!RespWriteUtils.TryWriteArrayLength(distanceSpan.Length, ref dcurr, dend)) + SendAndReset(); + + for (var i = 0; i < distanceSpan.Length; i++) + { + while (!RespWriteUtils.TryWriteDoubleBulkString(distanceSpan[i], ref dcurr, dend)) + SendAndReset(); + } + } + else + { + while (!RespWriteUtils.TryWriteEmptyArray(ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + finally + { + if (!distanceResult.IsSpanByte) + { + distanceResult.Memory.Dispose(); + } + } + } + + private bool NetworkVCARD(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVDIM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 1) + return AbortWithWrongNumberOfArguments("VDIM"); + + var key = parseState.GetArgSliceByRef(0); + + var res = storageApi.VectorSetDimensions(key, out var dimensions); + + if (res == GarnetStatus.NOTFOUND) + { + while (!RespWriteUtils.TryWriteError("ERR Key not found"u8, ref dcurr, dend)) + SendAndReset(); + } + else if (res == GarnetStatus.WRONGTYPE) + { + while (!RespWriteUtils.TryWriteError("ERR Not a Vector Set"u8, ref dcurr, dend)) + SendAndReset(); + } + else + { + while (!RespWriteUtils.TryWriteInt32(dimensions, ref dcurr, dend)) + SendAndReset(); + } + + return true; + } + + private bool NetworkVGETATTR(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVINFO(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVISMEMBER(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVLINKS(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVRANDMEMBER(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVREM(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + if (parseState.Count != 2) + return AbortWithWrongNumberOfArguments("VREM"); + + var key = parseState.GetArgSliceByRef(0); + var elem = parseState.GetArgSliceByRef(1); + + var res = storageApi.VectorSetRemove(key, elem); + + var resp = res == GarnetStatus.OK ? 1 : 0; + + while (!RespWriteUtils.TryWriteInt32(resp, ref dcurr, dend)) + SendAndReset(); + + return true; + } + + private bool NetworkVSETATTR(ref TGarnetApi storageApi) + where TGarnetApi : IGarnetApi + { + if (!storageSession.vectorManager.IsEnabled) + { + return AbortWithErrorMessage("ERR Vector Set (preview) commands are not enabled"); + } + + // TODO: implement! + + while (!RespWriteUtils.TryWriteDirect(CmdStrings.RESP_OK, ref dcurr, dend)) + SendAndReset(); + + return true; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Callbacks.cs b/libs/server/Resp/Vector/VectorManager.Callbacks.cs new file mode 100644 index 00000000000..07b37e2a2db --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Callbacks.cs @@ -0,0 +1,316 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods which calls back into to interact with Garnet. + /// + public sealed partial class VectorManager + { + public unsafe struct VectorReadBatch : IReadArgBatch + { + public int Count { get; } + + private readonly ulong context; + private readonly SpanByte lengthPrefixedKeys; + + public readonly unsafe delegate* unmanaged[Cdecl, SuppressGCTransition] callback; + public readonly nint callbackContext; + + private int currentIndex; + + private int currentLen; + private byte* currentPtr; + + private bool hasPending; + + public VectorReadBatch(nint callback, nint callbackContext, ulong context, uint keyCount, SpanByte lengthPrefixedKeys) + { + this.context = context; + this.lengthPrefixedKeys = lengthPrefixedKeys; + + this.callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])callback; + this.callbackContext = callbackContext; + + currentIndex = 0; + Count = (int)keyCount; + + currentPtr = this.lengthPrefixedKeys.ToPointerWithMetadata(); + currentLen = *(int*)currentPtr; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void AdvanceTo(int i) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + if (i == currentIndex) + { + return; + } + + // Undo namespace mutation + *(int*)currentPtr = currentLen; + + // Most likely case, we're going one forward + if (i == (currentIndex + 1)) + { + currentPtr += currentLen + sizeof(int); // Skip length prefix too + + Debug.Assert(currentPtr < lengthPrefixedKeys.ToPointerWithMetadata() + lengthPrefixedKeys.Length, "About to access out of bounds data"); + + currentLen = *currentPtr; + + currentIndex = i; + + return; + } + + // Next most likely case, we're going back to the start + currentPtr = lengthPrefixedKeys.ToPointerWithMetadata(); + currentLen = *(int*)currentPtr; + currentIndex = 0; + + if (i == 0) + { + return; + } + + SlowPath(ref this, i); + + // For the case where we're not just scanning or rolling back to 0, just iterate + // + // This should basically never happen + [MethodImpl(MethodImplOptions.NoInlining)] + static void SlowPath(ref VectorReadBatch self, int i) + { + for (var subI = 1; subI <= i; subI++) + { + self.AdvanceTo(subI); + } + } + } + + /// + public void GetKey(int i, out SpanByte key) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + AdvanceTo(i); + + key = SpanByte.FromPinnedPointer(currentPtr + 3, currentLen + 1); + key.MarkNamespace(); + key.SetNamespaceInPayload((byte)context); + } + + /// + public readonly void GetInput(int i, out VectorInput input) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + input = default; + input.CallbackContext = callbackContext; + input.Callback = (nint)callback; + input.Index = i; + } + + /// + public readonly void GetOutput(int i, out SpanByte output) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + // Don't care, won't be used + Unsafe.SkipInit(out output); + } + + /// + public readonly void SetOutput(int i, SpanByte output) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + } + + /// + public void SetStatus(int i, Status status) + { + Debug.Assert(i >= 0 && i < Count, "Trying to advance out of bounds"); + + hasPending |= status.IsPending; + } + + internal readonly void CompletePending(ref TContext objectContext) + where TContext : ITsavoriteContext + { + // Undo mutations + *(int*)currentPtr = currentLen; + + if (hasPending) + { + _ = objectContext.CompletePending(wait: true); + } + } + } + + private unsafe delegate* unmanaged[Cdecl] ReadCallbackPtr { get; } = &ReadCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] WriteCallbackPtr { get; } = &WriteCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] DeleteCallbackPtr { get; } = &DeleteCallbackUnmanaged; + private unsafe delegate* unmanaged[Cdecl] ReadModifyWriteCallbackPtr { get; } = &ReadModifyWriteCallbackUnmanaged; + + /// + /// Used to thread the active across p/invoke and reverse p/invoke boundaries into DiskANN. + /// + /// Not the most elegent option, but work so long as DiskANN remains single threaded. + /// + [ThreadStatic] + internal static StorageSession ActiveThreadSession; + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe void ReadCallbackUnmanaged( + ulong context, + uint numKeys, + nint keysData, + nuint keysLength, + nint dataCallback, + nint dataCallbackContext + ) + { + // dataCallback takes: index, dataCallbackContext, data pointer, data length, and returns nothing + + var enumerable = new VectorReadBatch(dataCallback, dataCallbackContext, context, numKeys, SpanByte.FromPinnedPointer((byte*)keysData, (int)keysLength)); + + ref var ctx = ref ActiveThreadSession.vectorContext; + + ctx.ReadWithPrefetch(ref enumerable); + + enumerable.CompletePending(ref ctx); + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte WriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + + ref var ctx = ref ActiveThreadSession.vectorContext; + VectorInput input = default; + var valueSpan = SpanByte.FromPinnedPointer((byte*)writeData, (int)writeLength); + SpanByte outputSpan = default; + + var status = ctx.Upsert(ref keyWithNamespace, ref input, ref valueSpan, ref outputSpan); + if (status.IsPending) + { + CompletePending(ref status, ref outputSpan, ref ctx); + } + + return status.IsCompletedSuccessfully ? (byte)1 : default; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte DeleteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + + ref var ctx = ref ActiveThreadSession.vectorContext; + + var status = ctx.Delete(ref keyWithNamespace); + Debug.Assert(!status.IsPending, "Deletes should never go async"); + + return status.IsCompletedSuccessfully && status.Found ? (byte)1 : default; + } + + [UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])] + private static unsafe byte ReadModifyWriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext) + { + var keyWithNamespace = MarkDiskANNKeyWithNamespace(context, keyData, keyLength); + + ref var ctx = ref ActiveThreadSession.vectorContext; + + VectorInput input = default; + input.Callback = dataCallback; + input.CallbackContext = dataCallbackContext; + input.WriteDesiredSize = (int)writeLength; + + var status = ctx.RMW(ref keyWithNamespace, ref input); + if (status.IsPending) + { + SpanByte ignored = default; + + CompletePending(ref status, ref ignored, ref ctx); + } + + return status.IsCompletedSuccessfully ? (byte)1 : default; + } + + private static unsafe bool ReadSizeUnknown(ulong context, ReadOnlySpan key, ref SpanByteAndMemory value) + { + Span distinctKey = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(distinctKey); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload((byte)context); + key.CopyTo(keyWithNamespace.AsSpan()); + + ref var ctx = ref ActiveThreadSession.vectorContext; + + tryAgain: + VectorInput input = new(); + input.ReadDesiredSize = -1; + fixed (byte* ptr = value.AsSpan()) + { + SpanByte asSpanByte = new(value.Length, (nint)ptr); + + var status = ctx.Read(ref keyWithNamespace, ref input, ref asSpanByte); + if (status.IsPending) + { + CompletePending(ref status, ref asSpanByte, ref ctx); + } + + if (!status.Found) + { + value.Length = 0; + return false; + } + + if (input.ReadDesiredSize > asSpanByte.Length) + { + value.Memory?.Dispose(); + var newAlloc = MemoryPool.Shared.Rent(input.ReadDesiredSize); + value = new(newAlloc, newAlloc.Memory.Length); + goto tryAgain; + } + + value.Length = asSpanByte.Length; + return true; + } + } + + /// + /// Get a which covers (keyData, keyLength), but has a namespace component based on . + /// + /// Attempts to do this in place. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe SpanByte MarkDiskANNKeyWithNamespace(ulong context, nint keyData, nuint keyLength) + { + // DiskANN guarantees we have 4-bytes worth of unused data right before the key + var keyPtr = (byte*)keyData; + var keyNamespaceByte = keyPtr - 1; + + // TODO: if/when namespace can be > 4-bytes, we'll need to copy here + + var keyWithNamespace = SpanByte.FromPinnedPointer(keyNamespaceByte, (int)(keyLength + 1)); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload((byte)context); + + return keyWithNamespace; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Cleanup.cs b/libs/server/Resp/Vector/VectorManager.Cleanup.cs new file mode 100644 index 00000000000..d630c8b49ce --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Cleanup.cs @@ -0,0 +1,164 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Garnet.networking; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods related to cleaning up data after a Vector Set is deleted. + /// + public sealed partial class VectorManager + { + /// + /// Used as part of scanning post-index-delete to cleanup abandoned data. + /// + private sealed class PostDropCleanupFunctions : IScanIteratorFunctions + { + private readonly StorageSession storageSession; + private readonly FrozenSet contexts; + + public PostDropCleanupFunctions(StorageSession storageSession, HashSet contexts) + { + this.contexts = contexts.ToFrozenSet(); + this.storageSession = storageSession; + } + + public bool ConcurrentReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + => SingleReader(ref key, ref value, recordMetadata, numberOfRecords, out cursorRecordResult); + + public void OnException(Exception exception, long numberOfRecords) { } + public bool OnStart(long beginAddress, long endAddress) => true; + public void OnStop(bool completed, long numberOfRecords) { } + + public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) + { + if (key.MetadataSize != 1) + { + // Not Vector Set, ignore + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + + var ns = key.GetNamespaceInPayload(); + var pairedContext = (ulong)ns & ~(ContextStep - 1); + if (!contexts.Contains(pairedContext)) + { + // Vector Set, but not one we're scanning for + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + + // Delete it + var status = storageSession.vectorContext.Delete(ref key, 0); + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref storageSession.vectorContext); + } + + cursorRecordResult = CursorRecordResult.Accept; + return true; + } + } + + private readonly Channel cleanupTaskChannel; + private readonly Task cleanupTask; + private readonly Func getCleanupSession; + + private async Task RunCleanupTaskAsync() + { + // Each drop index will queue a null object here + // We'll handle multiple at once if possible, but using a channel simplifies cancellation and dispose + await foreach (var ignored in cleanupTaskChannel.Reader.ReadAllAsync()) + { + try + { + HashSet needCleanup; + lock (this) + { + needCleanup = contextMetadata.GetNeedCleanup(); + } + + if (needCleanup == null) + { + // Previous run already got here, so bail + continue; + } + + // TODO: this doesn't work with non-RESP impls... which maybe we don't care about? + using var cleanupSession = (RespServerSession)getCleanupSession(); + if (cleanupSession.activeDbId != dbId && !cleanupSession.TrySwitchActiveDatabaseSession(dbId)) + { + throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); + } + + PostDropCleanupFunctions callbacks = new(cleanupSession.storageSession, needCleanup); + + ref var ctx = ref cleanupSession.storageSession.vectorContext; + + // Scan whole keyspace (sigh) and remove any associated data + // + // We don't really have a choice here, just do it + _ = ctx.Session.Iterate(ref callbacks); + + lock (this) + { + foreach (var cleanedUp in needCleanup) + { + contextMetadata.FinishedCleaningUp(cleanedUp); + } + } + + UpdateContextMetadata(ref ctx); + } + catch (Exception e) + { + logger?.LogError(e, "Failure during background cleanup of deleted vector sets, implies storage leak"); + } + } + } + + /// + /// After an index is dropped, called to start the process of removing ancillary data (elements, neighbor lists, attributes, etc.). + /// + internal void CleanupDroppedIndex(ref TContext ctx, ReadOnlySpan index) + where TContext : ITsavoriteContext + { + ReadIndex(index, out var context, out _, out _, out _, out _, out _, out _, out _); + + CleanupDroppedIndex(ref ctx, context); + } + + /// + /// After an index is dropped, called to start the process of removing ancillary data (elements, neighbor lists, attributes, etc.). + /// + internal void CleanupDroppedIndex(ref TContext ctx, ulong context) + where TContext : ITsavoriteContext + { + lock (this) + { + contextMetadata.MarkCleaningUp(context); + } + + UpdateContextMetadata(ref ctx); + + // Wake up cleanup task + var writeRes = cleanupTaskChannel.Writer.TryWrite(null); + Debug.Assert(writeRes, "Request for cleanup failed, this should never happen"); + } + + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs b/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs new file mode 100644 index 00000000000..1e1f71ce3cc --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.ContextMetadata.cs @@ -0,0 +1,458 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods for managing , which tracks process wide + /// information about different contexts. + /// + /// is persisted to the log when modified, but a copy is kept in memory for rapid access. + /// + public sealed partial class VectorManager + { + /// + /// Used for tracking which contexts are currently active. + /// + [StructLayout(LayoutKind.Explicit, Size = Size)] + internal struct ContextMetadata + { + [InlineArray(64)] + private struct HashSlots + { + private ushort element0; + } + + internal const int Size = + (4 * sizeof(ulong)) + // Bitmaps + (64 * sizeof(ushort)); // HashSlots for assigned contexts + + [FieldOffset(0)] + public ulong Version; + + [FieldOffset(8)] + private ulong inUse; + + [FieldOffset(16)] + private ulong cleaningUp; + + [FieldOffset(24)] + private ulong migrating; + + [FieldOffset(32)] + private HashSlots slots; + + public readonly bool IsInUse(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + return (inUse & mask) != 0; + } + + public readonly bool IsMigrating(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + return (migrating & mask) != 0; + } + + public readonly HashSet GetNamespacesForHashSlots(HashSet hashSlots) + { + HashSet ret = null; + + var remaining = inUse; + while (remaining != 0) + { + var inUseIx = BitOperations.TrailingZeroCount(remaining); + var inUseMask = 1UL << inUseIx; + + remaining &= ~inUseMask; + + if ((cleaningUp & inUseMask) != 0) + { + // If something is being cleaned up, no reason to migrate it + continue; + } + + var hashSlot = slots[inUseIx]; + if (!hashSlots.Contains(hashSlot)) + { + // Active, but not a target + continue; + } + + ret ??= []; + + var nsStart = ContextStep * (ulong)inUseIx; + for (var i = 0U; i < ContextStep; i++) + { + _ = ret.Add(nsStart + i); + } + } + + return ret; + } + + public readonly ulong NextNotInUse() + { + var ignoringZero = inUse | 1; + + var bit = (ulong)BitOperations.TrailingZeroCount(~ignoringZero & (ulong)-(long)(~ignoringZero)); + + if (bit == 64) + { + throw new GarnetException("All possible Vector Sets allocated"); + } + + var ret = bit * ContextStep; + + return ret; + } + + public bool TryReserveForMigration(int count, out List reserved) + { + var ignoringZero = inUse | 1; + + var available = BitOperations.PopCount(~ignoringZero); + + if (available < count) + { + reserved = null; + return false; + } + + reserved = new(); + for (var i = 0; i < count; i++) + { + var ctx = NextNotInUse(); + reserved.Add(ctx); + + MarkInUse(ctx, ushort.MaxValue); // HashSlot isn't known yet, so use an invalid value + MarkMigrating(ctx); + } + + return true; + } + + public void MarkInUse(ulong context, ushort hashSlot) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) == 0, "About to mark context which is already in use"); + inUse |= mask; + + slots[(int)bitIx] = hashSlot; + + Version++; + } + + public void MarkMigrating(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "About to mark migrating a context which is not in use"); + Debug.Assert((migrating & mask) == 0, "About to mark migrating a context which is already migrating"); + migrating |= mask; + + Version++; + } + + public void MarkMigrationComplete(ulong context, ushort hashSlot) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "Should already be in use"); + Debug.Assert((migrating & mask) != 0, "Should be migrating target"); + Debug.Assert(slots[(int)bitIx] == ushort.MaxValue, "Hash slot should not be known yet"); + + migrating &= ~mask; + + slots[(int)bitIx] = hashSlot; + + Version++; + } + + public void MarkCleaningUp(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "About to mark for cleanup when not actually in use"); + Debug.Assert((cleaningUp & mask) == 0, "About to mark for cleanup when already marked"); + cleaningUp |= mask; + + // If this slot were migrating, it isn't anymore + migrating &= ~mask; + + // Leave the slot around, we need it + + Version++; + } + + public void FinishedCleaningUp(ulong context) + { + Debug.Assert(context > 0, "Context 0 is reserved, should never queried"); + Debug.Assert((context % ContextStep) == 0, "Should only consider whole block of context, not a sub-bit"); + Debug.Assert(context <= byte.MaxValue, "Context larger than expected"); + + var bitIx = context / ContextStep; + var mask = 1UL << (byte)bitIx; + + Debug.Assert((inUse & mask) != 0, "Cleaned up context which isn't in use"); + Debug.Assert((cleaningUp & mask) != 0, "Cleaned up context not marked for it"); + cleaningUp &= ~mask; + inUse &= ~mask; + + slots[(int)bitIx] = 0; + + Version++; + } + + public readonly HashSet GetNeedCleanup() + { + if (cleaningUp == 0) + { + return null; + } + + var ret = new HashSet(); + + var remaining = cleaningUp; + while (remaining != 0UL) + { + var ix = BitOperations.TrailingZeroCount(remaining); + + _ = ret.Add((ulong)ix * ContextStep); + + remaining &= ~(1UL << (byte)ix); + } + + return ret; + } + + public readonly HashSet GetMigrating() + { + if (migrating == 0) + { + return null; + } + + var ret = new HashSet(); + + var remaining = migrating; + while (remaining != 0UL) + { + var ix = BitOperations.TrailingZeroCount(remaining); + + _ = ret.Add((ulong)ix * ContextStep); + + remaining &= ~(1UL << (byte)ix); + } + + return ret; + } + + /// + public override readonly string ToString() + { + // Just for debugging purposes + + var sb = new StringBuilder(); + sb.AppendLine(); + _ = sb.AppendLine($"Version: {Version}"); + var mask = 1UL; + var ix = 0; + while (mask != 0) + { + var isInUse = (inUse & mask) != 0; + var isMigrating = (migrating & mask) != 0; + var cleanup = (cleaningUp & mask) != 0; + + var hashSlot = this.slots[ix]; + + if (isInUse || isMigrating || cleanup) + { + var ctxStart = (ulong)ix * ContextStep; + var ctxEnd = ctxStart + ContextStep - 1; + + sb.AppendLine($"[{ctxStart:00}-{ctxEnd:00}): {(isInUse ? "in-use " : "")}{(isMigrating ? "migrating " : "")}{(cleanup ? "cleanup" : "")}"); + } + + mask <<= 1; + ix++; + } + + return sb.ToString(); + } + } + + private ContextMetadata contextMetadata; + + /// + /// Get a new unique context for a vector set. + /// + /// This value is guaranteed to not be shared by any other vector set in the store. + /// + private ulong NextVectorSetContext(ushort hashSlot) + { + var start = Stopwatch.GetTimestamp(); + + // TODO: This retry is no good, but will go away when namespaces >= 256 are possible + while (true) + { + // Lock isn't amazing, but _new_ vector set creation should be rare + // So just serializing it all is easier. + try + { + ulong nextFree; + lock (this) + { + nextFree = contextMetadata.NextNotInUse(); + + contextMetadata.MarkInUse(nextFree, hashSlot); + } + return nextFree; + } + catch (Exception e) + { + logger?.LogError(e, "NextContext not available, delaying and retrying"); + } + + if (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(30)) + { + lock (this) + { + if (contextMetadata.GetNeedCleanup() == null) + { + throw new GarnetException("No available Vector Sets contexts to allocate, none scheduled for cleanup"); + } + } + + // Wait a little bit for cleanup to make progress + Thread.Sleep(1_000); + } + else + { + throw new GarnetException("No available Vector Sets contexts to allocate, timeout reached"); + } + } + } + + /// + /// Obtain some number of contexts for migrating Vector Sets. + /// + /// The return contexts are unavailable for other use, but are not yet "live" for visibility purposes. + /// + public bool TryReserveContextsForMigration(ref TContext ctx, int count, out List contexts) + where TContext : ITsavoriteContext + { + lock (this) + { + if (!contextMetadata.TryReserveForMigration(count, out contexts)) + { + contexts = null; + return false; + } + } + + UpdateContextMetadata(ref ctx); + + return true; + } + + /// + /// Called when an index creation succeeds to flush into the store. + /// + private void UpdateContextMetadata(ref TContext ctx) + where TContext : ITsavoriteContext + { + Span keySpan = stackalloc byte[1]; + Span dataSpan = stackalloc byte[ContextMetadata.Size]; + + lock (this) + { + MemoryMarshal.Cast(dataSpan)[0] = contextMetadata; + } + + var key = SpanByte.FromPinnedSpan(keySpan); + + key.MarkNamespace(); + key.SetNamespaceInPayload(0); + + VectorInput input = default; + input.Callback = 0; + input.WriteDesiredSize = ContextMetadata.Size; + unsafe + { + input.CallbackContext = (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(dataSpan)); + } + + var data = SpanByte.FromPinnedSpan(dataSpan); + + var status = ctx.RMW(ref key, ref input); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + } + + /// + /// Find all namespaces in use by vector sets that are logically members of the given hash slots. + /// + /// Meant for use during migration. + /// + public HashSet GetNamespacesForHashSlots(HashSet hashSlots) + { + lock (this) + { + return contextMetadata.GetNamespacesForHashSlots(hashSlots); + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Index.cs b/libs/server/Resp/Vector/VectorManager.Index.cs new file mode 100644 index 00000000000..a57f3c02c56 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Index.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Methods for managing , which is the information about an index created by DiskANN. + /// + /// is stored under the "visible" key in the log, and thus is the common entry point + /// for all operations. + /// + public sealed partial class VectorManager + { + [StructLayout(LayoutKind.Explicit, Size = Size)] + private struct Index + { + internal const int Size = 52; + + [FieldOffset(0)] + public ulong Context; + [FieldOffset(8)] + public ulong IndexPtr; + [FieldOffset(16)] + public uint Dimensions; + [FieldOffset(20)] + public uint ReduceDims; + [FieldOffset(24)] + public uint NumLinks; + [FieldOffset(28)] + public uint BuildExplorationFactor; + [FieldOffset(32)] + public VectorQuantType QuantType; + [FieldOffset(36)] + public Guid ProcessInstanceId; + } + + /// + /// Construct a new index, and stash enough data to recover it with . + /// + internal void CreateIndex( + uint dimensions, + uint reduceDims, + VectorQuantType quantType, + uint buildExplorationFactor, + uint numLinks, + ulong newContext, + nint newIndexPtr, + ref SpanByte indexValue) + { + AssertHaveStorageSession(); + + var indexSpan = indexValue.AsSpan(); + + Debug.Assert((newContext % 8) == 0 && newContext != 0, "Illegal context provided"); + Debug.Assert(Unsafe.SizeOf() == Index.Size, "Constant index size is incorrect"); + + if (indexSpan.Length != Index.Size) + { + logger?.LogCritical("Acquired space for vector set index does not match expectations, {Length} != {Size}", indexSpan.Length, Index.Size); + throw new GarnetException($"Acquired space for vector set index does not match expectations, {indexSpan.Length} != {Index.Size}"); + } + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexSpan)); + asIndex.Context = newContext; + asIndex.Dimensions = dimensions; + asIndex.ReduceDims = reduceDims; + asIndex.QuantType = quantType; + asIndex.BuildExplorationFactor = buildExplorationFactor; + asIndex.NumLinks = numLinks; + asIndex.IndexPtr = (ulong)newIndexPtr; + asIndex.ProcessInstanceId = processInstanceId; + } + + /// + /// Recreate an index that was created by a prior instance of Garnet. + /// + /// This implies the index still has element data, but the pointer is garbage. + /// + internal void RecreateIndex(nint newIndexPtr, ref SpanByte indexValue) + { + AssertHaveStorageSession(); + + var indexSpan = indexValue.AsSpan(); + + if (indexSpan.Length != Index.Size) + { + logger?.LogCritical("Acquired space for vector set index does not match expectations, {Length} != {Size}", indexSpan.Length, Index.Size); + throw new GarnetException($"Acquired space for vector set index does not match expectations, {indexSpan.Length} != {Index.Size}"); + } + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out var indexProcessInstanceId); + Debug.Assert(processInstanceId != indexProcessInstanceId, "Shouldn't be recreating an index that matched our instance id"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexSpan)); + asIndex.IndexPtr = (ulong)newIndexPtr; + asIndex.ProcessInstanceId = processInstanceId; + } + + /// + /// Drop an index previously constructed with . + /// + internal void DropIndex(ReadOnlySpan indexValue) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out _, out _, out _, out var indexPtr, out var indexProcessInstanceId); + + if (indexProcessInstanceId != processInstanceId) + { + // We never actually spun this index up, so nothing to drop + return; + } + + Service.DropIndex(context, indexPtr); + } + + /// + /// Deconstruct index stored in the value under a Vector Set index key. + /// + public static void ReadIndex( + ReadOnlySpan indexValue, + out ulong context, + out uint dimensions, + out uint reduceDims, + out VectorQuantType quantType, + out uint buildExplorationFactor, + out uint numLinks, + out nint indexPtr, + out Guid processInstanceId + ) + { + Debug.Assert(indexValue.Length == Index.Size, $"Index size is incorrect ({indexValue.Length} != {Index.Size}), implies vector set index is probably corrupted"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexValue)); + + context = asIndex.Context; + dimensions = asIndex.Dimensions; + reduceDims = asIndex.ReduceDims; + quantType = asIndex.QuantType; + buildExplorationFactor = asIndex.BuildExplorationFactor; + numLinks = asIndex.NumLinks; + indexPtr = (nint)asIndex.IndexPtr; + processInstanceId = asIndex.ProcessInstanceId; + + Debug.Assert((context % ContextStep) == 0, $"Context ({context}) not as expected (% 4 == {context % 4}), vector set index is probably corrupted"); + } + + /// + /// Update the context (which defines a range of namespaces) stored in a given index. + /// + /// Doing this also smashes the ProcessInstanceId, so the destination node won't + /// think it's already creating this index. + /// + public static void SetContextForMigration(Span indexValue, ulong newContext) + { + Debug.Assert(newContext != 0, "0 is special, should not be assigning to an index"); + Debug.Assert(indexValue.Length == Index.Size, $"Index size is incorrect ({indexValue.Length} != {Index.Size}), implies vector set index is probably corrupted"); + + ref var asIndex = ref Unsafe.As(ref MemoryMarshal.GetReference(indexValue)); + + asIndex.Context = newContext; + asIndex.ProcessInstanceId = MigratedInstanceId; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Locking.cs b/libs/server/Resp/Vector/VectorManager.Locking.cs new file mode 100644 index 00000000000..9d601d696f1 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Locking.cs @@ -0,0 +1,447 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Garnet.common; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Methods managing locking around Vector Sets. + /// + /// Locking is bespoke because of read-like nature of most Vector Set operations, and the re-entrancy implied by DiskANN callbacks. + /// + public sealed partial class VectorManager + { + /// + /// Used to scope a shared lock related to a Vector Set operation. + /// + /// Disposing this releases the lock and exits the storage session context on the current thread. + /// + internal readonly ref struct ReadVectorLock : IDisposable + { + private readonly ref readonly ReadOptimizedLock lockableCtx; + private readonly int lockToken; + + internal ReadVectorLock(ref readonly ReadOptimizedLock lockableCtx, int lockToken) + { + this.lockToken = lockToken; + this.lockableCtx = ref lockableCtx; + } + + /// + public void Dispose() + { + Debug.Assert(ActiveThreadSession != null, "Shouldn't exit context when not in one"); + ActiveThreadSession = null; + + if (Unsafe.IsNullRef(in lockableCtx)) + { + return; + } + + lockableCtx.ReleaseSharedLock(lockToken); + } + } + + /// + /// Used to scope exclusive locks to exclusive Vector Set operation (delete, migrate, etc.). + /// + /// Disposing this releases the lock and exits the storage session context on the current thread. + /// + internal readonly ref struct ExclusiveVectorLock : IDisposable + { + private readonly ref readonly ReadOptimizedLock lockableCtx; + private readonly int lockToken; + + internal ExclusiveVectorLock(ref readonly ReadOptimizedLock lockableCtx, int lockToken) + { + this.lockToken = lockToken; + this.lockableCtx = ref lockableCtx; + } + + /// + public void Dispose() + { + Debug.Assert(ActiveThreadSession != null, "Shouldn't exit context when not in one"); + ActiveThreadSession = null; + + if (Unsafe.IsNullRef(in lockableCtx)) + { + return; + } + + lockableCtx.ReleaseExclusiveLock(lockToken); + } + } + + private readonly ReadOptimizedLock vectorSetLocks; + + /// + /// Returns true for indexes that were created via a previous instance of . + /// + /// Such indexes still have element data, but the index pointer to the DiskANN bits are invalid. + /// + internal bool NeedsRecreate(ReadOnlySpan indexConfig) + { + ReadIndex(indexConfig, out _, out _, out _, out _, out _, out _, out _, out var indexProcessInstanceId); + + return indexProcessInstanceId != processInstanceId; + } + + /// + /// Utility method that will read an vector set index out but not create one. + /// + /// It will however RECREATE one if needed. + /// + /// Returns a disposable that prevents the index from being deleted while undisposed. + /// + internal ReadVectorLock ReadVectorIndex(StorageSession storageSession, ref SpanByte key, ref RawStringInput input, scoped Span indexSpan, out GarnetStatus status) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var keyHash = storageSession.basicContext.GetKeyHash(ref key); + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + var readCmd = input.header.cmd; + + while (true) + { + input.header.cmd = readCmd; + input.arg1 = 0; + + vectorSetLocks.AcquireSharedLock(keyHash, out var sharedLockToken); + + GarnetStatus readRes; + try + { + readRes = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + Debug.Assert(indexConfig.IsSpanByte, "Should never need to move index onto the heap"); + } + catch + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + throw; + } + + var needsRecreate = readRes == GarnetStatus.OK && NeedsRecreate(indexConfig.AsReadOnlySpan()); + + if (needsRecreate) + { + if (!vectorSetLocks.TryPromoteSharedLock(keyHash, sharedLockToken, out var exclusiveLockToken)) + { + // Release the SHARED lock if we can't promote and try again + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + continue; + } + + ReadIndex(indexSpan, out var indexContext, out var dims, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out _, out _); + + input.arg1 = RecreateIndexArg; + + nint newlyAllocatedIndex; + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.header.cmd = RespCommand.VADD; + input.arg1 = RecreateIndexArg; + + input.parseState.EnsureCapacity(11); + + // Save off for recreation + input.parseState.SetArgument(9, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); // Strictly we don't _need_ this, but it keeps everything else aligned nicely + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + + GarnetStatus writeRes; + try + { + try + { + writeRes = storageSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + + if (writeRes != GarnetStatus.OK) + { + // If we didn't write, drop index so we don't leak it + Service.DropIndex(indexContext, newlyAllocatedIndex); + } + } + catch + { + // Drop to avoid leak on error + Service.DropIndex(indexContext, newlyAllocatedIndex); + throw; + } + } + catch + { + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + throw; + } + + if (writeRes == GarnetStatus.OK) + { + // Try again so we don't hold an exclusive lock while performing a search + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + continue; + } + else + { + status = writeRes; + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + return default; + } + } + else if (readRes != GarnetStatus.OK) + { + status = readRes; + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + return default; + } + + status = GarnetStatus.OK; + return new(in vectorSetLocks, sharedLockToken); + } + } + + /// + /// Utility method that will read vector set index out, create one if it doesn't exist, or RECREATE one if needed. + /// + /// Returns a disposable that prevents the index from being deleted while undisposed. + /// + internal ReadVectorLock ReadOrCreateVectorIndex( + StorageSession storageSession, + ref SpanByte key, + ref RawStringInput input, + scoped Span indexSpan, + out GarnetStatus status + ) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var keyHash = storageSession.basicContext.GetKeyHash(ref key); + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + while (true) + { + input.arg1 = 0; + + vectorSetLocks.AcquireSharedLock(keyHash, out var sharedLockToken); + + GarnetStatus readRes; + try + { + readRes = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + Debug.Assert(indexConfig.IsSpanByte, "Should never need to move index onto the heap"); + } + catch + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + throw; + } + + var needsRecreate = readRes == GarnetStatus.OK && storageSession.vectorManager.NeedsRecreate(indexSpan); + if (readRes == GarnetStatus.NOTFOUND || needsRecreate) + { + if (!vectorSetLocks.TryPromoteSharedLock(keyHash, sharedLockToken, out var exclusiveLockToken)) + { + // Release the SHARED lock if we can't promote and try again + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + continue; + } + + ulong indexContext; + nint newlyAllocatedIndex; + if (needsRecreate) + { + ReadIndex(indexSpan, out indexContext, out var dims, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out _, out _); + + input.arg1 = RecreateIndexArg; + + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(indexContext, dims, reduceDims, quantType, buildExplorationFactor, numLinks, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.parseState.EnsureCapacity(11); + + // Save off for recreation + input.parseState.SetArgument(9, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); // Strictly we don't _need_ this, but it keeps everything else aligned nicely + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + } + else + { + // Create a new index, grab a new context + + // We must associate the index with a hash slot at creation time to enable future migrations + // TODO: RENAME and friends need to also update this data + var slot = HashSlotUtils.HashSlot(ref key); + + indexContext = NextVectorSetContext(slot); + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + // ValueType is here, skipping during index creation + // Values is here, skipping during index creation + // Element is here, skipping during index creation + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + // Attributes is here, skipping during index creation + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + + unsafe + { + newlyAllocatedIndex = Service.CreateIndex(indexContext, dims, reduceDims, quantizer, buildExplorationFactor, numLinks, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + input.parseState.EnsureCapacity(11); + + // Save off for insertion + input.parseState.SetArgument(9, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref indexContext, 1)))); + input.parseState.SetArgument(10, ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1)))); + } + + GarnetStatus writeRes; + try + { + try + { + writeRes = storageSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + + if (writeRes != GarnetStatus.OK) + { + // Insertion failed, drop index + Service.DropIndex(indexContext, newlyAllocatedIndex); + + // If the failure was for a brand new index, free up the context too + if (!needsRecreate) + { + CleanupDroppedIndex(ref ActiveThreadSession.vectorContext, indexContext); + } + } + } + catch + { + if (newlyAllocatedIndex != 0) + { + // Drop to avoid a leak on error + Service.DropIndex(indexContext, newlyAllocatedIndex); + + // If the failure was for a brand new index, free up the context too + if (!needsRecreate) + { + CleanupDroppedIndex(ref ActiveThreadSession.vectorContext, indexContext); + } + } + + throw; + } + + if (!needsRecreate) + { + UpdateContextMetadata(ref storageSession.vectorContext); + } + } + catch + { + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + throw; + } + + if (writeRes == GarnetStatus.OK) + { + // Try again so we don't hold an exclusive lock while adding a vector (which might be time consuming) + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + continue; + } + else + { + status = writeRes; + vectorSetLocks.ReleaseExclusiveLock(exclusiveLockToken); + + return default; + } + } + else if (readRes != GarnetStatus.OK) + { + vectorSetLocks.ReleaseSharedLock(sharedLockToken); + + status = readRes; + return default; + } + + status = GarnetStatus.OK; + return new(in vectorSetLocks, sharedLockToken); + } + } + + /// + /// Acquire exclusive lock over a given key. + /// + private ExclusiveVectorLock AcquireExclusiveLocks(StorageSession storageSession, ref SpanByte key) + { + var keyHash = storageSession.lockableContext.GetKeyHash(key); + + vectorSetLocks.AcquireExclusiveLock(keyHash, out var exclusiveLockToken); + + return new(in vectorSetLocks, exclusiveLockToken); + } + + /// + /// Utility method that will read vector set index out, and acquire exclusive locks to allow it to be deleted. + /// + internal ExclusiveVectorLock ReadForDeleteVectorIndex(StorageSession storageSession, ref SpanByte key, ref RawStringInput input, scoped Span indexSpan, out GarnetStatus status) + { + Debug.Assert(indexSpan.Length == IndexSizeBytes, "Insufficient space for index"); + + Debug.Assert(ActiveThreadSession == null, "Shouldn't enter context when already in one"); + ActiveThreadSession = storageSession; + + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + // Get the index + var acquiredLock = AcquireExclusiveLocks(storageSession, ref key); + try + { + status = storageSession.Read_MainStore(ref key, ref input, ref indexConfig, ref storageSession.basicContext); + } + catch + { + acquiredLock.Dispose(); + + throw; + } + + if (status != GarnetStatus.OK) + { + // This can happen if something else successfully deleted before we acquired the lock + + acquiredLock.Dispose(); + return default; + } + + return acquiredLock; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Migration.cs b/libs/server/Resp/Vector/VectorManager.Migration.cs new file mode 100644 index 00000000000..64ee9bf4b71 --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Migration.cs @@ -0,0 +1,316 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods related to migrating Vector Sets between different primaries. + /// + /// This is bespoke because normal migration is key based, but Vector Set migration has to move whole namespaces first. + /// + public sealed partial class VectorManager + { + // This is a V8 GUID based on 'GARNET MIGRATION' ASCII string + // It cannot collide with processInstanceIds because it's v8 + // It's unlikely other projects will select the value, so it's unlikely to collide with other v8s + // If it ends up in logs, it's ASCII equivalent looks suspcious enough to lead back here + private static readonly Guid MigratedInstanceId = new("4e524147-5445-8d20-8947-524154494f4e"); + + /// + /// Called to handle a key in a namespace being received during a migration. + /// + /// These keys are what DiskANN stores, that is they are "element" data. + /// + /// The index is handled specially by . + /// + public void HandleMigratedElementKey( + ref BasicContext basicCtx, + ref BasicContext vectorCtx, + ref SpanByte key, + ref SpanByte value + ) + { + Debug.Assert(key.MetadataSize == 1, "Should have namespace if we're migrating a key"); + +#if DEBUG + // Do some extra sanity checking in DEBUG builds + lock (this) + { + var ns = key.GetNamespaceInPayload(); + var context = (ulong)(ns & ~(ContextStep - 1)); + Debug.Assert(contextMetadata.IsInUse(context), "Shouldn't be migrating to an unused context"); + Debug.Assert(contextMetadata.IsMigrating(context), "Shouldn't be migrating to context not marked for it"); + Debug.Assert(!(contextMetadata.GetNeedCleanup()?.Contains(context) ?? false), "Shouldn't be migrating into context being deleted"); + } +#endif + + VectorInput input = default; + SpanByte outputSpan = default; + + var status = vectorCtx.Upsert(ref key, ref input, ref value, ref outputSpan); + if (status.IsPending) + { + CompletePending(ref status, ref outputSpan, ref vectorCtx); + } + + if (!status.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to migrate key, this should fail migration"); + } + + ReplicateMigratedElementKey(ref basicCtx, ref key, ref value, logger); + + // Fake a write for post-migration replication + static void ReplicateMigratedElementKey(ref BasicContext basicCtx, ref SpanByte key, ref SpanByte value, ILogger logger) + { + RawStringInput input = default; + + input.header.cmd = RespCommand.VADD; + input.arg1 = MigrateElementKeyLogArg; + + input.parseState.InitializeWithArguments([ArgSlice.FromPinnedSpan(key.AsReadOnlySpanWithMetadata()), ArgSlice.FromPinnedSpan(value.AsReadOnlySpan())]); + + SpanByte dummyKey = default; + SpanByteAndMemory dummyOutput = default; + + var res = basicCtx.RMW(ref dummyKey, ref input, ref dummyOutput); + + if (res.IsPending) + { + CompletePending(ref res, ref dummyOutput, ref basicCtx); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for migrated Vector Set key/value into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set write operation for key/value migration, data loss may occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref BasicContext basicCtx) + { + _ = basicCtx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + } + + /// + /// Called to handle a Vector Set key being received during a migration. These are "index" keys. + /// + /// This is the metadata stuff Garnet creates, DiskANN is not involved. + /// + /// Invoked after all the namespace data is moved via . + /// + public void HandleMigratedIndexKey( + GarnetDatabase db, + StoreWrapper storeWrapper, + ref SpanByte key, + ref SpanByte value) + { + Debug.Assert(key.MetadataSize != 1, "Shouldn't have a namespace if we're migrating a Vector Set index"); + + RawStringInput input = default; + input.header.cmd = RespCommand.VADD; + input.arg1 = RecreateIndexArg; + + ReadIndex(value.AsReadOnlySpan(), out var context, out var dimensions, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out _, out var processInstanceId); + + Debug.Assert(processInstanceId == MigratedInstanceId, "Shouldn't receive a real process instance id during a migration"); + + // Extra validation in DEBUG +#if DEBUG + lock (this) + { + Debug.Assert(contextMetadata.IsInUse(context), "Context should be assigned if we're migrating"); + Debug.Assert(contextMetadata.IsMigrating(context), "Context should be marked migrating if we're moving an index key in"); + } +#endif + + // Spin up a new Storage Session is we don't have one + StorageSession newStorageSession; + if (ActiveThreadSession == null) + { + Debug.Assert(db != null, "Must have DB if session is not already set"); + Debug.Assert(storeWrapper != null, "Must have StoreWrapper if session is not already set"); + + ActiveThreadSession = newStorageSession = new StorageSession(storeWrapper, new(), null, null, db.Id, this, this.logger); + } + else + { + newStorageSession = null; + } + + try + { + // Prepare as a psuedo-VADD + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dimensions, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + ArgSlice valueTypeArg = default; + ArgSlice valuesArg = default; + ArgSlice elementArg = default; + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantType, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + ArgSlice attributesArg = default; + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + + nint newlyAllocatedIndex; + unsafe + { + newlyAllocatedIndex = Service.RecreateIndex(context, dimensions, reduceDims, quantType, buildExplorationFactor, numLinks, ReadCallbackPtr, WriteCallbackPtr, DeleteCallbackPtr, ReadModifyWriteCallbackPtr); + } + + var ctxArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref context, 1))); + var indexArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref newlyAllocatedIndex, 1))); + + input.parseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg, ctxArg, indexArg]); + + Span indexSpan = stackalloc byte[Index.Size]; + var indexConfig = SpanByteAndMemory.FromPinnedSpan(indexSpan); + + // Exclusive lock to prevent other modification of this key + + using (AcquireExclusiveLocks(ActiveThreadSession, ref key)) + { + // Perform the write + var writeRes = ActiveThreadSession.RMW_MainStore(ref key, ref input, ref indexConfig, ref ActiveThreadSession.basicContext); + if (writeRes != GarnetStatus.OK) + { + Service.DropIndex(context, newlyAllocatedIndex); + throw new GarnetException("Failed to import migrated Vector Set index, aborting migration"); + } + + var hashSlot = HashSlotUtils.HashSlot(ref key); + + lock (this) + { + contextMetadata.MarkMigrationComplete(context, hashSlot); + } + + UpdateContextMetadata(ref ActiveThreadSession.vectorContext); + + // For REPLICAs which are following, we need to fake up a write + ReplicateMigratedIndexKey(ref ActiveThreadSession.basicContext, ref key, ref value, context, logger); + } + } + finally + { + ActiveThreadSession = null; + + // If we spun up a new storage session, dispose it + newStorageSession?.Dispose(); + } + + // Fake a write for post-migration replication + static void ReplicateMigratedIndexKey( + ref BasicContext basicCtx, + ref SpanByte key, + ref SpanByte value, + ulong context, + ILogger logger) + { + RawStringInput input = default; + + input.header.cmd = RespCommand.VADD; + input.arg1 = MigrateIndexKeyLogArg; + + var contextArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref context, 1))); + + input.parseState.InitializeWithArguments([ArgSlice.FromPinnedSpan(key.AsReadOnlySpanWithMetadata()), ArgSlice.FromPinnedSpan(value.AsReadOnlySpan()), contextArg]); + + SpanByte dummyKey = default; + SpanByteAndMemory dummyOutput = default; + + var res = basicCtx.RMW(ref dummyKey, ref input, ref dummyOutput); + + if (res.IsPending) + { + CompletePending(ref res, ref dummyOutput, ref basicCtx); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for migrated Vector Set index into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set write operation for index migration, data loss may occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref BasicContext basicCtx) + { + _ = basicCtx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + } + + /// + /// Find namespaces used by the given keys, IFF they are Vector Sets. They may (and often will) not be. + /// + /// Meant for use during migration. + /// + public unsafe HashSet GetNamespacesForKeys(StoreWrapper storeWrapper, IEnumerable keys, Dictionary vectorSetKeys) + { + // TODO: Ideally we wouldn't make a new session for this, but it's fine for now + using var storageSession = new StorageSession(storeWrapper, new(), null, null, storeWrapper.DefaultDatabase.Id, this, logger); + + HashSet namespaces = null; + + Span indexSpan = stackalloc byte[Index.Size]; + + foreach (var key in keys) + { + fixed (byte* keyPtr = key) + { + var keySpan = SpanByte.FromPinnedPointer(keyPtr, key.Length); + + // Dummy command, we just need something Vector Set-y + RawStringInput input = default; + input.header.cmd = RespCommand.VSIM; + + using (ReadVectorIndex(storageSession, ref keySpan, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + continue; + } + + namespaces ??= []; + + ReadIndex(indexSpan, out var context, out _, out _, out _, out _, out _, out _, out _); + for (var i = 0UL; i < ContextStep; i++) + { + _ = namespaces.Add(context + i); + } + + vectorSetKeys[key] = indexSpan.ToArray(); + } + } + } + + return namespaces; + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.Replication.cs b/libs/server/Resp/Vector/VectorManager.Replication.cs new file mode 100644 index 00000000000..04e9bb9b82f --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.Replication.cs @@ -0,0 +1,541 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + /// + /// Methods for managing the replication of Vector Sets from primaries to other replicas. + /// + /// This is very bespoke because Vector Set operations are phrased as reads for most things, which + /// bypasses Garnet's usual replication logic. + /// + public sealed partial class VectorManager + { + /// + /// Represents a copy of a VADD being replayed during replication. + /// + private readonly record struct VADDReplicationState(Memory Key, uint Dims, uint ReduceDims, VectorValueType ValueType, Memory Values, Memory Element, VectorQuantType Quantizer, uint BuildExplorationFactor, Memory Attributes, uint NumLinks) + { + } + + private int replicationReplayStarted; + private long replicationReplayPendingVAdds; + private readonly ManualResetEventSlim replicationBlockEvent; + private readonly Channel replicationReplayChannel; + private readonly Task[] replicationReplayTasks; + + /// + /// For replication purposes, we need a write against the main log. + /// + /// But we don't actually want to do the (expensive) vector ops as part of a write. + /// + /// So this fakes up a modify operation that we can then intercept as part of replication. + /// + /// This the Primary part, on a Replica runs. + /// + internal void ReplicateVectorSetAdd(ref SpanByte key, ref RawStringInput input, ref TContext context) + where TContext : ITsavoriteContext + { + Debug.Assert(input.header.cmd == RespCommand.VADD, "Shouldn't be called with anything but VADD inputs"); + + var inputCopy = input; + inputCopy.arg1 = VADDAppendLogArg; + + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + Span dummyBytes = stackalloc byte[4]; + var dummy = SpanByteAndMemory.FromPinnedSpan(dummyBytes); + + var res = context.RMW(ref keyWithNamespace, ref inputCopy, ref dummy); + + if (res.IsPending) + { + CompletePending(ref res, ref dummy, ref context); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for VADD into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set add operation for replication, data loss will occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref TContext context) + { + _ = context.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + + /// + /// For replication purposes, we need a write against the main log. + /// + /// But we don't actually want to do the (expensive) vector ops as part of a write. + /// + /// So this fakes up a modify operation that we can then intercept as part of replication. + /// + /// This the Primary part, on a Replica runs. + /// + internal void ReplicateVectorSetRemove(ref SpanByte key, ref SpanByte element, ref RawStringInput input, ref TContext context) + where TContext : ITsavoriteContext + { + Debug.Assert(input.header.cmd == RespCommand.VREM, "Shouldn't be called with anything but VREM inputs"); + + var inputCopy = input; + inputCopy.arg1 = VREMAppendLogArg; + + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + Span dummyBytes = stackalloc byte[4]; + var dummy = SpanByteAndMemory.FromPinnedSpan(dummyBytes); + + inputCopy.parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(element.AsReadOnlySpan())); + + var res = context.RMW(ref keyWithNamespace, ref inputCopy, ref dummy); + + if (res.IsPending) + { + CompletePending(ref res, ref dummy, ref context); + } + + if (!res.IsCompletedSuccessfully) + { + logger?.LogCritical("Failed to inject replication write for VREM into log, result was {res}", res); + throw new GarnetException("Couldn't synthesize Vector Set remove operation for replication, data loss will occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref SpanByteAndMemory output, ref TContext context) + { + _ = context.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + + /// + /// After an index is dropped, called to cleanup state injected by + /// + /// Amounts to delete a synthetic key in namespace 0. + /// + internal void DropVectorSetReplicationKey(SpanByte key, ref TContext context) + where TContext : ITsavoriteContext + { + Span keyWithNamespaceBytes = stackalloc byte[key.Length + 1]; + var keyWithNamespace = SpanByte.FromPinnedSpan(keyWithNamespaceBytes); + keyWithNamespace.MarkNamespace(); + keyWithNamespace.SetNamespaceInPayload(0); + key.AsReadOnlySpan().CopyTo(keyWithNamespace.AsSpan()); + + Span dummyBytes = stackalloc byte[4]; + var dummy = SpanByteAndMemory.FromPinnedSpan(dummyBytes); + + var res = context.Delete(ref keyWithNamespace); + + if (res.IsPending) + { + CompletePending(ref res, ref context); + } + + if (!res.IsCompletedSuccessfully) + { + throw new GarnetException("Couldn't synthesize Vector Set add operation for replication, data loss will occur"); + } + + // Helper to complete read/writes during vector set synthetic op goes async + static void CompletePending(ref Status status, ref TContext context) + { + _ = context.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + more = completedOutputs.Next(); + Debug.Assert(!more); + completedOutputs.Dispose(); + } + } + + /// + /// Vector Set adds are phrased as reads (once the index is created), so they require special handling. + /// + /// Operations that are faked up by running on the Primary get diverted here on a Replica. + /// + internal void HandleVectorSetAddReplication(StorageSession currentSession, Func obtainServerSession, ref SpanByte keyWithNamespace, ref RawStringInput input) + { + if (input.arg1 == MigrateElementKeyLogArg) + { + // These are special, injecting by a PRIMARY applying migration operations + // These get replayed on REPLICAs typically, though role changes might still cause these + // to get replayed on now-primary nodes + + var key = input.parseState.GetArgSliceByRef(0).SpanByte; + var value = input.parseState.GetArgSliceByRef(1).SpanByte; + + // TODO: Namespace is present, but not actually transmitted + // This presumably becomes unnecessary in Store v2 + key.MarkNamespace(); + + var ns = key.GetNamespaceInPayload(); + + // REPLICAs wouldn't have seen a reservation message, so allocate this on demand + var ctx = ns & ~(ContextStep - 1); + if (!contextMetadata.IsMigrating(ctx)) + { + var needsUpdate = false; + + lock (this) + { + if (!contextMetadata.IsMigrating(ctx)) + { + contextMetadata.MarkInUse(ctx, ushort.MaxValue); + contextMetadata.MarkMigrating(ctx); + + needsUpdate = true; + } + } + + if (needsUpdate) + { + UpdateContextMetadata(ref currentSession.vectorContext); + } + } + + HandleMigratedElementKey(ref currentSession.basicContext, ref currentSession.vectorContext, ref key, ref value); + return; + } + else if (input.arg1 == MigrateIndexKeyLogArg) + { + // These also injected by a PRIMARY applying migration operations + + var key = input.parseState.GetArgSliceByRef(0).SpanByte; + var value = input.parseState.GetArgSliceByRef(1).SpanByte; + var context = MemoryMarshal.Cast(input.parseState.GetArgSliceByRef(2).Span)[0]; + + // Most of the time a replica will have seen an element moving before now + // but if you a migrate an EMPTY Vector Set that is not necessarily true + // + // So force reservation now + if (!contextMetadata.IsMigrating(context)) + { + var needsUpdate = false; + + lock (this) + { + if (!contextMetadata.IsMigrating(context)) + { + contextMetadata.MarkInUse(context, ushort.MaxValue); + contextMetadata.MarkMigrating(context); + + needsUpdate = true; + } + } + + if (needsUpdate) + { + UpdateContextMetadata(ref currentSession.vectorContext); + } + } + + ActiveThreadSession = currentSession; + try + { + HandleMigratedIndexKey(null, null, ref key, ref value); + } + finally + { + ActiveThreadSession = null; + } + return; + } + + Debug.Assert(input.arg1 == VADDAppendLogArg, "Unexpected operation during replication"); + + // Undo mangling that got replication going + var inputCopy = input; + inputCopy.arg1 = default; + var keyBytesArr = ArrayPool.Shared.Rent(keyWithNamespace.Length - 1); + var keyBytes = keyBytesArr.AsMemory()[..(keyWithNamespace.Length - 1)]; + + keyWithNamespace.AsReadOnlySpan().CopyTo(keyBytes.Span); + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + var valueType = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(2).Span); + var values = input.parseState.GetArgSliceByRef(3).Span; + var element = input.parseState.GetArgSliceByRef(4).Span; + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + var attributes = input.parseState.GetArgSliceByRef(7).Span; + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + + // We have to make copies (and they need to be on the heap) to pass to background tasks + var valuesBytes = ArrayPool.Shared.Rent(values.Length).AsMemory()[..values.Length]; + values.CopyTo(valuesBytes.Span); + + var elementBytes = ArrayPool.Shared.Rent(element.Length).AsMemory()[..element.Length]; + element.CopyTo(elementBytes.Span); + + var attributesBytes = ArrayPool.Shared.Rent(attributes.Length).AsMemory()[..attributes.Length]; + attributes.CopyTo(attributesBytes.Span); + + // Spin up replication replay tasks on first use + if (replicationReplayStarted == 0) + { + if (Interlocked.CompareExchange(ref replicationReplayStarted, 1, 0) == 0) + { + StartReplicationReplayTasks(this, obtainServerSession); + } + } + + // We need a running count of pending VADDs so WaitForVectorOperationsToComplete can work + _ = Interlocked.Increment(ref replicationReplayPendingVAdds); + replicationBlockEvent.Reset(); + var queued = replicationReplayChannel.Writer.TryWrite(new(keyBytes, dims, reduceDims, valueType, valuesBytes, elementBytes, quantizer, buildExplorationFactor, attributesBytes, numLinks)); + if (!queued) + { + // Can occur if we're being Disposed + var pending = Interlocked.Decrement(ref replicationReplayPendingVAdds); + if (pending == 0) + { + replicationBlockEvent.Set(); + } + } + + static void StartReplicationReplayTasks(VectorManager self, Func obtainServerSession) + { + self.logger?.LogInformation("Starting {numTasks} replication tasks for VADDs", self.replicationReplayTasks.Length); + + for (var i = 0; i < self.replicationReplayTasks.Length; i++) + { + // Allocate session outside of task so we fail "nicely" if something goes wrong with acquiring them + var allocatedSession = obtainServerSession(); + if (allocatedSession.activeDbId != self.dbId && !allocatedSession.TrySwitchActiveDatabaseSession(self.dbId)) + { + allocatedSession.Dispose(); + throw new GarnetException($"Could not switch replication replay session to {self.dbId}, replication will fail"); + } + + self.replicationReplayTasks[i] = Task.Factory.StartNew( + async () => + { + try + { + using (allocatedSession) + { + var reader = self.replicationReplayChannel.Reader; + + SessionParseState reusableParseState = default; + reusableParseState.Initialize(11); + + await foreach (var entry in reader.ReadAllAsync()) + { + try + { + try + { + ApplyVectorSetAdd(self, allocatedSession.storageSession, entry, ref reusableParseState); + } + finally + { + var pending = Interlocked.Decrement(ref self.replicationReplayPendingVAdds); + Debug.Assert(pending >= 0, "Pending VADD ops has fallen below 0 after processing op"); + + if (pending == 0) + { + self.replicationBlockEvent.Set(); + } + } + } + catch + { + self.logger?.LogCritical( + "Faulting ApplyVectorSetAdd ({key}, {dims}, {reducedDims}, {valueType}, 0x{values}, 0x{element}, {quantizer}, {bef}, {attributes}, {numLinks}", + Encoding.UTF8.GetString(entry.Key.Span), + entry.Dims, + entry.ReduceDims, + entry.ValueType, + Convert.ToBase64String(entry.Values.Span), + Convert.ToBase64String(entry.Values.Span), + entry.Quantizer, + entry.BuildExplorationFactor, + Encoding.UTF8.GetString(entry.Attributes.Span), + entry.NumLinks + ); + + throw; + } + } + } + } + catch (Exception e) + { + self.logger?.LogCritical(e, "Unexpected abort of replication replay task"); + throw; + } + } + ); + } + } + + // Actually apply a replicated VADD + static unsafe void ApplyVectorSetAdd(VectorManager self, StorageSession storageSession, VADDReplicationState state, ref SessionParseState reusableParseState) + { + ref var context = ref storageSession.basicContext; + + var (keyBytes, dims, reduceDims, valueType, valuesBytes, elementBytes, quantizer, buildExplorationFactor, attributesBytes, numLinks) = state; + try + { + Span indexSpan = stackalloc byte[IndexSizeBytes]; + + fixed (byte* keyPtr = keyBytes.Span) + fixed (byte* valuesPtr = valuesBytes.Span) + fixed (byte* elementPtr = elementBytes.Span) + fixed (byte* attributesPtr = attributesBytes.Span) + { + var key = SpanByte.FromPinnedPointer(keyPtr, keyBytes.Length); + var values = SpanByte.FromPinnedPointer(valuesPtr, valuesBytes.Length); + var element = SpanByte.FromPinnedPointer(elementPtr, elementBytes.Length); + var attributes = SpanByte.FromPinnedPointer(attributesPtr, attributesBytes.Length); + + var indexBytes = stackalloc byte[IndexSizeBytes]; + SpanByteAndMemory indexConfig = new(indexBytes, IndexSizeBytes); + + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dims, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + var valueTypeArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref valueType, 1))); + var valuesArg = ArgSlice.FromPinnedSpan(values.AsReadOnlySpan()); + var elementArg = ArgSlice.FromPinnedSpan(element.AsReadOnlySpan()); + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantizer, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + var attributesArg = ArgSlice.FromPinnedSpan(attributes.AsReadOnlySpan()); + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + + reusableParseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg]); + + var input = new RawStringInput(RespCommand.VADD, ref reusableParseState); + + // Equivalent to VectorStoreOps.VectorSetAdd + // + // We still need locking here because the replays may proceed in parallel + + using (self.ReadOrCreateVectorIndex(storageSession, ref key, ref input, indexSpan, out var status)) + { + Debug.Assert(status == GarnetStatus.OK, "Replication should only occur when an add is successful, so index must exist"); + + var addRes = self.TryAdd(indexSpan, element.AsReadOnlySpan(), valueType, values.AsReadOnlySpan(), attributes.AsReadOnlySpan(), reduceDims, quantizer, buildExplorationFactor, numLinks, out _); + + if (addRes != VectorManagerResult.OK) + { + throw new GarnetException("Failed to add to vector set index during AOF sync, this should never happen but will cause data loss if it does"); + } + } + } + } + finally + { + if (MemoryMarshal.TryGetArray(keyBytes, out var toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(valuesBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(elementBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + + if (MemoryMarshal.TryGetArray(attributesBytes, out toFree)) + { + ArrayPool.Shared.Return(toFree.Array); + } + } + } + } + + /// + /// Vector Set removes are phrased as reads (once the index is created), so they require special handling. + /// + /// Operations that are faked up by running on the Primary get diverted here on a Replica. + /// + internal void HandleVectorSetRemoveReplication(StorageSession storageSession, ref SpanByte key, ref RawStringInput input) + { + Span indexSpan = stackalloc byte[IndexSizeBytes]; + var element = input.parseState.GetArgSliceByRef(0); + + // Replication adds a (0) namespace - remove it + Span keyWithoutNamespaceSpan = stackalloc byte[key.Length - 1]; + key.AsReadOnlySpan().CopyTo(keyWithoutNamespaceSpan); + var keyWithoutNamespace = SpanByte.FromPinnedSpan(keyWithoutNamespaceSpan); + + var inputCopy = input; + inputCopy.arg1 = default; + + using (ReadVectorIndex(storageSession, ref keyWithoutNamespace, ref inputCopy, indexSpan, out var status)) + { + Debug.Assert(status == GarnetStatus.OK, "Replication should only occur when a remove is successful, so index must exist"); + + var addRes = TryRemove(indexSpan, element.ReadOnlySpan); + + if (addRes != VectorManagerResult.OK) + { + throw new GarnetException("Failed to remove from vector set index during AOF sync, this should never happen but will cause data loss if it does"); + } + } + } + + /// + /// Wait until all ops passed to have completed. + /// + public void WaitForVectorOperationsToComplete() + { + try + { + replicationBlockEvent.Wait(); + } + catch (ObjectDisposedException) + { + // This is possible during dispose + // + // Dispose already takes pains to drain everything before disposing, so this is safe to ignore + } + } + } +} \ No newline at end of file diff --git a/libs/server/Resp/Vector/VectorManager.cs b/libs/server/Resp/Vector/VectorManager.cs new file mode 100644 index 00000000000..6cd4b3792be --- /dev/null +++ b/libs/server/Resp/Vector/VectorManager.cs @@ -0,0 +1,743 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading.Channels; +using System.Threading.Tasks; +using Garnet.common; +using Garnet.networking; +using Microsoft.Extensions.Logging; +using Tsavorite.core; + +namespace Garnet.server +{ + using MainStoreAllocator = SpanByteAllocator>; + using MainStoreFunctions = StoreFunctions; + + public enum VectorManagerResult + { + Invalid = 0, + + OK, + BadParams, + Duplicate, + MissingElement, + } + + /// + /// Methods for managing an implementation of various vector operations. + /// + public sealed partial class VectorManager : IDisposable + { + // MUST BE A POWER OF 2 + public const ulong ContextStep = 8; + + internal const int IndexSizeBytes = Index.Size; + internal const long VADDAppendLogArg = long.MinValue; + internal const long DeleteAfterDropArg = VADDAppendLogArg + 1; + internal const long RecreateIndexArg = DeleteAfterDropArg + 1; + internal const long VREMAppendLogArg = RecreateIndexArg + 1; + internal const long MigrateElementKeyLogArg = VREMAppendLogArg + 1; + internal const long MigrateIndexKeyLogArg = MigrateElementKeyLogArg + 1; + + /// + /// Minimum size of an id is assumed to be at least 4 bytes + a length prefix. + /// + private const int MinimumSpacePerId = sizeof(int) + 4; + + /// + /// The process wide instances of DiskANN. + /// + /// We only need the one, even if we have multiple DBs, because all context is provided by DiskANN instances and Garnet storage. + /// + private DiskANNService Service { get; } = new DiskANNService(); + + /// + /// Whether or not Vector Set preview is enabled. + /// + /// TODO: This goes away once we're stable. + /// + public bool IsEnabled { get; } + + /// + /// Unique id for this . + /// + /// Is used to determine if an is backed by a DiskANN index that was created in this process. + /// + private readonly Guid processInstanceId = Guid.NewGuid(); + + private readonly ILogger logger; + + private readonly int dbId; + + public VectorManager(bool enabled, int dbId, Func getCleanupSession, ILoggerFactory loggerFactory) + { + this.dbId = dbId; + + IsEnabled = enabled; + + // Include DB and id so we correlate to what's actually stored in the log + logger = loggerFactory?.CreateLogger($"{nameof(VectorManager)}:{dbId}:{processInstanceId}"); + + replicationBlockEvent = new(true); + replicationReplayChannel = Channel.CreateUnbounded(new() { SingleWriter = true, SingleReader = false, AllowSynchronousContinuations = false }); + + // TODO: Pull this off a config or something + replicationReplayTasks = new Task[Environment.ProcessorCount]; + for (var i = 0; i < replicationReplayTasks.Length; i++) + { + replicationReplayTasks[i] = Task.CompletedTask; + } + + // TODO: Probably configurable? + // For now, just number of processors + vectorSetLocks = new(Environment.ProcessorCount); + + this.getCleanupSession = getCleanupSession; + cleanupTaskChannel = Channel.CreateUnbounded(new() { SingleWriter = false, SingleReader = true, AllowSynchronousContinuations = false }); + cleanupTask = RunCleanupTaskAsync(); + + logger?.LogInformation("Created VectorManager"); + } + + /// + /// Load state necessary for VectorManager from main store. + /// + public void Initialize() + { + using var session = (RespServerSession)getCleanupSession(); + if (session.activeDbId != dbId && !session.TrySwitchActiveDatabaseSession(dbId)) + { + throw new GarnetException($"Could not switch VectorManager cleanup session to {dbId}, initialization failed"); + } + + Span keySpan = stackalloc byte[1]; + Span dataSpan = stackalloc byte[ContextMetadata.Size]; + + var key = SpanByte.FromPinnedSpan(keySpan); + + key.MarkNamespace(); + key.SetNamespaceInPayload(0); + + var data = SpanByte.FromPinnedSpan(dataSpan); + + ref var ctx = ref session.storageSession.vectorContext; + + var status = ctx.Read(ref key, ref data); + + if (status.IsPending) + { + SpanByte ignored = default; + CompletePending(ref status, ref ignored, ref ctx); + } + + // Can be not found if we've never spun up a Vector Set + if (status.Found) + { + lock (this) + { + contextMetadata = MemoryMarshal.Cast(dataSpan)[0]; + } + } + + // If we come up and contexts are marked for migration, that means the migration FAILED + // and we'd like those contexts back ASAP + lock (this) + { + var abandonedMigrations = contextMetadata.GetMigrating(); + + if (abandonedMigrations != null) + { + foreach (var abandoned in abandonedMigrations) + { + contextMetadata.MarkMigrationComplete(abandoned, ushort.MaxValue); + contextMetadata.MarkCleaningUp(abandoned); + } + + UpdateContextMetadata(ref ctx); + } + } + + // Resume any cleanups we didn't complete before recovery + _ = cleanupTaskChannel.Writer.TryWrite(null); + } + + /// + public void Dispose() + { + // We must drain all these before disposing, otherwise we'll leave replicationBlockEvent unset + replicationReplayChannel.Writer.Complete(); + replicationReplayChannel.Reader.Completion.Wait(); + + Task.WhenAll(replicationReplayTasks).Wait(); + + replicationBlockEvent.Dispose(); + + // Wait for any in progress cleanup to finish + cleanupTaskChannel.Writer.Complete(); + cleanupTaskChannel.Reader.Completion.Wait(); + cleanupTask.Wait(); + } + + private static void CompletePending(ref Status status, ref SpanByte output, ref TContext ctx) + where TContext : ITsavoriteContext + { + _ = ctx.CompletePendingWithOutputs(out var completedOutputs, wait: true); + var more = completedOutputs.Next(); + Debug.Assert(more); + status = completedOutputs.Current.Status; + output = completedOutputs.Current.Output; + Debug.Assert(!completedOutputs.Next()); + completedOutputs.Dispose(); + } + + /// + /// Add a vector to a vector set encoded by . + /// + /// Assumes that the index is locked in the Tsavorite store. + /// + /// Result of the operation. + internal VectorManagerResult TryAdd( + scoped ReadOnlySpan indexValue, + ReadOnlySpan element, + VectorValueType valueType, + ReadOnlySpan values, + ReadOnlySpan attributes, + uint providedReduceDims, + VectorQuantType providedQuantType, + uint providedBuildExplorationFactor, + uint providedNumLinks, + out ReadOnlySpan errorMsg + ) + { + AssertHaveStorageSession(); + + errorMsg = default; + + ReadIndex(indexValue, out var context, out var dimensions, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var indexPtr, out _); + + var valueDims = CalculateValueDimensions(valueType, values); + + if (dimensions != valueDims) + { + // Matching Redis behavior + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {dimensions}"); + return VectorManagerResult.BadParams; + } + + if (providedReduceDims == 0 && reduceDims != 0) + { + // Matching Redis behavior, which is definitely a bit weird here + errorMsg = Encoding.ASCII.GetBytes($"ERR Vector dimension mismatch - got {valueDims} but set has {reduceDims}"); + return VectorManagerResult.BadParams; + } + else if (providedReduceDims != 0 && providedReduceDims != reduceDims) + { + return VectorManagerResult.BadParams; + } + + if (providedQuantType != VectorQuantType.Invalid && providedQuantType != quantType) + { + return VectorManagerResult.BadParams; + } + + if (providedNumLinks != numLinks) + { + // Matching Redis behavior + errorMsg = "ERR asked M value mismatch with existing vector set"u8; + return VectorManagerResult.BadParams; + } + + if (quantType == VectorQuantType.XPreQ8 && element.Length != sizeof(uint)) + { + errorMsg = "ERR XPREQ8 requires 4-byte element ids"u8; + return VectorManagerResult.BadParams; + } + + var insert = + Service.Insert( + context, + indexPtr, + element, + valueType, + values, + attributes + ); + + if (insert) + { + return VectorManagerResult.OK; + } + + return VectorManagerResult.Duplicate; + } + + /// + /// Try to remove a vector (and associated attributes) from a Vector Set, as identified by element key. + /// + internal VectorManagerResult TryRemove(ReadOnlySpan indexValue, ReadOnlySpan element) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out var quantType, out _, out _, out var indexPtr, out _); + + if (quantType == VectorQuantType.XPreQ8 && element.Length != sizeof(int)) + { + // We know this element isn't present because of other validation constraints, bail + return VectorManagerResult.MissingElement; + } + + var del = Service.Remove(context, indexPtr, element); + + return del ? VectorManagerResult.OK : VectorManagerResult.MissingElement; + } + + /// + /// Deletion of a Vector Set needs special handling. + /// + /// This is called by DEL and UNLINK after a naive delete fails for us to _try_ and delete a Vector Set. + /// + internal Status TryDeleteVectorSet(StorageSession storageSession, ref SpanByte key) + { + storageSession.parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VADD, ref storageSession.parseState); + + Span indexSpan = stackalloc byte[IndexSizeBytes]; + + using (ReadForDeleteVectorIndex(storageSession, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + // This can happen is something else successfully deleted before we acquired the lock + return Status.CreateNotFound(); + } + + DropIndex(indexSpan); + + // Update the index to be delete-able + var updateToDroppableVectorSet = new RawStringInput(); + updateToDroppableVectorSet.arg1 = DeleteAfterDropArg; + updateToDroppableVectorSet.header.cmd = RespCommand.VADD; + + var update = storageSession.basicContext.RMW(ref key, ref updateToDroppableVectorSet); + if (!update.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to make Vector Set delete-able, this should never happen but will leave vector sets corrupted"); + } + + // Actually delete the value + var del = storageSession.basicContext.Delete(ref key); + if (!del.IsCompletedSuccessfully) + { + throw new GarnetException("Failed to delete dropped Vector Set, this should never happen but will leave vector sets corrupted"); + } + + // Cleanup incidental additional state + DropVectorSetReplicationKey(key, ref storageSession.basicContext); + + CleanupDroppedIndex(ref storageSession.vectorContext, indexSpan); + + return Status.CreateFound(); + } + } + + /// + /// Perform a similarity search given a vector to compare against. + /// + internal VectorManagerResult ValueSimilarity( + ReadOnlySpan indexValue, + VectorValueType valueType, + ReadOnlySpan values, + int count, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + bool includeAttributes, + ref SpanByteAndMemory outputIds, + out VectorIdFormat outputIdFormat, + ref SpanByteAndMemory outputDistances, + ref SpanByteAndMemory outputAttributes + ) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out var dimensions, out var reduceDims, out var quantType, out var buildExplorationFactor, out var numLinks, out var indexPtr, out _); + + var valueDims = CalculateValueDimensions(valueType, values); + if (dimensions != valueDims) + { + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + // No point in asking for more data than the effort we'll put in + if (count > searchExplorationFactor) + { + count = searchExplorationFactor; + } + + // Make sure enough space in distances for requested count + if (count > outputDistances.Length) + { + if (!outputDistances.IsSpanByte) + { + outputDistances.Memory.Dispose(); + } + + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * sizeof(float))); + } + + // Indicate requested # of matches + outputDistances.Length = count * sizeof(float); + + // If we're fairly sure the ids won't fit, go ahead and grab more memory now + // + // If we're still wrong, we'll end up using continuation callbacks which have more overhead + if (count * MinimumSpacePerId > outputIds.Length) + { + if (!outputIds.IsSpanByte) + { + outputIds.Memory.Dispose(); + } + + outputIds = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * MinimumSpacePerId)); + } + + var found = + Service.SearchVector( + context, + indexPtr, + valueType, + values, + delta, + searchExplorationFactor, + filter, + maxFilteringEffort, + outputIds.AsSpan(), + MemoryMarshal.Cast(outputDistances.AsSpan()), + out var continuation + ); + + if (found < 0) + { + logger?.LogWarning("Error indicating response from vector service {found}", found); + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + if (includeAttributes) + { + FetchVectorElementAttributes(context, found, outputIds, ref outputAttributes); + } + + if (continuation != 0) + { + // TODO: paged results! + throw new NotImplementedException(); + } + + outputDistances.Length = sizeof(float) * found; + + // Default assumption is length prefixed + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + + if (quantType == VectorQuantType.XPreQ8) + { + // But in this special case, we force them to be 4-byte ids + //outputIdFormat = VectorIdFormat.FixedI32; + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + } + + return VectorManagerResult.OK; + } + + /// + /// Perform a similarity search given a vector to compare against. + /// + internal VectorManagerResult ElementSimilarity( + ReadOnlySpan indexValue, + ReadOnlySpan element, + int count, + float delta, + int searchExplorationFactor, + ReadOnlySpan filter, + int maxFilteringEffort, + bool includeAttributes, + ref SpanByteAndMemory outputIds, + out VectorIdFormat outputIdFormat, + ref SpanByteAndMemory outputDistances, + ref SpanByteAndMemory outputAttributes + ) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out _, out _, out var quantType, out _, out _, out var indexPtr, out _); + + // No point in asking for more data than the effort we'll put in + if (count > searchExplorationFactor) + { + count = searchExplorationFactor; + } + + // Make sure enough space in distances for requested count + if (count * sizeof(float) > outputDistances.Length) + { + if (!outputDistances.IsSpanByte) + { + outputDistances.Memory.Dispose(); + } + + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * sizeof(float))); + } + + // Indicate requested # of matches + outputDistances.Length = count * sizeof(float); + + // If we're fairly sure the ids won't fit, go ahead and grab more memory now + // + // If we're still wrong, we'll end up using continuation callbacks which have more overhead + if (count * MinimumSpacePerId > outputIds.Length) + { + if (!outputIds.IsSpanByte) + { + outputIds.Memory.Dispose(); + } + + outputIds = new SpanByteAndMemory(MemoryPool.Shared.Rent(count * MinimumSpacePerId)); + } + + var found = + Service.SearchElement( + context, + indexPtr, + element, + delta, + searchExplorationFactor, + filter, + maxFilteringEffort, + outputIds.AsSpan(), + MemoryMarshal.Cast(outputDistances.AsSpan()), + out var continuation + ); + + if (found < 0) + { + logger?.LogWarning("Error indicating response from vector service {found}", found); + outputIdFormat = VectorIdFormat.Invalid; + return VectorManagerResult.BadParams; + } + + if (includeAttributes) + { + FetchVectorElementAttributes(context, found, outputIds, ref outputAttributes); + } + + if (continuation != 0) + { + // TODO: paged results! + throw new NotImplementedException(); + } + + outputDistances.Length = sizeof(float) * found; + + // Default assumption is length prefixed + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + + if (quantType == VectorQuantType.XPreQ8) + { + // But in this special case, we force them to be 4-byte ids + //outputIdFormat = VectorIdFormat.FixedI32; + outputIdFormat = VectorIdFormat.I32LengthPrefixed; + } + + return VectorManagerResult.OK; + } + + + /// + /// Fetch attributes for a given set of element ids. + /// + /// This must only be called while holding locks which prevent the Vector Set from being dropped. + /// + private void FetchVectorElementAttributes(ulong context, int numIds, SpanByteAndMemory ids, ref SpanByteAndMemory attributes) + { + var remainingIds = ids.AsReadOnlySpan(); + + GCHandle idPin = default; + byte[] idWithNamespaceArr = null; + + var attributesNextIx = 0; + + Span attributeFull = stackalloc byte[32]; + var attributeMem = SpanByteAndMemory.FromPinnedSpan(attributeFull); + + try + { + Span idWithNamespace = stackalloc byte[128]; + + // TODO: we could scatter/gather this like MGET - doesn't matter when everything is in memory, + // but if anything is on disk it'd help perf + for (var i = 0; i < numIds; i++) + { + var idLen = BinaryPrimitives.ReadInt32LittleEndian(remainingIds); + if (idLen + sizeof(int) > remainingIds.Length) + { + throw new GarnetException($"Malformed ids, {idLen} + {sizeof(int)} > {remainingIds.Length}"); + } + + var id = remainingIds.Slice(sizeof(int), idLen); + + // Make sure we've got enough space to query the element + if (id.Length + 1 > idWithNamespace.Length) + { + if (idWithNamespaceArr != null) + { + idPin.Free(); + ArrayPool.Shared.Return(idWithNamespaceArr); + } + + idWithNamespaceArr = ArrayPool.Shared.Rent(id.Length + 1); + idPin = GCHandle.Alloc(idWithNamespaceArr, GCHandleType.Pinned); + idWithNamespace = idWithNamespaceArr; + } + + if (attributeMem.Memory != null) + { + attributeMem.Length = attributeMem.Memory.Memory.Length; + } + else + { + attributeMem.Length = attributeMem.SpanByte.Length; + } + + var found = ReadSizeUnknown(context | DiskANNService.Attributes, id, ref attributeMem); + + // Copy attribute into output buffer, length prefixed, resizing as necessary + var neededSpace = 4 + (found ? attributeMem.Length : 0); + + var destSpan = attributes.AsSpan()[attributesNextIx..]; + if (destSpan.Length < neededSpace) + { + var newAttrArr = MemoryPool.Shared.Rent(attributes.Length + neededSpace); + attributes.AsReadOnlySpan().CopyTo(newAttrArr.Memory.Span); + + attributes.Memory?.Dispose(); + + attributes = new SpanByteAndMemory(newAttrArr, newAttrArr.Memory.Length); + destSpan = attributes.AsSpan()[attributesNextIx..]; + } + + BinaryPrimitives.WriteInt32LittleEndian(destSpan, attributeMem.Length); + attributeMem.AsReadOnlySpan().CopyTo(destSpan[sizeof(int)..]); + + attributesNextIx += neededSpace; + + remainingIds = remainingIds[(sizeof(int) + idLen)..]; + } + + attributes.Length = attributesNextIx; + } + finally + { + if (idWithNamespaceArr != null) + { + idPin.Free(); + ArrayPool.Shared.Return(idWithNamespaceArr); + } + + attributeMem.Memory?.Dispose(); + } + } + + /// + /// Try to read the associated dimensions for an element out of a Vector Set. + /// + internal bool TryGetEmbedding(ReadOnlySpan indexValue, ReadOnlySpan element, ref SpanByteAndMemory outputDistances) + { + AssertHaveStorageSession(); + + ReadIndex(indexValue, out var context, out var dimensions, out _, out _, out _, out _, out var indexPtr, out _); + + // Make sure enough space in distances for requested count + if (dimensions * sizeof(float) > outputDistances.Length) + { + if (!outputDistances.IsSpanByte) + { + outputDistances.Memory.Dispose(); + } + + outputDistances = new SpanByteAndMemory(MemoryPool.Shared.Rent((int)dimensions * sizeof(float)), (int)dimensions * sizeof(float)); + } + else + { + outputDistances.Length = (int)dimensions * sizeof(float); + } + + Span asBytesSpan = stackalloc byte[(int)dimensions]; + var asBytes = SpanByteAndMemory.FromPinnedSpan(asBytesSpan); + try + { + if (!ReadSizeUnknown(context | DiskANNService.FullVector, element, ref asBytes)) + { + return false; + } + + var from = asBytes.AsReadOnlySpan(); + var into = MemoryMarshal.Cast(outputDistances.AsSpan()); + + for (var i = 0; i < asBytes.Length; i++) + { + into[i] = from[i]; + } + + return true; + } + finally + { + asBytes.Memory?.Dispose(); + } + + // TODO: DiskANN will need to do this long term, since different quantizers may behave differently + + //return + // Service.TryGetEmbedding( + // context, + // indexPtr, + // element, + // MemoryMarshal.Cast(outputDistances.AsSpan()) + // ); + } + + /// + /// Determine the dimensions of a vector given its and its raw data. + /// + internal static uint CalculateValueDimensions(VectorValueType valueType, ReadOnlySpan values) + { + if (valueType == VectorValueType.FP32) + { + return (uint)(values.Length / sizeof(float)); + } + else if (valueType == VectorValueType.XB8) + { + return (uint)(values.Length); + } + else + { + throw new NotImplementedException($"{valueType}"); + } + } + + [Conditional("DEBUG")] + private static void AssertHaveStorageSession() + { + Debug.Assert(ActiveThreadSession != null, "Should have StorageSession by now"); + } + } +} \ No newline at end of file diff --git a/libs/server/Servers/GarnetServerOptions.cs b/libs/server/Servers/GarnetServerOptions.cs index e0d344e87fb..4b6ee6b2797 100644 --- a/libs/server/Servers/GarnetServerOptions.cs +++ b/libs/server/Servers/GarnetServerOptions.cs @@ -532,6 +532,13 @@ public class GarnetServerOptions : ServerOptions /// public bool ClusterReplicaResumeWithData = false; + /// + /// If true, enable Vector Set commands. + /// + /// This is a preview feature, subject to substantial change, and should not be relied upon. + /// + public bool EnableVectorSetPreview = false; + /// /// Get the directory name for database checkpoints /// diff --git a/libs/server/Storage/Functions/FunctionsState.cs b/libs/server/Storage/Functions/FunctionsState.cs index 4ef24a38260..32eddffbe4e 100644 --- a/libs/server/Storage/Functions/FunctionsState.cs +++ b/libs/server/Storage/Functions/FunctionsState.cs @@ -22,11 +22,12 @@ internal sealed class FunctionsState public EtagState etagState; public byte respProtocolVersion; public bool StoredProcMode; + public readonly VectorManager vectorManager; internal ReadOnlySpan nilResp => respProtocolVersion >= 3 ? CmdStrings.RESP3_NULL_REPLY : CmdStrings.RESP_ERRNOTFOUND; public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionMap, CustomCommandManager customCommandManager, - MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer, + MemoryPool memoryPool, CacheSizeTracker objectStoreSizeTracker, GarnetObjectSerializer garnetObjectSerializer, VectorManager vectorManager, byte respProtocolVersion = ServerOptions.DEFAULT_RESP_VERSION) { this.appendOnlyFile = appendOnlyFile; @@ -36,6 +37,7 @@ public FunctionsState(TsavoriteLog appendOnlyFile, WatchVersionMap watchVersionM this.objectStoreSizeTracker = objectStoreSizeTracker; this.garnetObjectSerializer = garnetObjectSerializer; this.etagState = new EtagState(); + this.vectorManager = vectorManager; this.respProtocolVersion = respProtocolVersion; } diff --git a/libs/server/Storage/Functions/MainStore/DeleteMethods.cs b/libs/server/Storage/Functions/MainStore/DeleteMethods.cs index 6c055bd3682..d265ea20819 100644 --- a/libs/server/Storage/Functions/MainStore/DeleteMethods.cs +++ b/libs/server/Storage/Functions/MainStore/DeleteMethods.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System; using Tsavorite.core; namespace Garnet.server @@ -13,6 +14,15 @@ namespace Garnet.server /// public bool SingleDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) { + if (recordInfo.VectorSet && value.AsReadOnlySpan().ContainsAnyExcept((byte)0)) + { + // Implies this is a vector set, needs special handling + // + // Will call back in after a drop with an all 0 value + deleteInfo.Action = DeleteAction.CancelOperation; + return false; + } + recordInfo.ClearHasETag(); functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); return true; @@ -28,6 +38,15 @@ public void PostSingleDeleter(ref SpanByte key, ref DeleteInfo deleteInfo) /// public bool ConcurrentDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) { + if (recordInfo.VectorSet && value.AsReadOnlySpan().ContainsAnyExcept((byte)0)) + { + // Implies this is a vector set, needs special handling + // + // Will call back in after a drop with an all 0 value + deleteInfo.Action = DeleteAction.CancelOperation; + return false; + } + recordInfo.ClearHasETag(); if (!deleteInfo.RecordInfo.Modified) functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); diff --git a/libs/server/Storage/Functions/MainStore/PrivateMethods.cs b/libs/server/Storage/Functions/MainStore/PrivateMethods.cs index 371e8f0825b..499d9aac0b3 100644 --- a/libs/server/Storage/Functions/MainStore/PrivateMethods.cs +++ b/libs/server/Storage/Functions/MainStore/PrivateMethods.cs @@ -118,6 +118,11 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB value.CopyTo(dst.Memory.Memory.Span); break; + case RespCommand.VADD: + case RespCommand.VSIM: + case RespCommand.VEMB: + case RespCommand.VREM: + case RespCommand.VDIM: case RespCommand.GET: // Get value without RESP header; exclude expiration if (value.LengthWithoutMetadata <= dst.Length) @@ -242,12 +247,12 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB throw new GarnetException($"Not enough space in {input.header.cmd} buffer"); case RespCommand.TTL: - var ttlValue = ConvertUtils.SecondsFromDiffUtcNowTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var ttlValue = ConvertUtils.SecondsFromDiffUtcNowTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(ttlValue, ref dst); return; case RespCommand.PTTL: - var pttlValue = ConvertUtils.MillisecondsFromDiffUtcNowTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var pttlValue = ConvertUtils.MillisecondsFromDiffUtcNowTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(pttlValue, ref dst); return; @@ -260,12 +265,12 @@ void CopyRespToWithInput(ref RawStringInput input, ref SpanByte value, ref SpanB CopyRespTo(ref value, ref dst, start + functionsState.etagState.etagSkippedStart, end + functionsState.etagState.etagSkippedStart); return; case RespCommand.EXPIRETIME: - var expireTime = ConvertUtils.UnixTimeInSecondsFromTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var expireTime = ConvertUtils.UnixTimeInSecondsFromTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(expireTime, ref dst); return; case RespCommand.PEXPIRETIME: - var pexpireTime = ConvertUtils.UnixTimeInMillisecondsFromTicks(value.MetadataSize > 0 ? value.ExtraMetadata : -1); + var pexpireTime = ConvertUtils.UnixTimeInMillisecondsFromTicks(value.MetadataSize == 8 ? value.ExtraMetadata : -1); CopyRespNumber(pexpireTime, ref dst); return; @@ -638,6 +643,27 @@ void CopyDefaultResp(ReadOnlySpan resp, ref SpanByteAndMemory dst) resp.CopyTo(dst.Memory.Memory.Span); } + void CopyRespError(ReadOnlySpan errMsg, ref SpanByteAndMemory dst) + { + if (errMsg.Length + 3 < dst.SpanByte.Length) + { + var into = dst.SpanByte.AsSpan(); + + into[0] = (byte)'-'; + errMsg.CopyTo(into[1..]); + "\r\n"u8.CopyTo(into[(1 + errMsg.Length)..]); + dst.SpanByte.Length = errMsg.Length + 3; + return; + } + + dst.ConvertToHeap(); + dst.Length = errMsg.Length + 3; + dst.Memory = functionsState.memoryPool.Rent(errMsg.Length + 3); + dst.Memory.Memory.Span[0] = (byte)'-'; + errMsg.CopyTo(dst.Memory.Memory.Span[1..]); + "\r\n"u8.CopyTo(dst.Memory.Memory.Span[(1 + errMsg.Length)..]); + } + void CopyRespNumber(long number, ref SpanByteAndMemory dst) { byte* curr = dst.SpanByte.ToPointer(); @@ -729,6 +755,11 @@ void WriteLogUpsert(ref SpanByte key, ref RawStringInput input, ref SpanByte val { if (functionsState.StoredProcMode) return; + if (input.header.cmd == RespCommand.VADD && input.arg1 is not (VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg)) + { + return; + } + // We need this check because when we ingest records from the primary // if the input is zero then input overlaps with value so any update to RespInputHeader->flags // will incorrectly modify the total length of value. @@ -749,6 +780,12 @@ void WriteLogUpsert(ref SpanByte key, ref RawStringInput input, ref SpanByte val void WriteLogRMW(ref SpanByte key, ref RawStringInput input, long version, int sessionId) { if (functionsState.StoredProcMode) return; + + if (input.header.cmd == RespCommand.VADD && input.arg1 is not (VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg)) + { + return; + } + input.header.flags |= RespInputFlags.Deterministic; functionsState.appendOnlyFile.Enqueue( diff --git a/libs/server/Storage/Functions/MainStore/RMWMethods.cs b/libs/server/Storage/Functions/MainStore/RMWMethods.cs index 2e6453c20fb..df0426d25c8 100644 --- a/libs/server/Storage/Functions/MainStore/RMWMethods.cs +++ b/libs/server/Storage/Functions/MainStore/RMWMethods.cs @@ -3,6 +3,7 @@ using System; using System.Diagnostics; +using System.Runtime.InteropServices; using Garnet.common; using Tsavorite.core; @@ -242,6 +243,38 @@ public bool InitialUpdater(ref SpanByte key, ref RawStringInput input, ref SpanB var incrByFloat = BitConverter.Int64BitsToDouble(input.arg1); CopyUpdateNumber(incrByFloat, ref value, ref output); break; + + case RespCommand.VADD: + { + if (input.arg1 is VectorManager.VADDAppendLogArg or VectorManager.MigrateElementKeyLogArg or VectorManager.MigrateIndexKeyLogArg) + { + // Synthetic op, do nothing + break; + } + + var dims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(0).Span); + var reduceDims = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(1).Span); + // ValueType is here, skipping during index creation + // Values is here, skipping during index creation + // Element is here, skipping during index creation + var quantizer = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(5).Span); + var buildExplorationFactor = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(6).Span); + // Attributes is here, skipping during index creation + var numLinks = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(8).Span); + + // Pre-allocated by caller because DiskANN needs to be able to call into Garnet as part of create_index + // and thus we can't call into it from session functions + var context = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(9).Span); + var index = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(10).Span); + + recordInfo.VectorSet = true; + + functionsState.vectorManager.CreateIndex(dims, reduceDims, quantizer, buildExplorationFactor, numLinks, context, index, ref value); + } + break; + case RespCommand.VREM: + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "Should only see VREM writes as part of replication"); + break; default: if (input.header.cmd > RespCommandExtensions.LastValidCommand) { @@ -327,7 +360,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu { RespCommand cmd = input.header.cmd; // Expired data - if (value.MetadataSize > 0 && input.header.CheckExpiry(value.ExtraMetadata)) + if (value.MetadataSize == 8 && input.header.CheckExpiry(value.ExtraMetadata)) { rmwInfo.Action = cmd is RespCommand.DELIFEXPIM ? RMWAction.ExpireAndStop : RMWAction.ExpireAndResume; recordInfo.ClearHasETag(); @@ -583,7 +616,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu break; case RespCommand.EXPIRE: - var expiryExists = value.MetadataSize > 0; + var expiryExists = value.MetadataSize == 8; var expirationWithOption = new ExpirationWithOption(input.arg1); @@ -593,7 +626,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu return EvaluateExpireInPlace(expirationWithOption.ExpireOption, expiryExists, expirationWithOption.ExpirationTimeInTicks, ref value, ref output); case RespCommand.PERSIST: - if (value.MetadataSize != 0) + if (value.MetadataSize == 8) { rmwInfo.ClearExtraValueLength(ref recordInfo, ref value, value.TotalSize); value.AsSpan().CopyTo(value.AsSpanWithMetadata()); @@ -752,7 +785,7 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu var _output = new SpanByteAndMemory(SpanByte.FromPinnedPointer(pbOutput, ObjectOutputHeader.Size)); var newExpiry = input.arg1; - return EvaluateExpireInPlace(ExpireOption.None, expiryExists: value.MetadataSize > 0, newExpiry, ref value, ref _output); + return EvaluateExpireInPlace(ExpireOption.None, expiryExists: value.MetadataSize == 8, newExpiry, ref value, ref _output); } if (input.parseState.Count > 0) @@ -794,6 +827,38 @@ private IPUResult InPlaceUpdaterWorker(ref SpanByte key, ref RawStringInput inpu // this is the case where it isn't expired shouldUpdateEtag = false; break; + case RespCommand.VADD: + // Adding to an existing VectorSet is modeled as a read operations + // + // However, we do synthesize some (pointless) writes to implement replication + // and a "make me delete=able"-update during drop. + // + // Another "not quite write" is the recreate an index write operation + // that occurs if we're adding to an index that was restored from disk + // or a primary node. + + // Handle "make me delete-able" + if (input.arg1 == VectorManager.DeleteAfterDropArg) + { + value.AsSpan().Clear(); + } + else if (input.arg1 == VectorManager.RecreateIndexArg) + { + var newIndexPtr = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(10).Span); + + functionsState.vectorManager.RecreateIndex(newIndexPtr, ref value); + } + + // Ignore everything else + return IPUResult.Succeeded; + case RespCommand.VREM: + // Removing from a VectorSet is modeled as a read operations + // + // However, we do synthesize some (pointless) writes to implement replication + // in a similar manner to VADD. + + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "VREM in place update should only happen for replication"); // Ignore everything else + return IPUResult.Succeeded; default: if (cmd > RespCommandExtensions.LastValidCommand) { @@ -877,7 +942,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB switch (input.header.cmd) { case RespCommand.DELIFEXPIM: - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.Action = RMWAction.ExpireAndStop; } @@ -940,7 +1005,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB case RespCommand.SETEXNX: // Expired data, return false immediately // ExpireAndResume ensures that we set as new value, since it does not exist - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.Action = RMWAction.ExpireAndResume; rmwInfo.RecordInfo.ClearHasETag(); @@ -968,7 +1033,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB case RespCommand.SETEXXX: // Expired data, return false immediately so we do not set, since it does not exist // ExpireAndStop ensures that caller sees a NOTFOUND status - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { rmwInfo.RecordInfo.ClearHasETag(); rmwInfo.Action = RMWAction.ExpireAndStop; @@ -1009,7 +1074,7 @@ public bool NeedCopyUpdate(ref SpanByte key, ref RawStringInput input, ref SpanB public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByteAndMemory output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) { // Expired data - if (oldValue.MetadataSize > 0 && input.header.CheckExpiry(oldValue.ExtraMetadata)) + if (oldValue.MetadataSize == 8 && input.header.CheckExpiry(oldValue.ExtraMetadata)) { recordInfo.ClearHasETag(); rmwInfo.Action = RMWAction.ExpireAndResume; @@ -1171,7 +1236,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte case RespCommand.EXPIRE: shouldUpdateEtag = false; - var expiryExists = oldValue.MetadataSize > 0; + var expiryExists = oldValue.MetadataSize == 8; var expirationWithOption = new ExpirationWithOption(input.arg1); @@ -1181,7 +1246,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte case RespCommand.PERSIST: shouldUpdateEtag = false; oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); - if (oldValue.MetadataSize != 0) + if (oldValue.MetadataSize == 8) { newValue.AsSpan().CopyTo(newValue.AsSpanWithMetadata()); newValue.ShrinkSerializedLength(newValue.Length - newValue.MetadataSize); @@ -1306,7 +1371,7 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte byte* pbOutput = stackalloc byte[ObjectOutputHeader.Size]; var _output = new SpanByteAndMemory(SpanByte.FromPinnedPointer(pbOutput, ObjectOutputHeader.Size)); var newExpiry = input.arg1; - EvaluateExpireCopyUpdate(ExpireOption.None, expiryExists: oldValue.MetadataSize > 0, newExpiry, ref oldValue, ref newValue, ref _output); + EvaluateExpireCopyUpdate(ExpireOption.None, expiryExists: oldValue.MetadataSize == 8, newExpiry, ref oldValue, ref newValue, ref _output); } oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); @@ -1337,6 +1402,27 @@ public bool CopyUpdater(ref SpanByte key, ref RawStringInput input, ref SpanByte CopyValueLengthToOutput(ref newValue, ref output, functionsState.etagState.etagSkippedStart); break; + case RespCommand.VADD: + // Handle "make me delete-able" + if (input.arg1 == VectorManager.DeleteAfterDropArg) + { + newValue.AsSpan().Clear(); + } + else if (input.arg1 == VectorManager.RecreateIndexArg) + { + var newIndexPtr = MemoryMarshal.Read(input.parseState.GetArgSliceByRef(10).Span); + + oldValue.CopyTo(ref newValue); + + functionsState.vectorManager.RecreateIndex(newIndexPtr, ref newValue); + } + + break; + + case RespCommand.VREM: + Debug.Assert(input.arg1 == VectorManager.VREMAppendLogArg, "Unexpected CopyUpdater call on VREM key"); + break; + default: if (input.header.cmd > RespCommandExtensions.LastValidCommand) { diff --git a/libs/server/Storage/Functions/MainStore/ReadMethods.cs b/libs/server/Storage/Functions/MainStore/ReadMethods.cs index d23e5af89dd..2a953bc5731 100644 --- a/libs/server/Storage/Functions/MainStore/ReadMethods.cs +++ b/libs/server/Storage/Functions/MainStore/ReadMethods.cs @@ -17,7 +17,7 @@ public bool SingleReader( ref SpanByte key, ref RawStringInput input, ref SpanByte value, ref SpanByteAndMemory dst, ref ReadInfo readInfo) { - if (value.MetadataSize != 0 && CheckExpiry(ref value)) + if (value.MetadataSize == 8 && CheckExpiry(ref value)) { readInfo.RecordInfo.ClearHasETag(); return false; @@ -25,6 +25,22 @@ public bool SingleReader( var cmd = input.header.cmd; + // Vector sets are reachable (key not mangled) and hidden. + // So we can use that to detect type mismatches. + if (readInfo.RecordInfo.VectorSet && !cmd.IsLegalOnVectorSet()) + { + // Attempted an illegal op on a VectorSet + CopyRespError(CmdStrings.RESP_ERR_WRONG_TYPE, ref dst); + readInfo.Action = ReadAction.CancelOperation; + return true; + } + else if (!readInfo.RecordInfo.VectorSet && cmd.IsLegalOnVectorSet()) + { + // Attempted a vector set op on a non-VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + if (cmd == RespCommand.GETIFNOTMATCH) { if (handleGetIfNotMatch(ref input, ref value, ref dst, ref readInfo)) @@ -87,7 +103,7 @@ public bool ConcurrentReader( ref SpanByte key, ref RawStringInput input, ref SpanByte value, ref SpanByteAndMemory dst, ref ReadInfo readInfo, ref RecordInfo recordInfo) { - if (value.MetadataSize != 0 && CheckExpiry(ref value)) + if (value.MetadataSize == 8 && CheckExpiry(ref value)) { recordInfo.ClearHasETag(); return false; @@ -95,6 +111,22 @@ public bool ConcurrentReader( var cmd = input.header.cmd; + // Vector sets are reachable (key not mangled) and hidden. + // So we can use that to detect type mismatches. + if (recordInfo.VectorSet && !cmd.IsLegalOnVectorSet()) + { + // Attempted an illegal op on a VectorSet + CopyRespError(CmdStrings.RESP_ERR_WRONG_TYPE, ref dst); + readInfo.Action = ReadAction.CancelOperation; + return true; + } + else if (!recordInfo.VectorSet && cmd.IsLegalOnVectorSet()) + { + // Attempted a vector set op on a non-VectorSet + readInfo.Action = ReadAction.CancelOperation; + return false; + } + if (cmd == RespCommand.GETIFNOTMATCH) { if (handleGetIfNotMatch(ref input, ref value, ref dst, ref readInfo)) @@ -137,7 +169,6 @@ public bool ConcurrentReader( return true; } - if (cmd == RespCommand.NONE) CopyRespTo(ref value, ref dst, functionsState.etagState.etagSkippedStart, functionsState.etagState.etagAccountedLength); else diff --git a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs index adc5b124249..0130dcbe389 100644 --- a/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs +++ b/libs/server/Storage/Functions/MainStore/VarLenInputMethods.cs @@ -113,6 +113,9 @@ public int GetRMWInitialValueLength(ref RawStringInput input) ndigits = NumUtils.CountCharsInDouble(incrByFloat, out var _, out var _, out var _); return sizeof(int) + ndigits; + case RespCommand.VADD: + return sizeof(int) + VectorManager.IndexSizeBytes; + default: if (cmd > RespCommandExtensions.LastValidCommand) { @@ -236,6 +239,9 @@ public int GetRMWModifiedValueLength(ref SpanByte t, ref RawStringInput input) // Min allocation (only metadata) needed since this is going to be used for tombstoning anyway. return sizeof(int); + case RespCommand.VADD: + return t.Length; + default: if (cmd > RespCommandExtensions.LastValidCommand) { diff --git a/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs b/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs new file mode 100644 index 00000000000..ddad8151f95 --- /dev/null +++ b/libs/server/Storage/Functions/MainStore/VectorSessionFunctions.cs @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Diagnostics; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Functions for operating against the Main Store, but for data stored as part of a Vector Set operation - not a RESP command. + /// + public readonly struct VectorSessionFunctions : ISessionFunctions + { + private readonly FunctionsState functionsState; + + /// + /// Constructor + /// + internal VectorSessionFunctions(FunctionsState functionsState) + { + this.functionsState = functionsState; + } + + #region Deletes + /// + public bool SingleDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) + { + recordInfo.ClearHasETag(); + functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); + return true; + } + /// + public bool ConcurrentDeleter(ref SpanByte key, ref SpanByte value, ref DeleteInfo deleteInfo, ref RecordInfo recordInfo) + { + recordInfo.ClearHasETag(); + if (!deleteInfo.RecordInfo.Modified) + functionsState.watchVersionMap.IncrementVersion(deleteInfo.KeyHash); + return true; + } + /// + public void PostSingleDeleter(ref SpanByte key, ref DeleteInfo deleteInfo) { } + #endregion + + #region Reads + /// + public bool SingleReader(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte dst, ref ReadInfo readInfo) + { + Debug.Assert(key.MetadataSize == 1, "Should never read a non-namespaced value with VectorSessionFunctions"); + + unsafe + { + if (input.Callback != 0) + { + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + + callback(input.Index, input.CallbackContext, (nint)value.ToPointer(), (nuint)value.Length); + return true; + } + } + + if (input.ReadDesiredSize > 0) + { + Debug.Assert(dst.Length >= value.Length, "Should always have space for vector point reads"); + + dst.Length = value.Length; + value.AsReadOnlySpan(functionsState.etagState.etagSkippedStart).CopyTo(dst.AsSpan()); + } + else + { + input.ReadDesiredSize = value.Length; + if (dst.Length >= value.Length) + { + value.AsReadOnlySpan(functionsState.etagState.etagSkippedStart).CopyTo(dst.AsSpan()); + dst.Length = value.Length; + } + } + + return true; + } + /// + public bool ConcurrentReader(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte dst, ref ReadInfo readInfo, ref RecordInfo recordInfo) + => SingleReader(ref key, ref input, ref value, ref dst, ref readInfo); + + /// + public void ReadCompletionCallback(ref SpanByte key, ref VectorInput input, ref SpanByte output, long ctx, Status status, RecordMetadata recordMetadata) + { + } + #endregion + + #region Initial Values + /// + public bool NeedInitialUpdate(ref SpanByte key, ref VectorInput input, ref SpanByte output, ref RMWInfo rmwInfo) + { + // Only needed when updating ContextMetadata via RMW or the DiskANN RMW callback, both of which set WriteDesiredSize + return input.WriteDesiredSize > 0; + } + /// + public bool InitialUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + if (input.Callback == 0) + { + Debug.Assert(key.LengthWithoutMetadata == 0 && key.GetNamespaceInPayload() == 0, "Should only be updating ContextMetadata"); + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref value, ref rmwInfo, ref recordInfo); + } + else + { + Debug.Assert(input.WriteDesiredSize <= value.LengthWithoutMetadata, "Insufficient space for initial update, this should never happen"); + + rmwInfo.ClearExtraValueLength(ref recordInfo, ref value, value.TotalSize); + + // Must explicitly 0 before passing if we're doing an initial update + value.AsSpan().Clear(); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)value.ToPointer(), (nuint)input.WriteDesiredSize); + + value.ShrinkSerializedLength(input.WriteDesiredSize); + value.Length = input.WriteDesiredSize; + } + + return true; + } + } + /// + public void PostInitialUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo) { } + #endregion + + #region Writes + /// + public bool SingleWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, WriteReason reason, ref RecordInfo recordInfo) + => ConcurrentWriter(ref key, ref input, ref src, ref dst, ref output, ref upsertInfo, ref recordInfo); + + /// + public void PostSingleWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, WriteReason reason) { } + /// + public bool ConcurrentWriter(ref SpanByte key, ref VectorInput input, ref SpanByte src, ref SpanByte dst, ref SpanByte output, ref UpsertInfo upsertInfo, ref RecordInfo recordInfo) + => SpanByteFunctions.DoSafeCopy(ref src, ref dst, ref upsertInfo, ref recordInfo, 0); + + #endregion + + #region RMW + /// + public int GetRMWInitialValueLength(ref VectorInput input) + => sizeof(byte) + sizeof(int) + input.WriteDesiredSize; + /// + public int GetRMWModifiedValueLength(ref SpanByte value, ref VectorInput input) + => sizeof(byte) + sizeof(int) + input.WriteDesiredSize; + + /// + public int GetUpsertValueLength(ref SpanByte value, ref VectorInput input) + => sizeof(byte) + sizeof(int) + value.Length; + + /// + public bool InPlaceUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte value, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + if (input.Callback == 0) + { + // We're doing a Metadata update + + Debug.Assert(key.GetNamespaceInPayload() == 0 && key.LengthWithoutMetadata == 0, "Should be special context key"); + Debug.Assert(value.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + + ref readonly var oldMetadata = ref MemoryMarshal.Cast(value.AsReadOnlySpan())[0]; + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + ref readonly var newMetadata = ref MemoryMarshal.Cast(newMetadataValue.AsReadOnlySpan())[0]; + + if (newMetadata.Version < oldMetadata.Version) + { + rmwInfo.Action = RMWAction.CancelOperation; + return false; + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref value, ref rmwInfo, ref recordInfo); + } + else + { + Debug.Assert(input.WriteDesiredSize <= value.LengthWithoutMetadata, "Insufficient space for inplace update, this should never happen"); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)value.ToPointer(), (nuint)input.WriteDesiredSize); + } + + return true; + } + } + + /// + public bool NeedCopyUpdate(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte output, ref RMWInfo rmwInfo) + => input.WriteDesiredSize > 0; + + /// + public bool CopyUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByte output, ref RMWInfo rmwInfo, ref RecordInfo recordInfo) + { + if (input.Callback == 0) + { + // We're doing a Metadata update + + Debug.Assert(key.GetNamespaceInPayload() == 0 && key.LengthWithoutMetadata == 0, "Should be special context key"); + Debug.Assert(oldValue.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(newValue.LengthWithoutMetadata == VectorManager.ContextMetadata.Size, "Should be ContextMetadata"); + Debug.Assert(input.CallbackContext != 0, "Should have data on VectorInput"); + + ref readonly var oldMetadata = ref MemoryMarshal.Cast(oldValue.AsReadOnlySpan())[0]; + + SpanByte newMetadataValue; + unsafe + { + newMetadataValue = SpanByte.FromPinnedPointer((byte*)input.CallbackContext, VectorManager.ContextMetadata.Size); + } + + ref readonly var newMetadata = ref MemoryMarshal.Cast(newMetadataValue.AsReadOnlySpan())[0]; + + if (newMetadata.Version < oldMetadata.Version) + { + rmwInfo.Action = RMWAction.CancelOperation; + return false; + } + + return SpanByteFunctions.DoSafeCopy(ref newMetadataValue, ref newValue, ref rmwInfo, ref recordInfo); + } + else + { + Debug.Assert(input.WriteDesiredSize <= newValue.LengthWithoutMetadata, "Insufficient space for copy update, this should never happen"); + Debug.Assert(input.WriteDesiredSize <= oldValue.LengthWithoutMetadata, "Insufficient space for copy update, this should never happen"); + + oldValue.AsReadOnlySpan().CopyTo(newValue.AsSpan()); + + unsafe + { + // Callback takes: dataCallbackContext, dataPtr, dataLength + var callback = (delegate* unmanaged[Cdecl, SuppressGCTransition])input.Callback; + callback(input.CallbackContext, (nint)newValue.ToPointer(), (nuint)input.WriteDesiredSize); + } + + return true; + } + } + + /// + public bool PostCopyUpdater(ref SpanByte key, ref VectorInput input, ref SpanByte oldValue, ref SpanByte newValue, ref SpanByte output, ref RMWInfo rmwInfo) + => true; + /// + public void RMWCompletionCallback(ref SpanByte key, ref VectorInput input, ref SpanByte output, long ctx, Status status, RecordMetadata recordMetadata) { } + #endregion + + #region Utilities + /// + public void ConvertOutputToHeap(ref VectorInput input, ref SpanByte output) { } + #endregion + } +} \ No newline at end of file diff --git a/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs b/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs index b4cb3c530de..319f440ff9a 100644 --- a/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs +++ b/libs/server/Storage/Session/Common/ArrayKeyIterationFunctions.cs @@ -258,7 +258,7 @@ protected override bool DeleteIfExpiredInMemory(ref byte[] key, ref IGarnetObjec internal sealed class MainStoreExpiredKeyDeletionScan : ExpiredKeysBase { - protected override bool IsExpired(ref SpanByte value) => value.MetadataSize > 0 && MainSessionFunctions.CheckExpiry(ref value); + protected override bool IsExpired(ref SpanByte value) => value.MetadataSize == 8 && MainSessionFunctions.CheckExpiry(ref value); protected override bool DeleteIfExpiredInMemory(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata) { var input = new RawStringInput(RespCommand.DELIFEXPIM); @@ -323,8 +323,15 @@ public bool SingleReader(ref SpanByte key, ref SpanByte value, RecordMetadata re public bool ConcurrentReader(ref SpanByte key, ref SpanByte value, RecordMetadata recordMetadata, long numberOfRecords, out CursorRecordResult cursorRecordResult) { + // TODO: A better check for "is probably a vector key" + if (key.MetadataSize == 1) + { + cursorRecordResult = CursorRecordResult.Skip; + return true; + } + if ((info.patternB != null && !GlobUtils.Match(info.patternB, info.patternLength, key.ToPointer(), key.Length, true)) - || (value.MetadataSize != 0 && MainSessionFunctions.CheckExpiry(ref value))) + || (value.MetadataSize == 8 && MainSessionFunctions.CheckExpiry(ref value))) { cursorRecordResult = CursorRecordResult.Skip; } @@ -410,7 +417,14 @@ internal sealed class MainStoreGetDBSize : IScanIteratorFunctions(ref SpanByte key, ref RawStringInput input, re incr_session_found(); return GarnetStatus.OK; } + else if (status.IsCanceled) + { + return GarnetStatus.WRONGTYPE; + } else { incr_session_notfound(); @@ -589,6 +593,12 @@ public GarnetStatus DELETE(ref SpanByte key, StoreType if (storeType == StoreType.Main || storeType == StoreType.All) { var status = context.Delete(ref key); + if (status.IsCanceled) + { + // Might be a vector set + status = vectorManager.TryDeleteVectorSet(this, ref key); + } + Debug.Assert(!status.IsPending); if (status.Found) found = true; } @@ -600,10 +610,11 @@ public GarnetStatus DELETE(ref SpanByte key, StoreType Debug.Assert(!status.IsPending); if (status.Found) found = true; } + return found ? GarnetStatus.OK : GarnetStatus.NOTFOUND; } - public GarnetStatus DELETE(byte[] key, StoreType storeType, ref TContext context, ref TObjectContext objectContext) + public unsafe GarnetStatus DELETE(byte[] key, StoreType storeType, ref TContext context, ref TObjectContext objectContext) where TContext : ITsavoriteContext where TObjectContext : ITsavoriteContext { @@ -612,6 +623,18 @@ public GarnetStatus DELETE(byte[] key, StoreType store if ((storeType == StoreType.Object || storeType == StoreType.All) && !objectStoreBasicContext.IsNull) { var status = objectContext.Delete(key); + if (status.IsCanceled) + { + // Might be a vector set + fixed (byte* keyPtr = key) + { + SpanByte keySpan = new(key.Length, (nint)keyPtr); + status = vectorManager.TryDeleteVectorSet(this, ref keySpan); + } + + if (status.Found) found = true; + } + Debug.Assert(!status.IsPending); if (status.Found) found = true; } diff --git a/libs/server/Storage/Session/MainStore/VectorStoreOps.cs b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs new file mode 100644 index 00000000000..c35dc0d2c5e --- /dev/null +++ b/libs/server/Storage/Session/MainStore/VectorStoreOps.cs @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Tsavorite.core; + +namespace Garnet.server +{ + /// + /// Supported quantizations of vector data. + /// + /// This controls the mapping of vector elements to how they're actually stored. + /// + public enum VectorQuantType + { + Invalid = 0, + + // Redis quantiziations + + /// + /// Provided and stored as floats (FP32). + /// + NoQuant, + /// + /// Provided as FP32, stored as binary (1 bit). + /// + Bin, + /// + /// Provided as FP32, stored as bytes (8 bits). + /// + Q8, + + // Extended quantizations + + /// + /// Provided and stored as bytes (8 bits). + /// + XPreQ8, + } + + /// + /// Supported formats for Vector value data. + /// + public enum VectorValueType : int + { + Invalid = 0, + + // Redis formats + + /// + /// Floats (FP32). + /// + FP32, + + // Extended formats + + /// + /// Bytes (8 bit). + /// + XB8, + } + + /// + /// How result ids are formatted in responses from DiskANN. + /// + public enum VectorIdFormat : int + { + Invalid = 0, + + /// + /// Has 4 bytes of unsigned length before the data. + /// + I32LengthPrefixed, + + /// + /// Ids are actually 4-byte ints, no prefix. + /// + FixedI32 + } + + /// + /// Implementation of Vector Set operations. + /// + sealed partial class StorageSession : IDisposable + { + /// + /// Implement Vector Set Add - this may also create a Vector Set if one does not already exist. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetAdd(SpanByte key, int reduceDims, VectorValueType valueType, ArgSlice values, ArgSlice element, VectorQuantType quantizer, int buildExplorationFactor, ArgSlice attributes, int numLinks, out VectorManagerResult result, out ReadOnlySpan errorMsg) + { + var dims = VectorManager.CalculateValueDimensions(valueType, values.ReadOnlySpan); + + var dimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref dims, 1))); + var reduceDimsArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref reduceDims, 1))); + var valueTypeArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref valueType, 1))); + var valuesArg = values; + var elementArg = element; + var quantizerArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref quantizer, 1))); + var buildExplorationFactorArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref buildExplorationFactor, 1))); + var attributesArg = attributes; + var numLinksArg = ArgSlice.FromPinnedSpan(MemoryMarshal.Cast(MemoryMarshal.CreateSpan(ref numLinks, 1))); + + parseState.InitializeWithArguments([dimsArg, reduceDimsArg, valueTypeArg, valuesArg, elementArg, quantizerArg, buildExplorationFactorArg, attributesArg, numLinksArg]); + + var input = new RawStringInput(RespCommand.VADD, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadOrCreateVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + errorMsg = default; + return status; + } + + // After a successful read we add the vector while holding a shared lock + // That lock prevents deletion, but everything else can proceed in parallel + result = vectorManager.TryAdd(indexSpan, element.ReadOnlySpan, valueType, values.ReadOnlySpan, attributes.ReadOnlySpan, (uint)reduceDims, quantizer, (uint)buildExplorationFactor, (uint)numLinks, out errorMsg); + + if (result == VectorManagerResult.OK) + { + // On successful addition, we need to manually replicate the write + vectorManager.ReplicateVectorSetAdd(ref key, ref input, ref basicContext); + } + + return GarnetStatus.OK; + } + } + + /// + /// Implement Vector Set Remove - returns not found if the element is not present, or the vector set does not exist. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetRemove(SpanByte key, SpanByte element) + { + var input = new RawStringInput(RespCommand.VREM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + return status; + } + + // After a successful read we remove the vector while holding a shared lock + // That lock prevents deletion, but everything else can proceed in parallel + var res = vectorManager.TryRemove(indexSpan, element.AsReadOnlySpan()); + + if (res == VectorManagerResult.OK) + { + // On successful removal, we need to manually replicate the write + vectorManager.ReplicateVectorSetRemove(ref key, ref element, ref input, ref basicContext); + + return GarnetStatus.OK; + } + + return GarnetStatus.NOTFOUND; + } + } + + /// + /// Perform a similarity search on an existing Vector Set given a vector as a bunch of floats. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetValueSimilarity(SpanByte key, VectorValueType valueType, ArgSlice values, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + // Get the index + var input = new RawStringInput(RespCommand.VSIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + outputIdFormat = VectorIdFormat.Invalid; + return status; + } + + result = vectorManager.ValueSimilarity(indexSpan, valueType, values.ReadOnlySpan, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes); + + return GarnetStatus.OK; + } + } + + /// + /// Perform a similarity search on an existing Vector Set given an element that is already in the Vector Set. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetElementSimilarity(SpanByte key, ReadOnlySpan element, int count, float delta, int searchExplorationFactor, ReadOnlySpan filter, int maxFilteringEffort, bool includeAttributes, ref SpanByteAndMemory outputIds, out VectorIdFormat outputIdFormat, ref SpanByteAndMemory outputDistances, ref SpanByteAndMemory outputAttributes, out VectorManagerResult result) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VSIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + result = VectorManagerResult.Invalid; + outputIdFormat = VectorIdFormat.Invalid; + return status; + } + + result = vectorManager.ElementSimilarity(indexSpan, element, count, delta, searchExplorationFactor, filter, maxFilteringEffort, includeAttributes, ref outputIds, out outputIdFormat, ref outputDistances, ref outputAttributes); + return GarnetStatus.OK; + } + } + + /// + /// Get the approximate vector associated with an element, after (approximately) reversing any transformation. + /// + [SkipLocalsInit] + public unsafe GarnetStatus VectorSetEmbedding(SpanByte key, ReadOnlySpan element, ref SpanByteAndMemory outputDistances) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VEMB, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + return status; + } + + if (!vectorManager.TryGetEmbedding(indexSpan, element, ref outputDistances)) + { + return GarnetStatus.NOTFOUND; + } + + return GarnetStatus.OK; + } + } + + [SkipLocalsInit] + internal unsafe GarnetStatus VectorSetDimensions(SpanByte key, out int dimensions) + { + parseState.InitializeWithArgument(ArgSlice.FromPinnedSpan(key.AsReadOnlySpan())); + + var input = new RawStringInput(RespCommand.VDIM, ref parseState); + + Span indexSpan = stackalloc byte[VectorManager.IndexSizeBytes]; + + using (vectorManager.ReadVectorIndex(this, ref key, ref input, indexSpan, out var status)) + { + if (status != GarnetStatus.OK) + { + dimensions = 0; + return status; + } + + // After a successful read we extract metadata + VectorManager.ReadIndex(indexSpan, out _, out var dimensionsUS, out var reducedDimensionsUS, out _, out _, out _, out _, out _); + + dimensions = (int)(reducedDimensionsUS == 0 ? dimensionsUS : reducedDimensionsUS); + + return GarnetStatus.OK; + } + } + } +} \ No newline at end of file diff --git a/libs/server/Storage/Session/ObjectStore/Common.cs b/libs/server/Storage/Session/ObjectStore/Common.cs index b8ebf286995..5e5a69ad82e 100644 --- a/libs/server/Storage/Session/ObjectStore/Common.cs +++ b/libs/server/Storage/Session/ObjectStore/Common.cs @@ -783,6 +783,41 @@ unsafe GarnetStatus ReadObjectStoreOperation(byte[] key, ref Obj return GarnetStatus.NOTFOUND; } + /// + /// Gets the value of the key store in the Object Store + /// + unsafe GarnetStatus ReadObjectStoreOperationWithObject(byte[] key, ref ObjectInput input, out ObjectOutputHeader output, out IGarnetObject garnetObject, ref TObjectContext objectStoreContext) + where TObjectContext : ITsavoriteContext + { + if (objectStoreContext.Session is null) + ThrowObjectStoreUninitializedException(); + + var _output = new GarnetObjectStoreOutput(); + + // Perform Read on object store + var status = objectStoreContext.Read(ref key, ref input, ref _output); + + if (status.IsPending) + CompletePendingForObjectStoreSession(ref status, ref _output, ref objectStoreContext); + + output = _output.Header; + + if (_output.HasWrongType) + { + garnetObject = null; + return GarnetStatus.WRONGTYPE; + } + + if (status.Found && (!status.Record.Created && !status.Record.CopyUpdated && !status.Record.InPlaceUpdated)) + { + garnetObject = _output.GarnetObject; + return GarnetStatus.OK; + } + + garnetObject = null; + return GarnetStatus.NOTFOUND; + } + /// /// Iterates members of a collection object using a cursor, /// a match pattern and count parameters diff --git a/libs/server/Storage/Session/StorageSession.cs b/libs/server/Storage/Session/StorageSession.cs index 22edec64896..0ff9717d3fb 100644 --- a/libs/server/Storage/Session/StorageSession.cs +++ b/libs/server/Storage/Session/StorageSession.cs @@ -42,6 +42,12 @@ sealed partial class StorageSession : IDisposable public BasicContext objectStoreBasicContext; public LockableContext objectStoreLockableContext; + /// + /// Session Contexts for vector ops against the main store + /// + public BasicContext vectorContext; + public LockableContext vectorLockableContext; + public readonly ScratchBufferBuilder scratchBufferBuilder; public readonly FunctionsState functionsState; @@ -55,11 +61,14 @@ sealed partial class StorageSession : IDisposable public readonly int ObjectScanCountLimit; + public readonly VectorManager vectorManager; + public StorageSession(StoreWrapper storeWrapper, ScratchBufferBuilder scratchBufferBuilder, GarnetSessionMetrics sessionMetrics, GarnetLatencyMetricsSession LatencyMetrics, int dbId, + VectorManager vectorManager, ILogger logger = null, byte respProtocolVersion = ServerOptions.DEFAULT_RESP_VERSION) { @@ -68,6 +77,7 @@ public StorageSession(StoreWrapper storeWrapper, this.scratchBufferBuilder = scratchBufferBuilder; this.logger = logger; this.itemBroker = storeWrapper.itemBroker; + this.vectorManager = vectorManager; parseState.Initialize(); functionsState = storeWrapper.CreateFunctionsState(dbId, respProtocolVersion); @@ -83,6 +93,9 @@ public StorageSession(StoreWrapper storeWrapper, var objectStoreFunctions = new ObjectSessionFunctions(functionsState); var objectStoreSession = db.ObjectStore?.NewSession(objectStoreFunctions); + var vectorFunctions = new VectorSessionFunctions(functionsState); + var vectorSession = db.MainStore.NewSession(vectorFunctions); + basicContext = session.BasicContext; lockableContext = session.LockableContext; if (objectStoreSession != null) @@ -90,6 +103,8 @@ public StorageSession(StoreWrapper storeWrapper, objectStoreBasicContext = objectStoreSession.BasicContext; objectStoreLockableContext = objectStoreSession.LockableContext; } + vectorContext = vectorSession.BasicContext; + vectorLockableContext = vectorSession.LockableContext; HeadAddress = db.MainStore.Log.HeadAddress; ObjectScanCountLimit = storeWrapper.serverOptions.ObjectScanCountLimit; diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index c2b7418292c..9398af1b34f 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -355,7 +355,7 @@ internal void Recover() { RecoverCheckpoint(); RecoverAOF(); - ReplayAOF(); + _ = ReplayAOF(); } } } @@ -843,6 +843,13 @@ public bool HasKeysInSlots(List slots) while (!hasKeyInSlots && iter.GetNext(out RecordInfo record)) { ref var key = ref iter.GetKey(); + + // TODO: better way to ignore vector set elements + if (key.MetadataSize == 1) + { + continue; + } + ushort hashSlotForKey = HashSlotUtils.HashSlot(ref key); if (slots.Contains(hashSlotForKey)) { diff --git a/libs/server/Transaction/TransactionManager.cs b/libs/server/Transaction/TransactionManager.cs index adfe7975ab6..811988967d4 100644 --- a/libs/server/Transaction/TransactionManager.cs +++ b/libs/server/Transaction/TransactionManager.cs @@ -15,13 +15,19 @@ namespace Garnet.server SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; using LockableGarnetApi = GarnetApi, SpanByteAllocator>>, LockableContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + LockableContext, + SpanByteAllocator>>>; using MainStoreAllocator = SpanByteAllocator>; using MainStoreFunctions = StoreFunctions; diff --git a/libs/server/Transaction/TxnKeyManager.cs b/libs/server/Transaction/TxnKeyManager.cs index f8089664799..96607c5b0ca 100644 --- a/libs/server/Transaction/TxnKeyManager.cs +++ b/libs/server/Transaction/TxnKeyManager.cs @@ -48,7 +48,7 @@ public unsafe void VerifyKeyOwnership(ArgSlice key, LockType type) if (!clusterEnabled) return; var readOnly = type == LockType.Shared; - if (!respSession.clusterSession.NetworkIterativeSlotVerify(key, readOnly, respSession.SessionAsking)) + if (!respSession.clusterSession.NetworkIterativeSlotVerify(key, readOnly, respSession.SessionAsking, isVectorSetWriteCommand: false)) // TODO: Is it ok to ignore Vector Set-y-ness of the key? { this.state = TxnState.Aborted; } diff --git a/libs/server/Transaction/TxnRespCommands.cs b/libs/server/Transaction/TxnRespCommands.cs index e51eaf8612c..909aafe1a6b 100644 --- a/libs/server/Transaction/TxnRespCommands.cs +++ b/libs/server/Transaction/TxnRespCommands.cs @@ -60,7 +60,7 @@ private bool NetworkEXEC() endReadHead = txnManager.txnStartHead; txnManager.GetKeysForValidation(recvBufferPtr, out var keys, out int keyCount, out bool readOnly); - if (NetworkKeyArraySlotVerify(keys, readOnly, keyCount)) + if (NetworkKeyArraySlotVerify(keys, readOnly, isVectorSetWriteCommand: false, keyCount)) // TODO: We should actually verify if commands contained are Vector Set writes { logger?.LogWarning("Failed CheckClusterTxnKeys"); txnManager.Reset(false); diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs b/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs index 5d82c473f53..180dfbb0259 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Common/RecordInfo.cs @@ -11,7 +11,7 @@ namespace Tsavorite.core { // RecordInfo layout (64 bits total): - // [Unused1][Modified][InNewVersion][Filler][Dirty][ETag][Sealed][Valid][Tombstone][LLLLLLL] [RAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] + // [VectorSet][Modified][InNewVersion][Filler][Dirty][ETag][Sealed][Valid][Tombstone][LLLLLLL] [RAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] [AAAAAAAA] // where L = leftover, R = readcache, A = address [StructLayout(LayoutKind.Explicit, Size = 8)] public struct RecordInfo @@ -35,7 +35,7 @@ public struct RecordInfo const int kFillerBitOffset = kDirtyBitOffset + 1; const int kInNewVersionBitOffset = kFillerBitOffset + 1; const int kModifiedBitOffset = kInNewVersionBitOffset + 1; - const int kUnused1BitOffset = kModifiedBitOffset + 1; + const int kVectorSetBitOffset = kModifiedBitOffset + 1; const long kTombstoneBitMask = 1L << kTombstoneBitOffset; const long kValidBitMask = 1L << kValidBitOffset; @@ -45,7 +45,7 @@ public struct RecordInfo const long kFillerBitMask = 1L << kFillerBitOffset; const long kInNewVersionBitMask = 1L << kInNewVersionBitOffset; const long kModifiedBitMask = 1L << kModifiedBitOffset; - const long kUnused1BitMask = 1L << kUnused1BitOffset; + const long kVectorSetBitMask = 1L << kVectorSetBitOffset; [FieldOffset(0)] private long word; @@ -269,10 +269,10 @@ public long PreviousAddress [MethodImpl(MethodImplOptions.AggressiveInlining)] public static int GetLength() => kTotalSizeInBytes; - internal bool Unused1 + public bool VectorSet { - readonly get => (word & kUnused1BitMask) != 0; - set => word = value ? word | kUnused1BitMask : word & ~kUnused1BitMask; + readonly get => (word & kVectorSetBitMask) != 0; + set => word = value ? word | kVectorSetBitMask : word & ~kVectorSetBitMask; } public bool ETag @@ -289,7 +289,7 @@ public override readonly string ToString() var paRC = IsReadCache(PreviousAddress) ? "(rc)" : string.Empty; static string bstr(bool value) => value ? "T" : "F"; return $"prev {AbsoluteAddress(PreviousAddress)}{paRC}, valid {bstr(Valid)}, tomb {bstr(Tombstone)}, seal {bstr(IsSealed)}," - + $" mod {bstr(Modified)}, dirty {bstr(Dirty)}, fill {bstr(HasFiller)}, etag {bstr(ETag)}, Un1 {bstr(Unused1)}"; + + $" mod {bstr(Modified)}, dirty {bstr(Dirty)}, fill {bstr(HasFiller)}, etag {bstr(ETag)}, vset {bstr(VectorSet)}"; } } } \ No newline at end of file diff --git a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs index 218b8c0b822..d63a3dac1e6 100644 --- a/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs +++ b/libs/storage/Tsavorite/cs/src/core/Index/Tsavorite/Implementation/InternalDelete.cs @@ -225,6 +225,7 @@ private OperationStatus CreateNewRecordDelete public static Status CreatePending() => new(StatusCode.Pending); + /// + /// Create a Status value. + /// + public static Status CreateNotFound() => new(StatusCode.NotFound); + /// /// Whether a Read or RMW found the key /// diff --git a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs index 5d46bdb8ec4..d3058c49f5a 100644 --- a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs +++ b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByte.cs @@ -25,8 +25,10 @@ public unsafe struct SpanByte private const int UnserializedBitMask = 1 << 31; // Byte #30 is used to denote extra metadata present (1) or absent (0) in payload private const int ExtraMetadataBitMask = 1 << 30; + // Bit #29 used to denote if a namespace is present in payload + private const int NamespaceBitMask = 1 << 29; // Mask for header - private const int HeaderMask = 0x3 << 30; + private const int HeaderMask = UnserializedBitMask | ExtraMetadataBitMask | NamespaceBitMask; /// /// Length of the payload @@ -93,9 +95,9 @@ public int Length public readonly int TotalSize => sizeof(int) + Length; /// - /// Size of metadata header, if any (returns 0 or 8) + /// Size of metadata header, if any (returns 0, 1, 8, or 9) /// - public readonly int MetadataSize => (length & ExtraMetadataBitMask) >> (30 - 3); + public readonly int MetadataSize => ((length & ExtraMetadataBitMask) >> (30 - 3)) + ((length & NamespaceBitMask) >> 29); /// /// Create a around a given pointer and given @@ -144,6 +146,7 @@ public long ExtraMetadata public void MarkExtraMetadata() { Debug.Assert(Length >= 8); + Debug.Assert((length & NamespaceBitMask) == 0, "Don't use both extension for now"); length |= ExtraMetadataBitMask; } @@ -153,6 +156,23 @@ public void MarkExtraMetadata() [MethodImpl(MethodImplOptions.AggressiveInlining)] public void UnmarkExtraMetadata() => length &= ~ExtraMetadataBitMask; + /// + /// Mark as having 1-byte namespace in header of payload + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void MarkNamespace() + { + Debug.Assert(Length >= 1); + Debug.Assert((length & ExtraMetadataBitMask) == 0, "Don't use both extension for now"); + length |= NamespaceBitMask; + } + + /// + /// Unmark as having 1-byte namespace in header of payload + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void UnmarkNamespace() => length &= ~NamespaceBitMask; + /// /// Check or set struct as invalid /// @@ -526,6 +546,18 @@ public void CopyTo(byte* destination) [MethodImpl(MethodImplOptions.AggressiveInlining)] public void SetEtagInPayload(long etag) => *(long*)this.ToPointer() = etag; + /// + /// Gets a namespace from the payload of the SpanByte, caller should make sure the SpanByte has a namespace for the record by checking RecordInfo + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public byte GetNamespaceInPayload() => *(byte*)this.ToPointerWithMetadata(); + + /// + /// Gets a namespace from the payload of the SpanByte, caller should make sure the SpanByte has a namespace for the record by checking RecordInfo + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void SetNamespaceInPayload(byte ns) => *(byte*)this.ToPointerWithMetadata() = ns; + /// public override string ToString() { diff --git a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs index 6e8460c2662..cf6a1c5c9d0 100644 --- a/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs +++ b/libs/storage/Tsavorite/cs/src/core/VarLen/SpanByteAndMemory.cs @@ -83,6 +83,12 @@ public SpanByteAndMemory(IMemoryOwner memory, int length) [MethodImpl(MethodImplOptions.AggressiveInlining)] public ReadOnlySpan AsReadOnlySpan() => IsSpanByte ? SpanByte.AsReadOnlySpan() : Memory.Memory.Span.Slice(0, Length); + /// + /// As a span of the contained data. Use this when you haven't tested . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Span AsSpan() => IsSpanByte ? SpanByte.AsSpan() : Memory.Memory.Span.Slice(0, Length); + /// /// As a span of the contained data. Use this when you have already tested . /// diff --git a/playground/CommandInfoUpdater/GarnetCommandsInfo.json b/playground/CommandInfoUpdater/GarnetCommandsInfo.json index 52786d649d8..afb17f2c2e5 100644 --- a/playground/CommandInfoUpdater/GarnetCommandsInfo.json +++ b/playground/CommandInfoUpdater/GarnetCommandsInfo.json @@ -215,6 +215,19 @@ "KeySpecifications": null, "SubCommands": null }, + { + "Command": "CLUSTER_RESERVE", + "Name": "CLUSTER|RESERVE", + "IsInternal": true, + "Arity": 4, + "Flags": "Admin, NoScript, NoMulti", + "FirstKey": 0, + "LastKey": 0, + "Step": 0, + "AclCategories": "Admin, Dangerous, Garnet", + "KeySpecifications": null, + "SubCommands": null + }, { "Command": "CLUSTER_MTASKS", "Name": "CLUSTER|MTASKS", diff --git a/playground/CommandInfoUpdater/SupportedCommand.cs b/playground/CommandInfoUpdater/SupportedCommand.cs index a1a61b79234..20b2ed74a77 100644 --- a/playground/CommandInfoUpdater/SupportedCommand.cs +++ b/playground/CommandInfoUpdater/SupportedCommand.cs @@ -93,6 +93,7 @@ public class SupportedCommand new("CLUSTER|REPLICAS", RespCommand.CLUSTER_REPLICAS), new("CLUSTER|REPLICATE", RespCommand.CLUSTER_REPLICATE), new("CLUSTER|RESET", RespCommand.CLUSTER_RESET), + new("CLUSTER|RESERVE", RespCommand.CLUSTER_RESERVE), new("CLUSTER|SEND_CKPT_FILE_SEGMENT", RespCommand.CLUSTER_SEND_CKPT_FILE_SEGMENT), new("CLUSTER|SEND_CKPT_METADATA", RespCommand.CLUSTER_SEND_CKPT_METADATA), new("CLUSTER|SET-CONFIG-EPOCH", RespCommand.CLUSTER_SETCONFIGEPOCH), diff --git a/test/Garnet.test.cluster/ClusterTestContext.cs b/test/Garnet.test.cluster/ClusterTestContext.cs index e0a20561726..a2c9843f96c 100644 --- a/test/Garnet.test.cluster/ClusterTestContext.cs +++ b/test/Garnet.test.cluster/ClusterTestContext.cs @@ -117,7 +117,6 @@ public void RestartNode(int nodeIndex) nodes[nodeIndex].Start(); } - public void TearDown() { cts.Cancel(); diff --git a/test/Garnet.test.cluster/ClusterTestUtils.cs b/test/Garnet.test.cluster/ClusterTestUtils.cs index 1571a8881c9..50b9edd6603 100644 --- a/test/Garnet.test.cluster/ClusterTestUtils.cs +++ b/test/Garnet.test.cluster/ClusterTestUtils.cs @@ -8,12 +8,14 @@ using System.Linq; using System.Net; using System.Net.Security; +using System.Runtime.CompilerServices; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using System.Threading.Tasks; using Garnet.client; using Garnet.common; +using Garnet.server; using Garnet.server.TLS; using GarnetClusterManagement; using Microsoft.Extensions.Logging; @@ -1845,12 +1847,22 @@ public int MigrateTasks(IPEndPoint endPoint, ILogger logger) } } - public void WaitForMigrationCleanup(int nodeIndex, ILogger logger = null) - => WaitForMigrationCleanup(endpoints[nodeIndex].ToIPEndPoint(), logger); + public void WaitForMigrationCleanup(int nodeIndex, ILogger logger = null, CancellationToken cancellationToken = default) + => WaitForMigrationCleanup(endpoints[nodeIndex].ToIPEndPoint(), logger, cancellationToken); - public void WaitForMigrationCleanup(IPEndPoint endPoint, ILogger logger) + public void WaitForMigrationCleanup(IPEndPoint endPoint, ILogger logger, CancellationToken cancellationToken = default) { - while (MigrateTasks(endPoint, logger) > 0) { BackOff(cancellationToken: context.cts.Token); } + CancellationToken backoffToken; + if (cancellationToken.CanBeCanceled) + { + backoffToken = cancellationToken; + } + else + { + backoffToken = context.cts.Token; + } + + while (MigrateTasks(endPoint, logger) > 0) { BackOff(cancellationToken: backoffToken); } } public void WaitForMigrationCleanup(ILogger logger) @@ -2895,11 +2907,29 @@ public void WaitForReplicaAofSync(int primaryIndex, int secondaryIndex, ILogger primaryReplicationOffset = GetReplicationOffset(primaryIndex, logger); secondaryReplicationOffset1 = GetReplicationOffset(secondaryIndex, logger); if (primaryReplicationOffset == secondaryReplicationOffset1) + { + var storeWrapper = GetStoreWrapper(this.context.nodes[secondaryIndex]); + var dbManager = GetDatabaseManager(storeWrapper); + + dbManager.DefaultDatabase.VectorManager.WaitForVectorOperationsToComplete(); + break; + } var primaryMainStoreVersion = context.clusterTestUtils.GetStoreCurrentVersion(primaryIndex, isMainStore: true, logger); var replicaMainStoreVersion = context.clusterTestUtils.GetStoreCurrentVersion(secondaryIndex, isMainStore: true, logger); - BackOff(cancellationToken: context.cts.Token, msg: $"[{endpoints[primaryIndex]}]: {primaryMainStoreVersion},{primaryReplicationOffset} != [{endpoints[secondaryIndex]}]: {replicaMainStoreVersion},{secondaryReplicationOffset1}"); + + CancellationToken backoffToken; + if (cancellation.CanBeCanceled) + { + backoffToken = cancellation; + } + else + { + backoffToken = context.cts.Token; + } + + BackOff(cancellationToken: backoffToken, msg: $"[{endpoints[primaryIndex]}]: {primaryMainStoreVersion},{primaryReplicationOffset} != [{endpoints[secondaryIndex]}]: {replicaMainStoreVersion},{secondaryReplicationOffset1}"); } logger?.LogInformation("[{primaryEndpoint}]{primaryReplicationOffset} ?? [{endpoints[secondaryEndpoint}]{secondaryReplicationOffset1}", endpoints[primaryIndex], primaryReplicationOffset, endpoints[secondaryIndex], secondaryReplicationOffset1); } @@ -3162,5 +3192,11 @@ public int DBSize(IPEndPoint endPoint, ILogger logger = null) return -1; } } + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "storeWrapper")] + private static extern ref StoreWrapper GetStoreWrapper(GarnetServer server); + + [UnsafeAccessor(UnsafeAccessorKind.Field, Name = "databaseManager")] + private static extern ref IDatabaseManager GetDatabaseManager(StoreWrapper server); } } \ No newline at end of file diff --git a/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs b/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs index e7a0607cfd2..9d793d0f952 100644 --- a/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs +++ b/test/Garnet.test.cluster/RedirectTests/TestClusterProc.cs @@ -115,13 +115,13 @@ public override void Main(TGarnetApi api, ref CustomProcedureInput p { var offset = 0; var getA = GetNextArg(ref procInput, ref offset); - var setB = GetNextArg(ref procInput, ref offset).SpanByte; - var setC = GetNextArg(ref procInput, ref offset).SpanByte; + var setB = GetNextArg(ref procInput, ref offset); + var setC = GetNextArg(ref procInput, ref offset); _ = api.GET(getA, out _); - var status = api.SET(ref setB, ref setB); + var status = api.SET(setB, setB); ClassicAssert.AreEqual(GarnetStatus.OK, status); - status = api.SET(ref setC, ref setC); + status = api.SET(setC, setC); ClassicAssert.AreEqual(GarnetStatus.OK, status); WriteSimpleString(ref output, "SUCCESS"); } diff --git a/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs new file mode 100644 index 00000000000..1465dba69f2 --- /dev/null +++ b/test/Garnet.test.cluster/VectorSets/ClusterVectorSetTests.cs @@ -0,0 +1,2013 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Net; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Garnet.server; +using Microsoft.Extensions.Logging; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test.cluster +{ + [TestFixture, NonParallelizable] + public class ClusterVectorSetTests + { + private sealed class StringAndByteArrayComparer : IEqualityComparer<(string Key, byte[] Elem)> + { + public static readonly StringAndByteArrayComparer Instance = new(); + + private StringAndByteArrayComparer() { } + + public bool Equals((string Key, byte[] Elem) x, (string Key, byte[] Elem) y) + => x.Key.Equals(y.Key) && x.Elem.SequenceEqual(y.Elem); + + public int GetHashCode([DisallowNull] (string Key, byte[] Elem) obj) + { + HashCode code = default; + code.Add(obj.Key); + code.AddBytes(obj.Elem); + + return code.ToHashCode(); + } + } + + private sealed class CaptureLogWriter(TextWriter passThrough) : TextWriter + { + public bool capture; + public readonly StringBuilder buffer = new(); + + public override Encoding Encoding + => passThrough.Encoding; + + public override void Write(string value) + { + passThrough.Write(value); + + if (capture) + { + lock (buffer) + { + _ = buffer.Append(value); + } + } + } + } + + private const int DefaultShards = 2; + private const int HighReplicationShards = 6; + private const int DefaultMultiPrimaryShards = 4; + + private static readonly Dictionary MonitorTests = new() + { + [nameof(MigrateVectorStressAsync)] = LogLevel.Debug, + }; + + + private ClusterTestContext context; + + private CaptureLogWriter captureLogWriter; + + [SetUp] + public virtual void Setup() + { + captureLogWriter = new(TestContext.Progress); + + context = new ClusterTestContext(); + context.logTextWriter = captureLogWriter; + context.Setup(MonitorTests); + } + + [TearDown] + public virtual void TearDown() + { + context?.TearDown(); + } + + [Test] + [TestCase("XB8", "XPREQ8")] + [TestCase("XB8", "Q8")] + [TestCase("XB8", "BIN")] + [TestCase("XB8", "NOQUANT")] + [TestCase("FP32", "XPREQ8")] + [TestCase("FP32", "Q8")] + [TestCase("FP32", "BIN")] + [TestCase("FP32", "NOQUANT")] + public void BasicVADDReplicates(string vectorFormat, string quantizer) + { + // TODO: also test VALUES format? + + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + + ClassicAssert.IsTrue(Enum.TryParse(vectorFormat, ignoreCase: true, out var vectorFormatParsed)); + ClassicAssert.IsTrue(Enum.TryParse(quantizer, ignoreCase: true, out var quantTypeParsed)); + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1, logger: context.logger); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + byte[] vectorAddData; + if (vectorFormatParsed == VectorValueType.XB8) + { + vectorAddData = new byte[75]; + vectorAddData[0] = 1; + for (var i = 1; i < vectorAddData.Length; i++) + { + vectorAddData[i] = (byte)(vectorAddData[i - 1] + 1); + } + } + else if (vectorFormatParsed == VectorValueType.FP32) + { + var floats = new float[75]; + floats[0] = 1; + for (var i = 1; i < floats.Length; i++) + { + floats[i] = floats[i - 1] + 1; + } + + vectorAddData = MemoryMarshal.Cast(floats).ToArray(); + } + else + { + ClassicAssert.Fail("Unexpected vector format"); + return; + } + + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", vectorFormat, vectorAddData, new byte[] { 0, 0, 0, 0 }, quantizer]); + ClassicAssert.AreEqual(1, addRes); + + byte[] vectorSimData; + if (vectorFormatParsed == VectorValueType.XB8) + { + vectorSimData = new byte[75]; + vectorSimData[0] = 2; + for (var i = 1; i < vectorSimData.Length; i++) + { + vectorSimData[i] = (byte)(vectorSimData[i - 1] + 1); + } + } + else if (vectorFormatParsed == VectorValueType.FP32) + { + var floats = new float[75]; + floats[0] = 2; + for (var i = 1; i < floats.Length; i++) + { + floats[i] = floats[i - 1] + 1; + } + + vectorSimData = MemoryMarshal.Cast(floats).ToArray(); + } + else + { + ClassicAssert.Fail("Unexpected vector format"); + return; + } + + var simRes = (byte[][])context.clusterTestUtils.Execute(primary, "VSIM", ["foo", vectorFormat, vectorSimData]); + ClassicAssert.IsTrue(simRes.Length > 0); + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + var simOnReplica = context.clusterTestUtils.Execute(secondary, "VSIM", ["foo", vectorFormat, vectorSimData]); + ClassicAssert.IsTrue(simOnReplica.Length > 0); + } + + [Test] + [TestCase(false)] + [TestCase(true)] + public async Task ConcurrentVADDReplicatedVSimsAsync(bool withAttributes) + { + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + const int Vectors = 2_000; + const string Key = nameof(ConcurrentVADDReplicatedVSimsAsync); + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1, logger: context.logger); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + { + var r = new Random(2025_09_15_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + int addRes; + if (withAttributes) + { + addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8", "SETATTR", $"{{ \"id\": {i} }}"]); + } + else + { + addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + } + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_15_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + var gotAttrs = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + if (withAttributes) + { + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val, "WITHATTRIBS"]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + + for (var i = 0; i < readRes.Length; i += 2) + { + var id = readRes[i]; + var attr = readRes[i + 1]; + + var asInt = BinaryPrimitives.ReadInt32LittleEndian(id); + + var actualAttr = Encoding.UTF8.GetString(attr); + var expectedAttr = $"{{ \"id\": {asInt} }}"; + + ClassicAssert.AreEqual(expectedAttr, actualAttr); + + gotAttrs++; + } + } + else + { + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + } + + return (nonZeroReturns, gotAttrs); + } + ); + + _ = sync.Release(2); + await writeTask; + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var (searchesWithNonZeroResults, searchesWithAttrs) = await readTask; + + ClassicAssert.IsTrue(searchesWithNonZeroResults > 0); + + if (withAttributes) + { + ClassicAssert.IsTrue(searchesWithAttrs > 0); + } + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + var s = (byte)float.Parse(fromSecondary[i]); + + ClassicAssert.AreEqual(expected[i], p); + ClassicAssert.AreEqual(expected[i], s); + } + } + } + } + + [Test] + public void RepeatedCreateDelete() + { + const int PrimaryIndex = 0; + const int SecondaryIndex = 1; + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 1, logger: context.logger); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondary = (IPEndPoint)context.endpoints[SecondaryIndex]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + + var bytes1 = new byte[75]; + bytes1[0] = 1; + for (var j = 1; j < bytes1.Length; j++) + { + bytes1[j] = (byte)(bytes1[j - 1] + 1); + } + + var bytes2 = new byte[75]; + bytes2[0] = 5; + for (var j = 1; j < bytes2.Length; j++) + { + bytes2[j] = (byte)(bytes2[j - 1] + 1); + } + + var bytes3 = new byte[75]; + bytes3[0] = 10; + for (var j = 1; j < bytes3.Length; j++) + { + bytes3[j] = (byte)(bytes3[j - 1] + 1); + } + + var key0 = new byte[4]; + key0[0] = 1; + var key1 = new byte[4]; + key1[0] = 2; + + for (var i = 0; i < 100; i++) + { + var delRes = (int)context.clusterTestUtils.Execute(primary, "DEL", ["foo"]); + + if (i != 0) + { + ClassicAssert.AreEqual(1, delRes); + } + else + { + ClassicAssert.AreEqual(0, delRes); + } + + var addRes1 = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", "XB8", bytes1, key0, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes1); + + var addRes2 = (int)context.clusterTestUtils.Execute(primary, "VADD", ["foo", "XB8", bytes2, key1, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes2); + + var readPrimaryExc = (string)context.clusterTestUtils.Execute(primary, "GET", ["foo"]); + ClassicAssert.IsTrue(readPrimaryExc.StartsWith("WRONGTYPE ")); + + var queryPrimary = (byte[][])context.clusterTestUtils.Execute(primary, "VSIM", ["foo", "XB8", bytes3]); + ClassicAssert.AreEqual(2, queryPrimary.Length); + + _ = context.clusterTestUtils.Execute(secondary, "READONLY", []); + + // The vector set has either replicated, or not + // If so - we get WRONGTYPE + // If not - we get a null + var readSecondary = (string)context.clusterTestUtils.Execute(secondary, "GET", ["foo"]); + ClassicAssert.IsTrue(readSecondary is null || readSecondary.StartsWith("WRONGTYPE ")); + + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, SecondaryIndex); + + var querySecondary = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", ["foo", "XB8", bytes3]); + ClassicAssert.IsTrue(querySecondary.Length >= 1); + + for (var j = 0; j < querySecondary.Length; j++) + { + var expected = + querySecondary[j].AsSpan().SequenceEqual(key0) || + querySecondary[j].AsSpan().SequenceEqual(key1); + + ClassicAssert.IsTrue(expected); + } + + Incr(key0); + Incr(key1); + } + + static void Incr(byte[] k) + { + var ix = k.Length - 1; + while (true) + { + k[ix]++; + if (k[ix] == 0) + { + ix--; + } + else + { + break; + } + } + } + } + + [Test] + public async Task MultipleReplicasWithVectorSetsAsync() + { + const int PrimaryIndex = 0; + const int SecondaryStartIndex = 1; + const int SecondaryEndIndex = 5; + const int Vectors = 2_000; + const string Key = nameof(MultipleReplicasWithVectorSetsAsync); + + context.CreateInstances(HighReplicationShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 5, logger: context.logger); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondaries = new IPEndPoint[SecondaryEndIndex - SecondaryStartIndex + 1]; + for (var i = SecondaryStartIndex; i <= SecondaryEndIndex; i++) + { + secondaries[i - SecondaryStartIndex] = (IPEndPoint)context.endpoints[i]; + } + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + + foreach (var secondary in secondaries) + { + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + } + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + { + var r = new Random(2025_09_23_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTasks = new Task[secondaries.Length]; + + for (var i = 0; i < secondaries.Length; i++) + { + var secondary = secondaries[i]; + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_23_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + + return nonZeroReturns; + } + ); + + readTasks[i] = readTask; + } + + _ = sync.Release(secondaries.Length + 1); + await writeTask; + + for (var secondaryIndex = SecondaryStartIndex; secondaryIndex <= SecondaryEndIndex; secondaryIndex++) + { + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, secondaryIndex); + } + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var searchesWithNonZeroResults = await Task.WhenAll(readTasks); + + ClassicAssert.IsTrue(searchesWithNonZeroResults.All(static x => x > 0)); + + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + ClassicAssert.AreEqual(expected[i], p); + } + + for (var secondaryIx = 0; secondaryIx < secondaries.Length; secondaryIx++) + { + var secondary = secondaries[secondaryIx]; + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var s = (byte)float.Parse(fromSecondary[i]); + ClassicAssert.AreEqual(expected[i], s); + } + } + } + } + } + + [Test] + public async Task MultipleReplicasWithVectorSetsAndDeletesAsync() + { + const int PrimaryIndex = 0; + const int SecondaryStartIndex = 1; + const int SecondaryEndIndex = 5; + const int Vectors = 2_000; + const int Deletes = Vectors / 10; + const string Key = nameof(MultipleReplicasWithVectorSetsAndDeletesAsync); + + context.CreateInstances(HighReplicationShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: 1, replica_count: 5, logger: context.logger); + + var primary = (IPEndPoint)context.endpoints[PrimaryIndex]; + var secondaries = new IPEndPoint[SecondaryEndIndex - SecondaryStartIndex + 1]; + for (var i = SecondaryStartIndex; i <= SecondaryEndIndex; i++) + { + secondaries[i - SecondaryStartIndex] = (IPEndPoint)context.endpoints[i]; + } + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary).Value); + + foreach (var secondary in secondaries) + { + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary).Value); + } + + // Build some repeatably random data for inserts + var vectors = new byte[Vectors][]; + var toDeleteVectors = new HashSet(); + var pendingRemove = new List(); + { + var r = new Random(2025_10_20_00); + + for (var i = 0; i < vectors.Length; i++) + { + vectors[i] = new byte[75]; + r.NextBytes(vectors[i]); + } + + while (toDeleteVectors.Count < Deletes) + { + _ = toDeleteVectors.Add(r.Next(vectors.Length)); + } + + pendingRemove.AddRange(toDeleteVectors); + } + + using var sync = new SemaphoreSlim(2); + + var writeTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + for (var i = 0; i < vectors.Length; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(key, i); + var val = vectors[i]; + var addRes = (int)context.clusterTestUtils.Execute(primary, "VADD", [Key, "XB8", val, key, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes); + } + } + ); + + var deleteTask = + Task.Run( + async () => + { + await sync.WaitAsync(); + + var key = new byte[4]; + + while (pendingRemove.Count > 0) + { + var i = Random.Shared.Next(pendingRemove.Count); + var id = pendingRemove[i]; + + BinaryPrimitives.WriteInt32LittleEndian(key, id); + var remRes = (int)context.clusterTestUtils.Execute(primary, "VREM", [Key, key]); + if (remRes == 1) + { + pendingRemove.RemoveAt(i); + } + } + } + ); + + using var cts = new CancellationTokenSource(); + + var readTasks = new Task[secondaries.Length]; + + for (var i = 0; i < secondaries.Length; i++) + { + var secondary = secondaries[i]; + var readTask = + Task.Run( + async () => + { + var r = new Random(2025_09_23_01); + + var readonlyOnReplica = (string)context.clusterTestUtils.Execute(secondary, "READONLY", []); + ClassicAssert.AreEqual("OK", readonlyOnReplica); + + await sync.WaitAsync(); + + var nonZeroReturns = 0; + + while (!cts.Token.IsCancellationRequested) + { + var val = vectors[r.Next(vectors.Length)]; + + var readRes = (byte[][])context.clusterTestUtils.Execute(secondary, "VSIM", [Key, "XB8", val]); + if (readRes.Length > 0) + { + nonZeroReturns++; + } + } + + return nonZeroReturns; + } + ); + + readTasks[i] = readTask; + } + + _ = sync.Release(secondaries.Length + 2); + await writeTask; + await deleteTask; + + for (var secondaryIndex = SecondaryStartIndex; secondaryIndex <= SecondaryEndIndex; secondaryIndex++) + { + context.clusterTestUtils.WaitForReplicaAofSync(PrimaryIndex, secondaryIndex); + } + + cts.CancelAfter(TimeSpan.FromSeconds(1)); + + var searchesWithNonZeroResults = await Task.WhenAll(readTasks); + + ClassicAssert.IsTrue(searchesWithNonZeroResults.All(static x => x > 0)); + + // Validate all nodes have same vector embeddings + { + var idBytes = new byte[4]; + for (var id = 0; id < vectors.Length; id++) + { + BinaryPrimitives.WriteInt32LittleEndian(idBytes, id); + var expected = vectors[id]; + + var fromPrimary = (string[])context.clusterTestUtils.Execute(primary, "VEMB", [Key, idBytes]); + + var shouldBePresent = !toDeleteVectors.Contains(id); + if (shouldBePresent) + { + ClassicAssert.AreEqual(expected.Length, fromPrimary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var p = (byte)float.Parse(fromPrimary[i]); + ClassicAssert.AreEqual(expected[i], p); + } + } + else + { + ClassicAssert.IsEmpty(fromPrimary); + } + + for (var secondaryIx = 0; secondaryIx < secondaries.Length; secondaryIx++) + { + var secondary = secondaries[secondaryIx]; + var fromSecondary = (string[])context.clusterTestUtils.Execute(secondary, "VEMB", [Key, idBytes]); + + if (shouldBePresent) + { + ClassicAssert.AreEqual(expected.Length, fromSecondary.Length); + + for (var i = 0; i < expected.Length; i++) + { + var s = (byte)float.Parse(fromSecondary[i]); + ClassicAssert.AreEqual(expected[i], s); + } + } + else + { + ClassicAssert.IsEmpty(fromSecondary); + } + } + } + } + } + + [Test] + public void VectorSetMigrateSingleBySlot() + { + // Test migrating a single slot with a vector set of one element in it + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1, logger: context.logger); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string primary0Key; + int primary0HashSlot; + { + var ix = 0; + + while (true) + { + primary0Key = $"{nameof(VectorSetMigrateSingleBySlot)}_{ix}"; + primary0HashSlot = context.clusterTestUtils.HashSlot(primary0Key); + + if (slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && primary0HashSlot >= x.startSlot && primary0HashSlot <= x.endSlot)) + { + break; + } + + ix++; + } + } + + // Setup simple vector set on Primary0 in some hash slot + + var vectorData = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + var vectorSimData = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + + var add0Res = (int)context.clusterTestUtils.Execute(primary0, "VADD", [primary0Key, "XB8", vectorData, new byte[] { 0, 0, 0, 0 }, "XPREQ8", "SETATTR", "{\"hello\": \"world\"}"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add0Res); + + var sim0Res = (byte[][])context.clusterTestUtils.Execute(primary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim0Res.Length); + ClassicAssert.IsTrue(new byte[] { 0, 0, 0, 0 }.SequenceEqual(sim0Res[0])); + ClassicAssert.IsFalse(float.IsNaN(float.Parse(Encoding.ASCII.GetString(sim0Res[1])))); + ClassicAssert.IsTrue("{\"hello\": \"world\"}"u8.SequenceEqual(sim0Res[2])); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + + var readonlyOnReplica0 = (string)context.clusterTestUtils.Execute(secondary0, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica0); + + var simOnReplica0 = (byte[][])context.clusterTestUtils.Execute(secondary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(simOnReplica0.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(simOnReplica0[i])); + } + + // Move to other primary + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [primary0HashSlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + ClassicAssert.IsFalse(curPrimary0Slots.Contains(primary0HashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(primary0HashSlot)); + + // Check available on other primary & secondary + + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(sim1Res.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(sim1Res[i])); + } + + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + var simOnReplica1 = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(simOnReplica1.Length > 0); + for (var i = 0; i < sim0Res.Length; i++) + { + ClassicAssert.IsTrue(sim0Res[i].AsSpan().SequenceEqual(simOnReplica0[i])); + } + + // Check no longer available on old primary or secondary + var exc0 = (string)context.clusterTestUtils.Execute(primary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc0.StartsWith("Key has MOVED to ")); + + var start = Stopwatch.GetTimestamp(); + + var success = false; + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + try + { + var exc1 = (string)context.clusterTestUtils.Execute(secondary0, "VSIM", [primary0Key, "XB8", vectorSimData, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc1.StartsWith("Key has MOVED to ")); + success = true; + break; + } + catch + { + // Secondary can still have the key for a bit + Thread.Sleep(100); + } + } + + ClassicAssert.IsTrue(success, "Original replica still has Vector Set long after primary has completed"); + } + + [Test] + public void VectorSetMigrateByKeys() + { + // Based on : ClusterSimpleMigrateKeys test + + const int ShardCount = 3; + const int KeyCount = 10; + + context.CreateInstances(ShardCount, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(logger: context.logger); + + var otherNodeIndex = 0; + var sourceNodeIndex = 1; + var targetNodeIndex = 2; + var sourceNodeId = context.clusterTestUtils.GetNodeIdFromNode(sourceNodeIndex, context.logger); + var targetNodeId = context.clusterTestUtils.GetNodeIdFromNode(targetNodeIndex, context.logger); + + var key = Encoding.ASCII.GetBytes("{abc}a"); + List keys = []; + List<(byte[] Key, byte[] Data)> vectors = []; + List attributes = []; + + var _workingSlot = ClusterTestUtils.HashSlot(key); + ClassicAssert.AreEqual(7638, _workingSlot); + + Random rand = new(2025_11_04_00); + + for (var i = 0; i < KeyCount; i++) + { + var newKey = new byte[key.Length]; + Array.Copy(key, 0, newKey, 0, key.Length); + newKey[^1] = (byte)(newKey[^1] + i); + ClassicAssert.AreEqual(_workingSlot, ClusterTestUtils.HashSlot(newKey)); + + var elem = new byte[4]; + rand.NextBytes(elem); + + var data = new byte[75]; + rand.NextBytes(data); + + var attrs = new byte[16]; + rand.NextBytes(attrs); + + var addRes = (int)context.clusterTestUtils.Execute(context.clusterTestUtils.GetEndPoint(sourceNodeIndex), "VADD", [newKey, "XB8", data, elem, "XPREQ8", "SETATTR", attrs]); + ClassicAssert.AreEqual(1, addRes); + + keys.Add(newKey); + vectors.Add((elem, data)); + attributes.Add(attrs); + } + + // Start migration + var respImport = context.clusterTestUtils.SetSlot(targetNodeIndex, _workingSlot, "IMPORTING", sourceNodeId, logger: context.logger); + ClassicAssert.AreEqual(respImport, "OK"); + + var respMigrate = context.clusterTestUtils.SetSlot(sourceNodeIndex, _workingSlot, "MIGRATING", targetNodeId, logger: context.logger); + ClassicAssert.AreEqual(respMigrate, "OK"); + + // Check key count + var countKeys = context.clusterTestUtils.CountKeysInSlot(sourceNodeIndex, _workingSlot, context.logger); + ClassicAssert.AreEqual(countKeys, KeyCount); + + // Enumerate keys in slots + var keysInSlot = context.clusterTestUtils.GetKeysInSlot(sourceNodeIndex, _workingSlot, countKeys, context.logger); + ClassicAssert.AreEqual(keys, keysInSlot); + + // Migrate keys, but in a random-ish order so context reservation gets stressed + var toMigrate = keysInSlot.ToList(); + while (toMigrate.Count > 0) + { + var migrateSingleIx = rand.Next(toMigrate.Count); + var migrateKey = toMigrate[migrateSingleIx]; + context.clusterTestUtils.MigrateKeys(context.clusterTestUtils.GetEndPoint(sourceNodeIndex), context.clusterTestUtils.GetEndPoint(targetNodeIndex), [migrateKey], context.logger); + + toMigrate.RemoveAt(migrateSingleIx); + } + + // Finish migration + var respNodeTarget = context.clusterTestUtils.SetSlot(targetNodeIndex, _workingSlot, "NODE", targetNodeId, logger: context.logger); + ClassicAssert.AreEqual(respNodeTarget, "OK"); + context.clusterTestUtils.BumpEpoch(targetNodeIndex, waitForSync: true, logger: context.logger); + + var respNodeSource = context.clusterTestUtils.SetSlot(sourceNodeIndex, _workingSlot, "NODE", targetNodeId, logger: context.logger); + ClassicAssert.AreEqual(respNodeSource, "OK"); + context.clusterTestUtils.BumpEpoch(sourceNodeIndex, waitForSync: true, logger: context.logger); + // End Migration + + // Check config + var targetConfigEpochFromTarget = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(targetNodeIndex, targetNodeId, context.logger); + var targetConfigEpochFromSource = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(sourceNodeIndex, targetNodeId, context.logger); + var targetConfigEpochFromOther = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(otherNodeIndex, targetNodeId, context.logger); + + while (targetConfigEpochFromOther != targetConfigEpochFromTarget || targetConfigEpochFromSource != targetConfigEpochFromTarget) + { + _ = Thread.Yield(); + targetConfigEpochFromTarget = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(targetNodeIndex, targetNodeId, context.logger); + targetConfigEpochFromSource = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(sourceNodeIndex, targetNodeId, context.logger); + targetConfigEpochFromOther = context.clusterTestUtils.GetConfigEpochOfNodeFromNodeIndex(otherNodeIndex, targetNodeId, context.logger); + } + ClassicAssert.AreEqual(targetConfigEpochFromTarget, targetConfigEpochFromOther); + ClassicAssert.AreEqual(targetConfigEpochFromTarget, targetConfigEpochFromSource); + + // Check migration in progress + foreach (var _key in keys) + { + var resp = context.clusterTestUtils.GetKey(otherNodeIndex, _key, out var slot, out var endpoint, out var responseState, logger: context.logger); + while (endpoint.Port != context.clusterTestUtils.GetEndPoint(targetNodeIndex).Port && responseState != ResponseState.OK) + { + resp = context.clusterTestUtils.GetKey(otherNodeIndex, _key, out slot, out endpoint, out responseState, logger: context.logger); + } + ClassicAssert.AreEqual(resp, "MOVED"); + ClassicAssert.AreEqual(_workingSlot, slot); + ClassicAssert.AreEqual(context.clusterTestUtils.GetEndPoint(targetNodeIndex), endpoint); + } + + // Finish migration + context.clusterTestUtils.WaitForMigrationCleanup(context.logger); + + // Validate vector sets coherent + for (var i = 0; i < keys.Count; i++) + { + var _key = keys[i]; + var (elem, data) = vectors[i]; + var attrs = attributes[i]; + + var res = (byte[][])context.clusterTestUtils.Execute(context.clusterTestUtils.GetEndPoint(targetNodeIndex), "VSIM", [_key, "XB8", data, "WITHATTRIBS"]); + ClassicAssert.AreEqual(2, res.Length); + ClassicAssert.IsTrue(res[0].SequenceEqual(elem)); + ClassicAssert.IsTrue(res[1].SequenceEqual(attrs)); + } + } + + [Test] + public void VectorSetMigrateManyBySlot() + { + // Test migrating several vector sets from one primary to another primary, which already has vectors sets of its own + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + const int VectorSetsPerPrimary = 8; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1, logger: context.logger); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + List<(string Key, ushort HashSlot, byte[] Element, byte[] Data, byte[] Attr)> primary0Keys = []; + List<(string Key, ushort HashSlot, byte[] Element, byte[] Data, byte[] Attr)> primary1Keys = []; + + { + var ix = 0; + + while (primary0Keys.Count < VectorSetsPerPrimary || primary1Keys.Count < VectorSetsPerPrimary) + { + var key = $"{nameof(VectorSetMigrateManyBySlot)}_{ix}"; + var hashSlot = context.clusterTestUtils.HashSlot(key); + + var isOnPrimary0 = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && hashSlot >= x.startSlot && hashSlot <= x.endSlot); + var isOnPrimary1 = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary1Id) && hashSlot >= x.startSlot && hashSlot <= x.endSlot); + + if (isOnPrimary0 && primary0Keys.Count < VectorSetsPerPrimary) + { + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[10]; + Random.Shared.NextBytes(elem); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + primary0Keys.Add((key, (ushort)hashSlot, elem, data, attr)); + } + + if (isOnPrimary1 && primary1Keys.Count < VectorSetsPerPrimary) + { + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[10]; + Random.Shared.NextBytes(elem); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + primary1Keys.Add((key, (ushort)hashSlot, elem, data, attr)); + } + + ix++; + } + } + + // Setup vectors on the primaries + foreach (var (key, _, elem, data, attr) in primary0Keys) + { + var add0Res = (int)context.clusterTestUtils.Execute(primary0, "VADD", [key, "XB8", data, elem, "XPREQ8", "SETATTR", attr], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add0Res); + } + + foreach (var (key, _, elem, data, attr) in primary1Keys) + { + var add1Res = (int)context.clusterTestUtils.Execute(primary1, "VADD", [key, "XB8", data, elem, "XPREQ8", "SETATTR", attr], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(1, add1Res); + } + + // Query expected results + Dictionary<(string Key, byte[] Data), (byte[] Elem, byte[] Attr, float Score)> expected = new(StringAndByteArrayComparer.Instance); + + foreach (var (key, _, _, data, _) in primary0Keys) + { + var sim0Res = (byte[][])context.clusterTestUtils.Execute(primary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim0Res.Length); + expected.Add((key, data), (sim0Res[0], sim0Res[2], float.Parse(Encoding.ASCII.GetString(sim0Res[1])))); + } + + foreach (var (key, _, _, data, _) in primary1Keys) + { + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim1Res.Length); + expected.Add((key, data), (sim1Res[0], sim1Res[2], float.Parse(Encoding.ASCII.GetString(sim1Res[1])))); + } + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + + // Move from primary0 to primary1 + var migratedHashSlots = primary0Keys.Select(static t => t.HashSlot).Distinct().Select(static s => (int)s).ToList(); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, migratedHashSlots); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + foreach (var hashSlot in migratedHashSlots) + { + ClassicAssert.IsFalse(curPrimary0Slots.Contains(hashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(hashSlot)); + } + + // Check available on other primary + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var migrateSimRes = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, migrateSimRes.Length); + + var (elem, attr, score) = expected[(key, data)]; + + ClassicAssert.IsTrue(elem.SequenceEqual(migrateSimRes[0])); + ClassicAssert.AreEqual(score, float.Parse(Encoding.ASCII.GetString(migrateSimRes[1]))); + ClassicAssert.IsTrue(attr.SequenceEqual(migrateSimRes[2])); + } + + // Check no longer available on old primary or secondary + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var exc0 = (string)context.clusterTestUtils.Execute(primary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + ClassicAssert.IsTrue(exc0.StartsWith("Key has MOVED to ")); + } + + var start = Stopwatch.GetTimestamp(); + + var success = false; + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + try + { + var migrationNotFinished = false; + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var exc1 = (string)context.clusterTestUtils.Execute(secondary0, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + if (!exc1.StartsWith("Key has MOVED to ")) + { + migrationNotFinished = true; + break; + } + } + + if (migrationNotFinished) + { + continue; + } + + success = true; + break; + } + catch + { + // Secondary can still have the key for a bit + Thread.Sleep(100); + } + } + + ClassicAssert.IsTrue(success, "Original replica still has Vector Set long after primary has completed"); + + // Check available on new secondary + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + start = Stopwatch.GetTimestamp(); + + success = false; + + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + success = true; + + foreach (var (key, _, _, data, _) in primary0Keys.Concat(primary1Keys)) + { + var migrateSimRes = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [key, "XB8", data, "WITHSCORES", "WITHATTRIBS"], flags: CommandFlags.NoRedirect); + + if (migrateSimRes.Length == 1 && Encoding.UTF8.GetString(migrateSimRes[1]).StartsWith("Key has MOVED to ")) + { + success = false; + break; + } + + ClassicAssert.AreEqual(3, migrateSimRes.Length); + + var (elem, attr, score) = expected[(key, data)]; + + ClassicAssert.IsTrue(elem.SequenceEqual(migrateSimRes[0])); + ClassicAssert.AreEqual(score, float.Parse(Encoding.ASCII.GetString(migrateSimRes[1]))); + ClassicAssert.IsTrue(attr.SequenceEqual(migrateSimRes[2])); + } + + if (success) + { + break; + } + } + + ClassicAssert.IsTrue(success, "New replica hasn't replicated Vector Set long after primary has received data"); + } + + [Test] + public async Task MigrateVectorSetWhileModifyingAsync() + { + // Test migrating a single slot with a vector set while moving it + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true, OnDemandCheckpoint: true, EnableIncrementalSnapshots: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1, logger: context.logger); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string primary0Key; + int primary0HashSlot; + { + var ix = 0; + + while (true) + { + primary0Key = $"{nameof(MigrateVectorSetWhileModifyingAsync)}_{ix}"; + primary0HashSlot = context.clusterTestUtils.HashSlot(primary0Key); + + if (slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && primary0HashSlot >= x.startSlot && primary0HashSlot <= x.endSlot)) + { + break; + } + + ix++; + } + } + + // Start writing to this Vector Set + using var cts = new CancellationTokenSource(); + + var added = new ConcurrentBag<(byte[] Elem, byte[] Data, byte[] Attr)>(); + + var writeTask = + Task.Run( + async () => + { + // Force async + await Task.Yield(); + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDb = readWriteCon.GetDatabase(); + + var ix = 0; + + var elem = new byte[4]; + var data = new byte[75]; + var attr = new byte[100]; + + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + + while (!cts.IsCancellationRequested) + { + if (TestUtils.IsRunningAsGitHubAction) + { + // Throw some delay in when running as a GitHub Action to work around the weak drives those VMs have + await Task.Delay(1); + } + + // This should follow redirects, so migration shouldn't cause any failures + try + { + var addRes = (int)readWriteDb.Execute("VADD", [new RedisKey(primary0Key), "XB8", data, elem, "XPREQ8", "SETATTR", attr]); + ClassicAssert.AreEqual(1, addRes); + } + catch (RedisServerException exc) + { + if (exc.Message.StartsWith("MOVED ")) + { + continue; + } + + throw; + } + + added.Add((elem.ToArray(), data.ToArray(), attr.ToArray())); + + ix++; + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + Random.Shared.NextBytes(data); + Random.Shared.NextBytes(attr); + } + } + ); + + await Task.Delay(1_000); + + var lenPreMigration = added.Count; + ClassicAssert.IsTrue(lenPreMigration > 0, "Should have seen some writes pre-migration"); + + // Move to other primary + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [primary0HashSlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + using (var replicationToken = new CancellationTokenSource()) + { + replicationToken.CancelAfter(30_000); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index, cancellation: replicationToken.Token); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index, cancellation: replicationToken.Token); + } + + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + ClassicAssert.IsFalse(curPrimary0Slots.Contains(primary0HashSlot)); + ClassicAssert.IsTrue(curPrimary1Slots.Contains(primary0HashSlot)); + + var lenPrePause = added.Count; + await Task.Delay(5_000); + var lenPostPause = added.Count; + + ClassicAssert.IsTrue(lenPostPause > lenPrePause, "Writes after migration did not resume"); + + // Stop Writes and wait for replication to catch up + cts.Cancel(); + await writeTask; + + var addedLookup = added.ToFrozenDictionary(static t => t.Elem, t => t, ByteArrayComparer.Instance); + + context.clusterTestUtils.WaitForReplicaAofSync(Primary0Index, Secondary0Index); + context.clusterTestUtils.WaitForReplicaAofSync(Primary1Index, Secondary1Index); + + // Check available on other primary & secondary + + foreach (var (_, data, _) in added) + { + var sim1Res = (byte[][])context.clusterTestUtils.Execute(primary1, "VSIM", [primary0Key, "XB8", data, "WITHSCORES", "WITHATTRIBS", "COUNT", "1"], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual(3, sim1Res.Length); + + // No guarantee we'll get the exact same element, but we should always get _a_ result and the correct associated attribute + var resElem = sim1Res[0]; + var resAttr = sim1Res[2]; + var expectedAttr = addedLookup[resElem].Attr; + ClassicAssert.IsTrue(resAttr.SequenceEqual(expectedAttr)); + } + + var readonlyOnReplica1 = (string)context.clusterTestUtils.Execute(secondary1, "READONLY", [], flags: CommandFlags.NoRedirect); + ClassicAssert.AreEqual("OK", readonlyOnReplica1); + + foreach (var (elem, data, attr) in added) + { + var simOnReplica1Res = (byte[][])context.clusterTestUtils.Execute(secondary1, "VSIM", [primary0Key, "XB8", data, "WITHSCORES", "WITHATTRIBS", "COUNT", "1"], flags: CommandFlags.NoRedirect); + + // No guarantee we'll get the exact same element, but we should always get _a_ result and the correct associated attribute + var resElem = simOnReplica1Res[0]; + var resAttr = simOnReplica1Res[2]; + var expectedAttr = addedLookup[resElem].Attr; + ClassicAssert.IsTrue(resAttr.SequenceEqual(expectedAttr)); + } + } + + [Test] + public void MigrateVectorSetBack() + { + const int Primary0Index = 0; + const int Primary1Index = 1; + + context.CreateInstances(DefaultShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultShards, replica_count: 0, logger: context.logger); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + string vectorSetKey; + int vectorSetKeySlot; + { + var ix = 0; + + while (true) + { + vectorSetKey = $"{nameof(MigrateVectorSetBack)}_{ix}"; + vectorSetKeySlot = context.clusterTestUtils.HashSlot(vectorSetKey); + + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && vectorSetKeySlot >= x.startSlot && vectorSetKeySlot <= x.endSlot); + if (isPrimary0Slot) + { + break; + } + + ix++; + } + } + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDB = readWriteCon.GetDatabase(); + + var data0 = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + byte[] elem0 = [1, 2, 3, 0]; + var attr0 = "hello world"u8.ToArray(); + + var add0Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data0, elem0, "XPREQ8", "SETATTR", attr0]); + ClassicAssert.AreEqual(1, add0Res); + + // Migrate 0 -> 1 + context.logger?.LogInformation("Starting 0 -> 1 migration of {slot}", vectorSetKeySlot); + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, [vectorSetKeySlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + var movedOffPrimary0 = !curPrimary0Slots.Contains(vectorSetKeySlot); + var movedOntoPrimary1 = curPrimary1Slots.Contains(vectorSetKeySlot); + + if (movedOffPrimary0 && movedOntoPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 0 -> 1 migration took too long"); + } + + // Confirm still valid to add, with client side routing + var data1 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + byte[] elem1 = [4, 5, 6, 7]; + var attr1 = "fizz buzz"u8.ToArray(); + + var add1Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data1, elem1, "XPREQ8", "SETATTR", attr1]); + ClassicAssert.AreEqual(1, add1Res); + + // Migrate 1 -> 0 + context.logger?.LogInformation("Starting 1 -> 0 migration of {slot}", vectorSetKeySlot); + { + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary1, primary0, [vectorSetKeySlot]); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + var movedOntoPrimary0 = curPrimary0Slots.Contains(vectorSetKeySlot); + var movedOffPrimary1 = !curPrimary1Slots.Contains(vectorSetKeySlot); + + if (movedOntoPrimary0 && movedOffPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 1 -> 0 migration took too long"); + } + + // Confirm still valid to add, with client side routing + var data2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 3)).ToArray(); + byte[] elem2 = [8, 9, 10, 11]; + var attr2 = "foo bar"u8.ToArray(); + + var add2Res = (int)readWriteDB.Execute("VADD", [new RedisKey(vectorSetKey), "XB8", data2, elem2, "XPREQ8", "SETATTR", attr2]); + ClassicAssert.AreEqual(1, add2Res); + + // Confirm no data loss + var emb0 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem0])).Select(static x => (byte)float.Parse(x)).ToArray(); + var emb1 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem1])).Select(static x => (byte)float.Parse(x)).ToArray(); + var emb2 = ((string[])readWriteDB.Execute("VEMB", [new RedisKey(vectorSetKey), elem2])).Select(static x => (byte)float.Parse(x)).ToArray(); + ClassicAssert.IsTrue(data0.SequenceEqual(emb0)); + ClassicAssert.IsTrue(data1.SequenceEqual(emb1)); + ClassicAssert.IsTrue(data2.SequenceEqual(emb2)); + } + + [Test] + public async Task MigrateVectorStressAsync() + { + // Move vector sets back and forth between replicas, making sure we don't drop data + // Keeps reads and writes going continuously + + const int Primary0Index = 0; + const int Primary1Index = 1; + const int Secondary0Index = 2; + const int Secondary1Index = 3; + + const int VectorSetsPerPrimary = 2; + + var gossipFaultsAtTestStart = 0; + + captureLogWriter.capture = true; + + try + { + context.CreateInstances(DefaultMultiPrimaryShards, useTLS: true, enableAOF: true); + context.CreateConnection(useTLS: true); + _ = context.clusterTestUtils.SimpleSetupCluster(primary_count: DefaultMultiPrimaryShards / 2, replica_count: 1, logger: context.logger); + + var primary0 = (IPEndPoint)context.endpoints[Primary0Index]; + var primary1 = (IPEndPoint)context.endpoints[Primary1Index]; + var secondary0 = (IPEndPoint)context.endpoints[Secondary0Index]; + var secondary1 = (IPEndPoint)context.endpoints[Secondary1Index]; + + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary0).Value); + ClassicAssert.AreEqual("master", context.clusterTestUtils.RoleCommand(primary1).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary0).Value); + ClassicAssert.AreEqual("slave", context.clusterTestUtils.RoleCommand(secondary1).Value); + + var primary0Id = context.clusterTestUtils.ClusterMyId(primary0); + var primary1Id = context.clusterTestUtils.ClusterMyId(primary1); + + var slots = context.clusterTestUtils.ClusterSlots(primary0); + + var vectorSetKeys = new List<(string Key, ushort HashSlot)>(); + + { + var ix = 0; + + var numP0 = 0; + var numP1 = 0; + + while (numP0 < VectorSetsPerPrimary || numP1 < VectorSetsPerPrimary) + { + var key = $"{nameof(MigrateVectorStressAsync)}_{ix}"; + var slot = context.clusterTestUtils.HashSlot(key); + + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && slot >= x.startSlot && slot <= x.endSlot); + + if (isPrimary0Slot) + { + if (numP0 < VectorSetsPerPrimary) + { + vectorSetKeys.Add((key, (ushort)slot)); + numP0++; + } + } + else + { + if (numP1 < VectorSetsPerPrimary) + { + vectorSetKeys.Add((key, (ushort)slot)); + numP1++; + } + } + + ix++; + } + } + + // Remember how cluster looked right after it was stable + gossipFaultsAtTestStart = CountGossipFaults(captureLogWriter); + + // Start writing to this Vector Set + using var writeCancel = new CancellationTokenSource(); + + using var readWriteCon = ConnectionMultiplexer.Connect(context.clusterTestUtils.GetRedisConfig(context.endpoints)); + var readWriteDB = readWriteCon.GetDatabase(); + + var writeTasks = new Task[vectorSetKeys.Count]; + var writeResults = new ConcurrentBag<(byte[] Elem, byte[] Data, byte[] Attr, DateTime InsertionTime)>[vectorSetKeys.Count]; + + var mostRecentWrite = 0L; + + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, _) = vectorSetKeys[i]; + var written = writeResults[i] = new(); + + writeTasks[i] = + Task.Run( + async () => + { + // Force async + await Task.Yield(); + + var ix = 0; + + while (!writeCancel.IsCancellationRequested) + { + var elem = new byte[4]; + BinaryPrimitives.WriteInt32LittleEndian(elem, ix); + + var data = new byte[75]; + Random.Shared.NextBytes(data); + + var attr = new byte[100]; + Random.Shared.NextBytes(attr); + + while (true) + { + try + { + var addRes = (int)readWriteDB.Execute("VADD", [new RedisKey(key), "XB8", data, elem, "XPREQ8", "SETATTR", attr]); + ClassicAssert.AreEqual(1, addRes); + break; + } + catch (RedisServerException exc) + { + if (exc.Message.StartsWith("MOVED ")) + { + // This is fine, just try again if we're not cancelled + if (writeCancel.IsCancellationRequested) + { + return; + } + + continue; + } + + throw; + } + } + + var now = DateTime.UtcNow; + written.Add((elem, data, attr, now)); + + var mostRecentCopy = mostRecentWrite; + while (mostRecentCopy < now.Ticks) + { + var currentMostRecent = Interlocked.CompareExchange(ref mostRecentWrite, now.Ticks, mostRecentCopy); + if (currentMostRecent == mostRecentCopy) + { + break; + } + mostRecentCopy = currentMostRecent; + } + + ix++; + } + } + ); + } + + using var readCancel = new CancellationTokenSource(); + + var readTasks = new Task[vectorSetKeys.Count]; + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, _) = vectorSetKeys[i]; + var written = writeResults[i]; + readTasks[i] = + Task.Run( + async () => + { + await Task.Yield(); + + var successfulReads = 0; + + while (!readCancel.IsCancellationRequested) + { + var r = written.Count; + if (r == 0) + { + await Task.Delay(10); + continue; + } + + var (elem, data, _, _) = written.ToList()[Random.Shared.Next(r)]; + + var emb = (string[])readWriteDB.Execute("VEMB", [new RedisKey(key), elem]); + + // If we got data, make sure it's coherent + ClassicAssert.AreEqual(data.Length, emb.Length); + + for (var i = 0; i < data.Length; i++) + { + ClassicAssert.AreEqual(data[i], (byte)float.Parse(emb[i])); + } + + successfulReads++; + } + + return successfulReads; + } + ); + } + + await Task.Delay(1_000); + + ClassicAssert.IsTrue(writeResults.All(static r => !r.IsEmpty), "Should have seen some writes pre-migration"); + + // Task to flip back and forth between primaries + using var migrateCancel = new CancellationTokenSource(); + + var migrateTask = + Task.Run( + async () => + { + var hashSlotsOnP0 = new List(); + var hashSlotsOnP1 = new List(); + foreach (var (_, slot) in vectorSetKeys) + { + var isPrimary0Slot = slots.Any(x => x.nnInfo.Any(y => y.nodeid == primary0Id) && slot >= x.startSlot && slot <= x.endSlot); + if (isPrimary0Slot) + { + if (!hashSlotsOnP0.Contains(slot)) + { + hashSlotsOnP0.Add(slot); + } + } + else + { + if (!hashSlotsOnP1.Contains(slot)) + { + hashSlotsOnP1.Add(slot); + } + } + } + + var migrationTimes = new List(); + + var mostRecentMigration = 0L; + + while (!migrateCancel.IsCancellationRequested) + { + await Task.Delay(100); + + // Don't start another migration until we get at least one successful write + if (Interlocked.CompareExchange(ref mostRecentWrite, 0, 0) < mostRecentMigration) + { + continue; + } + + // Move 0 -> 1 + if (hashSlotsOnP0.Count > 0) + { + context.logger?.LogInformation("Starting 0 -> 1 migration of {slots}", string.Join(", ", hashSlotsOnP0)); + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary0, primary1, hashSlotsOnP0); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + var movedOffPrimary0 = !curPrimary0Slots.Any(h => hashSlotsOnP0.Contains(h)); + var movedOntoPrimary1 = hashSlotsOnP0.All(h => curPrimary1Slots.Contains(h)); + + if (movedOffPrimary0 && movedOntoPrimary1) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 0 -> 1 migration took too long"); + } + + // Move 1 -> 0 + if (hashSlotsOnP1.Count > 0) + { + context.logger?.LogInformation("Starting 1 -> 0 migration of {slots}", string.Join(", ", hashSlotsOnP1)); + using (var migrateToken = new CancellationTokenSource()) + { + migrateToken.CancelAfter(30_000); + + context.clusterTestUtils.MigrateSlots(primary1, primary0, hashSlotsOnP1); + context.clusterTestUtils.WaitForMigrationCleanup(Primary1Index, cancellationToken: migrateToken.Token); + context.clusterTestUtils.WaitForMigrationCleanup(Primary0Index, cancellationToken: migrateToken.Token); + } + + var nodePropSuccess = false; + var start = Stopwatch.GetTimestamp(); + while (Stopwatch.GetElapsedTime(start) < TimeSpan.FromSeconds(5)) + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + var movedOffPrimary1 = !curPrimary1Slots.Any(h => hashSlotsOnP1.Contains(h)); + var movedOntoPrimary0 = hashSlotsOnP1.All(h => curPrimary0Slots.Contains(h)); + + if (movedOffPrimary1 && movedOntoPrimary0) + { + nodePropSuccess = true; + break; + } + } + + ClassicAssert.IsTrue(nodePropSuccess, "Node propagation after 1 -> 0 migration took too long"); + } + + // Remember for next iteration + var now = DateTime.UtcNow; + mostRecentMigration = now.Ticks; + migrationTimes.Add(now); + + // Flip around assignment for next pass + (hashSlotsOnP0, hashSlotsOnP1) = (hashSlotsOnP1, hashSlotsOnP0); + } + + return migrationTimes; + } + ); + + await Task.Delay(10_000); + + migrateCancel.Cancel(); + var migrationTimes = await migrateTask; + + ClassicAssert.IsTrue(migrationTimes.Count > 2, "Should have moved back and forth at least twice"); + + writeCancel.Cancel(); + await Task.WhenAll(writeTasks); + + readCancel.Cancel(); + var readResults = await Task.WhenAll(readTasks); + ClassicAssert.IsTrue(readResults.All(static r => r > 0), "Should have successful reads on all Vector Sets"); + + // Check that everything written survived all the migrations + { + var curPrimary0Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary0, context.logger); + var curPrimary1Slots = context.clusterTestUtils.GetOwnedSlotsFromNode(primary1, context.logger); + + for (var i = 0; i < vectorSetKeys.Count; i++) + { + var (key, slot) = vectorSetKeys[i]; + + var isOnPrimary0 = curPrimary0Slots.Contains(slot); + var isOnPrimary1 = curPrimary1Slots.Contains(slot); + + ClassicAssert.IsTrue(isOnPrimary0 || isOnPrimary1, "Hash slot not found on either node"); + ClassicAssert.IsFalse(isOnPrimary0 && isOnPrimary1, "Hash slot found on both nodes"); + + var endpoint = isOnPrimary0 ? primary0 : primary1; + + foreach (var (elem, data, attr, _) in writeResults[i]) + { + var actualData = (string[])context.clusterTestUtils.Execute(endpoint, "VEMB", [key, elem]); + + for (var j = 0; j < data.Length; j++) + { + ClassicAssert.AreEqual(data[j], (byte)float.Parse(actualData[j])); + } + } + } + } + + } + catch (Exception exc) + { + var gossipFaultsAtEnd = CountGossipFaults(captureLogWriter); + + if (gossipFaultsAtTestStart != gossipFaultsAtEnd) + { + // The cluster broke in some way, so data loss is _expected_ + ClassicAssert.Inconclusive($"Gossip fault lead to data loss, Vector Set migration is (probably) not to blame: {exc.Message}"); + } + + // Anything else, keep it going up + throw; + } + + static int CountGossipFaults(CaptureLogWriter captureLogWriter) + { + var capturedLog = captureLogWriter.buffer.ToString(); + + // These kinds of errors happen from stressing migration independent of Vector Sets + // + // TODO: These out to be fixed outside of Vector Set work + var faultRound = capturedLog.Split("^GOSSIP round faulted^").Length - 1; + var faultResponse = capturedLog.Split("^GOSSIP faulted processing response^").Length - 1; + var faultMergeMap = capturedLog.Split("ClusterConfig.MergeSlotMap(").Length - 1; + + return faultRound + faultResponse + faultMergeMap; + } + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/DiskANNServiceTests.cs b/test/Garnet.test/DiskANNServiceTests.cs new file mode 100644 index 00000000000..ed347ba0f1a --- /dev/null +++ b/test/Garnet.test/DiskANNServiceTests.cs @@ -0,0 +1,338 @@ +using System; +using System.Buffers.Binary; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; + +namespace Garnet.test +{ + [TestFixture] + public class DiskANNServiceTests + { + private delegate void ReadCallbackDelegate(ulong context, uint numKeys, nint keysData, nuint keysLength, nint dataCallback, nint dataCallbackContext); + private delegate byte WriteCallbackDelegate(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength); + private delegate byte DeleteCallbackDelegate(ulong context, nint keyData, nuint keyLength); + private delegate byte ReadModifyWriteCallbackDelegate(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext); + + private sealed class ContextAndKeyComparer : IEqualityComparer<(ulong Context, byte[] Data)> + { + public bool Equals((ulong Context, byte[] Data) x, (ulong Context, byte[] Data) y) + => x.Context == y.Context && x.Data.AsSpan().SequenceEqual(y.Data); + public int GetHashCode([DisallowNull] (ulong Context, byte[] Data) obj) + { + HashCode hash = default; + hash.Add(obj.Context); + hash.AddBytes(obj.Data); + + return hash.ToHashCode(); + } + } + + GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, lowMemory: true); + server.Start(); + } + + [TearDown] + public void TearDown() + { + server.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + + [Test] + public void VADD() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var res1 = db.Execute("VADD", ["foo", "VALUES", "4", "1.0", "1.0", "1.0", "1.0", new byte[] { 1, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "VALUES", "4", "2.0", "2.0", "2.0", "2.0", new byte[] { 2, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + } + + [Test] + public void VSIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + var res1 = db.Execute("VADD", ["foo", "VALUES", "4", "1.0", "1.0", "1.0", "1.0", new byte[] { 1, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "VALUES", "4", "2.0", "2.0", "2.0", "2.0", new byte[] { 2, 0, 0, 0 }, "EF", "128", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res3 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "4", "0.0", "0.0", "0.0", "0.0", "COUNT", "5", "EF", "128"]); + ClassicAssert.AreEqual(2, res3.Length); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 2, 0, 0, 0 }))); + + var res4 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 1, 0, 0, 0 }, "COUNT", "5", "EF", "128"]); + ClassicAssert.AreEqual(2, res4.Length); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 2, 0, 0, 0 }))); + } + + [Test] + public void Recreate() + { + const ulong Context = 8; + + ConcurrentDictionary<(ulong Context, byte[] Key), byte[]> data = new(new ContextAndKeyComparer()); + + unsafe void ReadCallback( + ulong context, + uint numKeys, + nint keysData, + nuint keysLength, + nint dataCallback, + nint dataCallbackContext + ) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keysData), (int)keysLength); + + var remainingKeyDataSpan = keyDataSpan; + var dataCallbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])dataCallback; + + for (var index = 0; index < numKeys; index++) + { + var keyLen = BinaryPrimitives.ReadInt32LittleEndian(remainingKeyDataSpan); + var keyData = remainingKeyDataSpan.Slice(sizeof(int), keyLen); + + remainingKeyDataSpan = remainingKeyDataSpan[(sizeof(int) + keyLen)..]; + + var lookup = (context, keyData.ToArray()); + if (data.TryGetValue(lookup, out var res)) + { + fixed (byte* resPtr = res) + { + dataCallbackDel(index, dataCallbackContext, (nint)resPtr, (nuint)res.Length); + } + } + } + } + + unsafe byte WriteCallback(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + var writeDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)writeData), (int)writeLength); + + var lookup = (context, keyDataSpan.ToArray()); + + data[lookup] = writeDataSpan.ToArray(); + + return 1; + } + + unsafe byte DeleteCallback(ulong context, nint keyData, nuint keyLength) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + if (data.TryRemove(lookup, out _)) + { + return 1; + } + + return 0; + } + + unsafe byte ReadModifyWriteCallback(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint callback, nint callbackContext) + { + var keyDataSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.AsRef((byte*)keyData), (int)keyLength); + + var lookup = (context, keyDataSpan.ToArray()); + + var callbackDel = (delegate* unmanaged[Cdecl, SuppressGCTransition])callback; + + _ = data.AddOrUpdate( + lookup, + key => + { + var ret = new byte[writeLength]; + fixed (byte* retPtr = ret) + { + callbackDel(callbackContext, (nint)retPtr, (nuint)ret.Length); + } + + return ret; + }, + (key, old) => + { + // Garnet guarantees no concurrent RMW update same value, but ConcurrentDictionary doesn't; so use a lock + lock (old) + { + fixed (byte* oldPtr = old) + { + callbackDel(callbackContext, (nint)oldPtr, (nuint)old.Length); + } + + return old; + } + } + ); + + return 1; + } + + ReadCallbackDelegate readDel = ReadCallback; + WriteCallbackDelegate writeDel = WriteCallback; + DeleteCallbackDelegate deleteDel = DeleteCallback; + ReadModifyWriteCallbackDelegate rmwDel = ReadModifyWriteCallback; + + var readFuncPtr = Marshal.GetFunctionPointerForDelegate(readDel); + var writeFuncPtr = Marshal.GetFunctionPointerForDelegate(writeDel); + var deleteFuncPtr = Marshal.GetFunctionPointerForDelegate(deleteDel); + var rmwFuncPtr = Marshal.GetFunctionPointerForDelegate(rmwDel); + + var rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + + Span id = [0, 1, 2, 3]; + Span elem = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + Span attr = []; + + // Insert + unsafe + { + var insertRes = NativeDiskANNMethods.insert(Context, rawIndex, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr)), (nuint)attr.Length); + ClassicAssert.AreEqual(1, insertRes); + } + + Span filter = []; + + // Search + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_vector( + Context, rawIndex, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Drop does not cleanup data, so use it to simulate a process stop and recreate + { + NativeDiskANNMethods.drop_index(Context, rawIndex); + + rawIndex = NativeDiskANNMethods.create_index(Context, 75, 0, VectorQuantType.XPreQ8, 10, 10, readFuncPtr, writeFuncPtr, deleteFuncPtr, rmwFuncPtr); + } + + // Search value + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_vector( + Context, rawIndex, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem)), (nuint)elem.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Search element + unsafe + { + Span outputIds = stackalloc byte[1024]; + Span outputDistances = stackalloc float[64]; + + nint continuation = 0; + + var numRes = + NativeDiskANNMethods.search_element( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length, + 1f, outputDistances.Length, // SearchExplorationFactor must >= Count + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(filter)), (nuint)filter.Length, + 0, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputIds)), (nuint)outputIds.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(outputDistances)), (nuint)outputDistances.Length, + (nint)Unsafe.AsPointer(ref continuation) + ); + ClassicAssert.AreEqual(1, numRes); + + var firstResLen = BinaryPrimitives.ReadInt32LittleEndian(outputIds); + var firstRes = outputIds.Slice(sizeof(int), firstResLen); + ClassicAssert.IsTrue(firstRes.SequenceEqual(id)); + } + + // Remove + unsafe + { + var numRes = + NativeDiskANNMethods.remove( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id)), (nuint)id.Length + ); + ClassicAssert.AreEqual(1, numRes); + } + + // Insert + unsafe + { + Span id2 = [4, 5, 6, 7]; + Span elem2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + ReadOnlySpan attr2 = "{\"foo\": \"bar\"}"u8; + + var insertRes = NativeDiskANNMethods.insert( + Context, rawIndex, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(id2)), (nuint)id2.Length, + VectorValueType.XB8, (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(elem2)), (nuint)elem2.Length, + (nint)Unsafe.AsPointer(ref MemoryMarshal.GetReference(attr2)), (nuint)attr2.Length + ); + ClassicAssert.AreEqual(1, insertRes); + } + + GC.KeepAlive(deleteDel); + GC.KeepAlive(writeDel); + GC.KeepAlive(readDel); + GC.KeepAlive(rmwDel); + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/GarnetServerConfigTests.cs b/test/Garnet.test/GarnetServerConfigTests.cs index fb85e063925..f7fd3305bd5 100644 --- a/test/Garnet.test/GarnetServerConfigTests.cs +++ b/test/Garnet.test/GarnetServerConfigTests.cs @@ -938,6 +938,63 @@ public void ClusterReplicaResumeWithData() } } + [Test] + public void EnableVectorSetPreview() + { + // Command line args + { + // Default accepted + { + var args = Array.Empty(); + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // Switch is accepted + { + var args = new[] { "--enable-vector-set-preview" }; + var parseSuccessful = ServerSettingsManager.TryParseCommandLineArguments(args, out var options, out _, out _, out _); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + } + } + + // JSON args + { + // Default accepted + { + const string JSON = @"{ }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // False is accepted + { + const string JSON = @"{ ""EnableVectorSetPreview"": false }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsFalse(options.EnableVectorSetPreview); + } + + // True is accepted + { + const string JSON = @"{ ""EnableVectorSetPreview"": true }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsTrue(parseSuccessful); + ClassicAssert.IsTrue(options.EnableVectorSetPreview); + } + + // Invalid rejected + { + const string JSON = @"{ ""EnableVectorSetPreview"": ""foo"" }"; + var parseSuccessful = TryParseGarnetConfOptions(JSON, out var options, out var invalidOptions, out var exitGracefully); + ClassicAssert.IsFalse(parseSuccessful); + } + } + } + /// /// Import a garnet.conf file with the given contents /// diff --git a/test/Garnet.test/ReadOptimizedLockTests.cs b/test/Garnet.test/ReadOptimizedLockTests.cs new file mode 100644 index 00000000000..6451151938d --- /dev/null +++ b/test/Garnet.test/ReadOptimizedLockTests.cs @@ -0,0 +1,282 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using Garnet.common; +using NUnit.Framework; +using NUnit.Framework.Legacy; + +namespace Garnet.test +{ + public class ReadOptimizedLockTests + { + [TestCase(123)] + [TestCase(0)] + [TestCase(1)] + [TestCase(-1)] + [TestCase(int.MaxValue)] + [TestCase(int.MinValue)] + public void BasicLocks(int hash) + { + var lockContext = new ReadOptimizedLock(16); + + var gotShared0 = lockContext.TryAcquireSharedLock(hash, out var sharedToken0); + ClassicAssert.IsTrue(gotShared0); + + var gotShared1 = lockContext.TryAcquireSharedLock(hash, out var sharedToken1); + ClassicAssert.IsTrue(gotShared1); + + var gotExclusive = lockContext.TryAcquireExclusiveLock(hash, out _); + ClassicAssert.IsFalse(gotExclusive); + + lockContext.ReleaseSharedLock(sharedToken0); + lockContext.ReleaseSharedLock(sharedToken1); + + var gotExclusiveAgain = lockContext.TryAcquireExclusiveLock(hash, out var exclusiveToken); + ClassicAssert.IsTrue(gotExclusiveAgain); + + var gotSharedAgain = lockContext.TryAcquireSharedLock(hash, out _); + ClassicAssert.IsFalse(gotSharedAgain); + + lockContext.ReleaseExclusiveLock(exclusiveToken); + } + + [Test] + public void IndexCalculations() + { + const int Iters = 10_000; + + var lockContext = new ReadOptimizedLock(16); + + var rand = new Random(2025_11_17_00); + + var offsets = new HashSet(); + + for (var i = 0; i < Iters; i++) + { + offsets.Clear(); + + // Bunch of random hashes, including negative ones, to prove reasonable calculations + var hash = (int)rand.NextInt64(); + + var hintBase = (int)rand.NextInt64(); + + for (var j = 0; j < Environment.ProcessorCount; j++) + { + var offset = lockContext.CalculateIndex(hash, hintBase + j); + ClassicAssert.True(offsets.Add(offset)); + } + + foreach (var offset in offsets) + { + var tooClose = offsets.Except([offset]).Where(x => Math.Abs(x - offset) < ReadOptimizedLock.CacheLineSizeBytes / sizeof(int)); + ClassicAssert.IsEmpty(tooClose); + } + } + } + + [TestCase(1)] + [TestCase(4)] + [TestCase(16)] + [TestCase(64)] + [TestCase(128)] + public void Threaded(int hashCount) + { + // Guard some number of distinct value "slots" (defined by hashes) + // + // Runs threads which (randomly) either read values, write values, or read (then promote) and write. + // + // Reads check for correctness. + // Writes are done "plain" with no other locking or coherency enforcement. + + const int Iters = 100_000; + const int LongsPerSlot = 4; + + var lockContext = new ReadOptimizedLock(Math.Min(Math.Max(hashCount / 2, 1), Environment.ProcessorCount)); + + var threads = new Thread[Math.Max(Environment.ProcessorCount, 4)]; + + using var threadStart = new SemaphoreSlim(0, threads.Length); + + var globalRandom = new Random(2025_11_17_01); + + var hashes = new int[hashCount]; + for (var i = 0; i < hashes.Length; i++) + { + var nextHash = (int)globalRandom.NextInt64(); + if (hashes.AsSpan()[..i].Contains(nextHash)) + { + i--; + continue; + } + hashes[i] = nextHash; + } + + var values = new long[hashes.Length][]; + for (var i = 0; i < values.Length; i++) + { + values[i] = new long[LongsPerSlot]; + } + + // Spin up a bunch of mutators + for (var i = 0; i < threads.Length; i++) + { + var threadRandom = new Random(2025_11_17_01 + ((i + 1) * 100_000)); + + threads[i] = + new( + () => + { + threadStart.Wait(); + + for (var j = 0; j < Iters; j++) + { + var hashIx = threadRandom.Next(hashes.Length); + var hash = hashes[hashIx]; + + switch (threadRandom.Next(5)) + { + // Try: Read and verify + case 0: + { + if (lockContext.TryAcquireSharedLock(hash, out var sharedLockToken)) + { + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + lockContext.ReleaseSharedLock(sharedLockToken); + } + else + { + j--; + } + } + break; + + // Try: Lock, modify + case 1: + { + if (lockContext.TryAcquireExclusiveLock(hash, out var exclusiveLockToken)) + { + var sub = values[hashIx]; + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + else + { + j--; + } + } + break; + + // Demand: Read and verify + case 2: + { + lockContext.AcquireSharedLock(hash, out var sharedLockToken); + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + lockContext.ReleaseSharedLock(sharedLockToken); + } + + break; + + // Demand: Lock, modify + case 3: + { + lockContext.AcquireExclusiveLock(hash, out var exclusiveLockToken); + var sub = values[hashIx]; + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + + break; + + // Try: Read, verify, promote, modify + case 4: + { + if (lockContext.TryAcquireSharedLock(hash, out var sharedLockToken)) + { + var sub = values[hashIx]; + for (var k = 1; k < sub.Length; k++) + { + ClassicAssert.AreEqual(sub[0], sub[k]); + } + + if (lockContext.TryPromoteSharedLock(hash, sharedLockToken, out var exclusiveLockToken)) + { + var newValue = threadRandom.NextInt64(); + for (var k = 0; k < sub.Length; k++) + { + sub[k] = newValue; + } + + lockContext.ReleaseExclusiveLock(exclusiveLockToken); + } + else + { + lockContext.ReleaseSharedLock(sharedLockToken); + + j--; + } + } + else + { + j--; + } + } + + break; + + // There is no Demand version of Promote because that is not safe in general + + default: throw new InvalidOperationException($"Unexpected op"); + } + } + } + ) + { + Name = $"{nameof(Threaded)} #{i}" + }; + threads[i].Start(); + } + + // Let threads run + _ = threadStart.Release(threads.Length); + + // Wait for threads to finish + foreach (var thread in threads) + { + thread.Join(); + } + + // Validate correctness of final state + foreach (var vals in values) + { + for (var k = 1; k < vals.Length; k++) + { + ClassicAssert.AreEqual(vals[0], vals[k]); + } + } + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/Resp/ACL/RespCommandTests.cs b/test/Garnet.test/Resp/ACL/RespCommandTests.cs index 3b02ac68853..ac5b1b79190 100644 --- a/test/Garnet.test/Resp/ACL/RespCommandTests.cs +++ b/test/Garnet.test/Resp/ACL/RespCommandTests.cs @@ -2033,6 +2033,35 @@ static async Task DoClusterReplicateAsync(GarnetClient client) } } + [Test] + public async Task ClusterReserveACLsAsync() + { + // All cluster command "success" is a thrown exception, because clustering is disabled + + await CheckCommandsAsync( + "CLUSTER RESERVE", + [DoClusterReserveAsync] + ); + + static async Task DoClusterReserveAsync(GarnetClient client) + { + try + { + await client.ExecuteForStringResultAsync("CLUSTER", ["RESERVE", "VECTOR_SET_CONTEXTS", "16"]); + Assert.Fail("Shouldn't be reachable, cluster isn't enabled"); + } + catch (Exception e) + { + if (e.Message == "ERR This instance has cluster support disabled") + { + return; + } + + throw; + } + } + } + [Test] public async Task ClusterResetACLsAsync() { @@ -7484,6 +7513,209 @@ static async Task DoUnwatchAsync(GarnetClient client) } } + [Test] + public async Task VAddACLsAsync() + { + await CheckCommandsAsync( + "VADD", + [DoVAddAsync] + ); + + static async Task DoVAddAsync(GarnetClient client) + { + var elem = Encoding.ASCII.GetString("\x0\x1\x2\x3"u8); + + long val = await client.ExecuteForLongResultAsync("VADD", ["foo", "REDUCE", "50", "VALUES", "4", "1.0", "2.0", "3.0", "4.0", elem, "CAS", "Q8", "EF", "16", "SETATTR", "{ 'hello': 'world' }", "M", "32"]); + ClassicAssert.AreEqual(1, val); + } + } + + [Test] + public async Task VCardACLsAsync() + { + await CheckCommandsAsync( + "VCARD", + [DoVCardAsync] + ); + + static async Task DoVCardAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VCARD", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VDimACLsAsync() + { + await CheckCommandsAsync( + "VDIM", + [DoVDimAsync] + ); + + static async Task DoVDimAsync(GarnetClient client) + { + try + { + _ = await client.ExecuteForStringResultAsync("VDIM", ["foo"]); + ClassicAssert.Fail("Shouldn't be reachable"); + } + catch (Exception e) when (e.Message.Equals("ERR Key not found")) + { + // Excepted + } + } + } + + [Test] + public async Task VEmbACLsAsync() + { + await CheckCommandsAsync( + "VEMB", + [DoVEmbAsync] + ); + + static async Task DoVEmbAsync(GarnetClient client) + { + string[] val = await client.ExecuteForStringArrayResultAsync("VEMB", ["foo", "bar"]); + ClassicAssert.AreEqual(0, val.Length); + } + } + + [Test] + public async Task VGetAttrACLsAsync() + { + await CheckCommandsAsync( + "VGETATTR", + [DoVGetAttrAsync] + ); + + static async Task DoVGetAttrAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VGETATTR", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VInfoACLsAsync() + { + await CheckCommandsAsync( + "VINFO", + [DoVInfoAsync] + ); + + static async Task DoVInfoAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VINFO", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VIsMemberACLsAsync() + { + await CheckCommandsAsync( + "VISMEMBER", + [DoVIsMemberAsync] + ); + + static async Task DoVIsMemberAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VISMEMBER", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VLinksACLsAsync() + { + await CheckCommandsAsync( + "VLINKS", + [DoVLinksAsync] + ); + + static async Task DoVLinksAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VLINKS", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VRandMemberACLsAsync() + { + await CheckCommandsAsync( + "VRANDMEMBER", + [DoVRandMemberAsync] + ); + + static async Task DoVRandMemberAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VRANDMEMBER", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VRemACLsAsync() + { + await CheckCommandsAsync( + "VREM", + [DoVRemAsync] + ); + + static async Task DoVRemAsync(GarnetClient client) + { + long val = await client.ExecuteForLongResultAsync("VREM", ["foo", Encoding.UTF8.GetString("\0\0\0\0"u8)]); + ClassicAssert.AreEqual(0, val); + } + } + + [Test] + public async Task VSetAttrACLsAsync() + { + await CheckCommandsAsync( + "VSETATTR", + [DoVSetAttrAsync] + ); + + static async Task DoVSetAttrAsync(GarnetClient client) + { + // TODO: this is a placeholder implementation + + string val = await client.ExecuteForStringResultAsync("VSETATTR", ["foo"]); + ClassicAssert.AreEqual("OK", val); + } + } + + [Test] + public async Task VSimACLsAsync() + { + await CheckCommandsAsync( + "VSIM", + [DoVSimAsync] + ); + + static async Task DoVSimAsync(GarnetClient client) + { + string[] val = await client.ExecuteForStringArrayResultAsync("VSIM", ["foo", "ELE", "bar"]); + ClassicAssert.AreEqual(0, val.Length); + } + } + /// /// Take a command (or subcommand, with a space) and check that adding and removing /// command, subcommand, and categories ACLs behaves as expected. diff --git a/test/Garnet.test/RespCustomCommandTests.cs b/test/Garnet.test/RespCustomCommandTests.cs index a3b3d372be2..0f2062b2b1d 100644 --- a/test/Garnet.test/RespCustomCommandTests.cs +++ b/test/Garnet.test/RespCustomCommandTests.cs @@ -210,8 +210,7 @@ public override unsafe void Main(TGarnetApi garnetApi, ref CustomPro ArgSlice valForKey1 = new ArgSlice(valuePtr, valueToMessWith.Count); input.parseState.InitializeWithArgument(valForKey1); // since we are setting with retain to etag, this change should be reflected in an etag update - SpanByte sameKeyToUse = key.SpanByte; - garnetApi.SET_Conditional(ref sameKeyToUse, ref input); + garnetApi.SET_Conditional(key, ref input); } diff --git a/test/Garnet.test/RespSortedSetTests.cs b/test/Garnet.test/RespSortedSetTests.cs index 21d094df7ed..3938910f574 100644 --- a/test/Garnet.test/RespSortedSetTests.cs +++ b/test/Garnet.test/RespSortedSetTests.cs @@ -24,7 +24,10 @@ namespace Garnet.test SpanByteAllocator>>, BasicContext>, - GenericAllocator>>>>; + GenericAllocator>>>, + BasicContext, + SpanByteAllocator>>>; [TestFixture] public class RespSortedSetTests diff --git a/test/Garnet.test/RespVectorSetTests.cs b/test/Garnet.test/RespVectorSetTests.cs new file mode 100644 index 00000000000..1659f0c98ef --- /dev/null +++ b/test/Garnet.test/RespVectorSetTests.cs @@ -0,0 +1,1284 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Buffers; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; +using Garnet.server; +using NUnit.Framework; +using NUnit.Framework.Legacy; +using StackExchange.Redis; +using Tsavorite.core; + +namespace Garnet.test +{ + [TestFixture] + public class RespVectorSetTests + { + GarnetServer server; + + [SetUp] + public void Setup() + { + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true); + + server.Start(); + } + + [TearDown] + public void TearDown() + { + server.Dispose(); + TestUtils.DeleteDirectory(TestUtils.MethodTestDir); + } + + [Test] + public void DisabledWithFeatureFlag() + { + // Restart with Vector Sets disabled + TearDown(); + + TestUtils.DeleteDirectory(TestUtils.MethodTestDir, wait: true); + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, enableAOF: true, enableVectorSetPreview: false); + + server.Start(); + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + ReadOnlySpan vectorSetCommands = [RespCommand.VADD, RespCommand.VCARD, RespCommand.VDIM, RespCommand.VEMB, RespCommand.VGETATTR, RespCommand.VINFO, RespCommand.VISMEMBER, RespCommand.VLINKS, RespCommand.VRANDMEMBER, RespCommand.VREM, RespCommand.VSETATTR, RespCommand.VSIM]; + foreach (var cmd in vectorSetCommands) + { + // Should all fault before any validation + var exc = ClassicAssert.Throws(() => db.Execute(cmd.ToString())); + ClassicAssert.AreEqual("ERR Vector Set (preview) commands are not enabled", exc.Message); + } + + } + + [Test] + public void VADD() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // VALUES + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 1, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + var float3 = new float[75]; + float3[0] = 5f; + for (var i = 1; i < float3.Length; i++) + { + float3[i] = float3[i - 1] + 1; + } + + // FP32 + var res3 = db.Execute("VADD", ["foo", "REDUCE", "50", "FP32", MemoryMarshal.Cast(float3).ToArray(), new byte[] { 2, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var byte4 = new byte[75]; + byte4[0] = 9; + for (var i = 1; i < byte4.Length; i++) + { + byte4[i] = (byte)(byte4[i - 1] + 1); + } + + // XB8 + var res4 = db.Execute("VADD", ["foo", "REDUCE", "50", "XB8", byte4, new byte[] { 3, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res4); + + // TODO: exact duplicates - what does Redis do? + + // Add without specifying reductions after first vector + var res5 = db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "150.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res5); + + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "VALUES", "4", "5.0", "6.0", "7.0", "8.0", new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 4 but set has 75", exc1.Message); + + // Add without specifying quantization after first vector + var res6 = db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "160.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 2 }, "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res6); + + // Add without specifying EF after first vector + var res7 = db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "170.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 3 }, "CAS", "Q8", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res7); + + // Add without specifying M after first vector + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "75", "180.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 4 }, "CAS", "Q8", "EF", "16"])); + ClassicAssert.AreEqual("ERR asked M value mismatch with existing vector set", exc2.Message); + + // Mismatch vector size for projection + var exc3 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "50", "VALUES", "5", "1.0", "2.0", "3.0", "4.0", "5.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 5 but set has 75", exc3.Message); + } + + [Test] + public void VADDXPREQB8() + { + // Extra validation is required for this extension quantifier + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // REDUCE not allowed + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "REDUCE", "2", "VALUES", "4", "1.0", "2.0", "3.0", "4.0", new byte[] { 0, 0, 0, 0 }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR asked quantization mismatch with existing vector set", exc1.Message); + + // Create a vector set + var res1 = db.Execute("VADD", ["fizz", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, (int)res1); + + // Element name too short + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "VALUES", "4", "1.0", "2.0", "3.0", "4.0", new byte[] { 0 }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 4 but set has 75", exc2.Message); + + // Element name too long + var exc3 = ClassicAssert.Throws(() => db.Execute("VADD", ["fizz", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 1, 2, 3, 4, }, "XPREQ8"])); + ClassicAssert.AreEqual("ERR XPREQ8 requires 4-byte element ids", exc3.Message); + } + + [Test] + public void VADDErrors() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var vectorSetKey = $"{nameof(VADDErrors)}_{Guid.NewGuid()}"; + + // Bad arity + var exc1 = ClassicAssert.Throws(() => db.Execute("VADD")); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc1.Message); + var exc2 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc2.Message); + var exc3 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "FP32"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc3.Message); + var exc4 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc4.Message); + var exc5 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc5.Message); + var exc6 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "1.0"])); + ClassicAssert.AreEqual("ERR wrong number of arguments for 'VADD' command", exc6.Message); + + // Reduce after vector + var exc7 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "1.0", "2.0", "bar", "REDUCE", "1"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc7.Message); + + // Duplicate flags + // TODO: Redis doesn't error on these which seems... wrong, confirm with them + //var exc8 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "CAS", "CAS"])); + //var exc9 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "NOQUANT", "Q8"])); + //var exc10 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF", "1", "EF", "1"])); + //var exc11 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "SETATTR", "abc", "SETATTR", "abc"])); + //var exc12 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "5", "M", "5"])); + + // M out of range (Redis imposes M >= 4 and m <= 4096 + var exc13 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "1"])); + ClassicAssert.AreEqual("ERR invalid M", exc13.Message); + var exc14 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M", "10000"])); + ClassicAssert.AreEqual("ERR invalid M", exc14.Message); + + // Missing/bad option value + var exc20 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc20.Message); + var exc21 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "EF", "0"])); + ClassicAssert.AreEqual("ERR invalid EF", exc21.Message); + var exc22 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "SETATTR"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc22.Message); + var exc23 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "M"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc23.Message); + var exc24 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "2.0", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc24.Message); + var exc25 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "0", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc25.Message); + var exc26 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "fizz", "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc26.Message); + + // Unknown option + var exc27 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "1", "2.0", "bar", "FOO"])); + ClassicAssert.AreEqual("ERR invalid option after element", exc27.Message); + + // Malformed FP32 + var binary = new float[] { 1, 2, 3 }; + var blob = MemoryMarshal.Cast(binary)[..^1].ToArray(); + var exc15 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "FP32", blob, "bar"])); + ClassicAssert.AreEqual("ERR invalid vector specification", exc15.Message); + + // Mismatch after creating a vector set + _ = db.KeyDelete(vectorSetKey); + + _ = db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 1, 0 }, "NOQUANT", "EF", "6", "M", "10"]); + + var exc16 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "2", "1.0", "2.0", "fizz", "NOQUANT", "EF", "6", "M", "10"])); + ClassicAssert.AreEqual("ERR Vector dimension mismatch - got 2 but set has 75", exc16.Message); + var exc17 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "fizz", "Q8", "EF", "6", "M", "10"])); + ClassicAssert.AreEqual("ERR asked quantization mismatch with existing vector set", exc17.Message); + var exc18 = ClassicAssert.Throws(() => db.Execute("VADD", [vectorSetKey, "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "fizz", "NOQUANT", "EF", "12", "M", "20"])); + ClassicAssert.AreEqual("ERR asked M value mismatch with existing vector set", exc18.Message); + + // TODO: Redis doesn't appear to validate attributes... so that's weird + } + + [Test] + public void VEMB() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(75, res2.Length); + for (var i = 0; i < 75; i += 4) + { + ClassicAssert.AreEqual(float.Parse("1.0"), float.Parse(res2[i + 0])); + if (i + 1 < res2.Length) + { + ClassicAssert.AreEqual(float.Parse("2.0"), float.Parse(res2[i + 1])); + } + + if (i + 2 < res2.Length) + { + ClassicAssert.AreEqual(float.Parse("3.0"), float.Parse(res2[i + 2])); + } + + if (i + 3 < res2.Length) + { + ClassicAssert.AreEqual(float.Parse("4.0"), float.Parse(res2[i + 3])); + } + } + + var res3 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 1 }]); + ClassicAssert.AreEqual(0, res3.Length); + } + + [Test] + public void VectorSetOpacity() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = ClassicAssert.Throws(() => db.StringGet("foo")); + ClassicAssert.True(res2.Message.Contains("WRONGTYPE")); + } + + [Test] + public void VectorElementOpacity() + { + // Check that we can't touch an element with GET despite it also being in the main store + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = (string)db.StringGet(new byte[] { 0, 0, 0, 0 }); + ClassicAssert.IsNull(res2); + + var res3 = db.KeyDelete(new byte[] { 0, 0, 0, 0 }); + ClassicAssert.IsFalse(res3); + + var res4 = db.StringSet(new byte[] { 0, 0, 0, 0 }, "def", when: When.NotExists); + ClassicAssert.IsTrue(res4); + + Span buffer = stackalloc byte[128]; + + // Check we haven't messed up the element + var res7 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(75, res7.Length); + for (var i = 0; i < res7.Length; i++) + { + var expected = + (i % 4) switch + { + 0 => float.Parse("1.0"), + 1 => float.Parse("2.0"), + 2 => float.Parse("3.0"), + 3 => float.Parse("4.0"), + _ => throw new InvalidOperationException(), + }; + + ClassicAssert.AreEqual(expected, float.Parse(res7[i])); + } + } + + [Test] + public void VSIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + var res3 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res3.Length); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res3.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + var res4 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }, "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res4.Length); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res4.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // FP32 + var float5 = new float[75]; + float5[0] = 3; + for (var i = 1; i < float5.Length; i++) + { + float5[i] = float5[i - 1] + 0.1f; + } + var res5 = (byte[][])db.Execute("VSIM", ["foo", "FP32", MemoryMarshal.Cast(float5).ToArray(), "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res5.Length); + ClassicAssert.IsTrue(res5.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res5.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // XB8 + var byte6 = new byte[75]; + byte6[0] = 10; + for (var i = 1; i < byte6.Length; i++) + { + byte6[i] = (byte)(byte6[i - 1] + 1); + } + var res6 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte6, "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res6.Length); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // COUNT > EF + var byte7 = new byte[75]; + byte7[0] = 20; + for (var i = 1; i < byte7.Length; i++) + { + byte7[i] = (byte)(byte7[i - 1] + 1); + } + var res7 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte7, "COUNT", "100", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(2, res7.Length); + ClassicAssert.IsTrue(res7.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res7.Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + + // WITHSCORES + var res8 = (byte[][])db.Execute("VSIM", ["foo", "XB8", byte7, "COUNT", "100", "EPSILON", "1.0", "EF", "40", "WITHSCORES"]); + ClassicAssert.AreEqual(4, res8.Length); + ClassicAssert.IsTrue(res8.Where(static (x, ix) => (ix % 2) == 0).Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 0 }))); + ClassicAssert.IsTrue(res8.Where(static (x, ix) => (ix % 2) == 0).Any(static x => x.SequenceEqual(new byte[] { 0, 0, 0, 1 }))); + ClassicAssert.IsFalse(double.IsNaN(double.Parse(Encoding.UTF8.GetString(res8[1])))); + ClassicAssert.IsFalse(double.IsNaN(double.Parse(Encoding.UTF8.GetString(res8[3])))); + } + + [Test] + public void VSIMWithAttribs() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "fizz buzz"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Equivalent to no attribute + var res3 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 2 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", ""]); + ClassicAssert.AreEqual(1, (int)res3); + + // Actually no attribute + var res4 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "120.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 3 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res4); + + // Very long attribute + var bigAttr = Enumerable.Repeat((byte)'a', 1_024).ToArray(); + var res5 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "130.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 4 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", bigAttr]); + ClassicAssert.AreEqual(1, (int)res5); + + var res6 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "140.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40", "WITHATTRIBS"]); + ClassicAssert.AreEqual(10, res6.Length); + for (var i = 0; i < res6.Length; i += 2) + { + var id = res6[i]; + var attr = res6[i + 1]; + + if (id.SequenceEqual(new byte[] { 0, 0, 0, 0 })) + { + ClassicAssert.True(attr.SequenceEqual("hello world"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 1 })) + { + ClassicAssert.True(attr.SequenceEqual("fizz buzz"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 2 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 3 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 4 })) + { + ClassicAssert.True(bigAttr.SequenceEqual(attr)); + } + else + { + ClassicAssert.Fail("Unexpected id"); + } + } + + // WITHSCORES + var res7 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "140.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40", "WITHATTRIBS", "WITHSCORES"]); + ClassicAssert.AreEqual(15, res7.Length); + for (var i = 0; i < res7.Length; i += 3) + { + var id = res7[i]; + var score = double.Parse(Encoding.UTF8.GetString(res7[i + 1])); + var attr = res7[i + 2]; + + ClassicAssert.IsFalse(double.IsNaN(score)); + + if (id.SequenceEqual(new byte[] { 0, 0, 0, 0 })) + { + ClassicAssert.True(attr.SequenceEqual("hello world"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 1 })) + { + ClassicAssert.True(attr.SequenceEqual("fizz buzz"u8.ToArray())); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 2 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 3 })) + { + ClassicAssert.AreEqual(0, attr.Length); + } + else if (id.SequenceEqual(new byte[] { 0, 0, 0, 4 })) + { + ClassicAssert.True(bigAttr.SequenceEqual(attr)); + } + else + { + ClassicAssert.Fail("Unexpected id"); + } + } + } + + + [Test] + public void VDIM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VDIM", "foo"); + ClassicAssert.AreEqual(3, (int)res2); + + var res3 = db.Execute("VADD", ["bar", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var res4 = db.Execute("VDIM", "bar"); + ClassicAssert.AreEqual(75, (int)res4); + + var exc1 = ClassicAssert.Throws(() => db.Execute("VDIM", "fizz")); + ClassicAssert.IsTrue(exc1.Message.Contains("Key not found")); + + // TODO: Add WRONGTYPE behavior check once implemented + } + + [Test] + public void DeleteVectorSet() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var res1 = db.Execute("VADD", ["foo", "REDUCE", "3", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.KeyDelete("foo"); + ClassicAssert.IsTrue(res2); + + var res3 = db.Execute("VADD", ["fizz", "REDUCE", "3", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res3); + + var res4 = db.StringSet("buzz", "abc"); + ClassicAssert.IsTrue(res4); + + var res5 = db.KeyDelete(["fizz", "buzz"]); + ClassicAssert.AreEqual(2, res5); + } + + [Test] + public void RepeatedVectorSetDeletes() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var bytes1 = new byte[75]; + var bytes2 = new byte[75]; + var bytes3 = new byte[75]; + bytes1[0] = 1; + bytes2[0] = 75; + bytes3[0] = 128; + for (var i = 1; i < bytes1.Length; i++) + { + bytes1[i] = (byte)(bytes1[i - 1] + 1); + bytes2[i] = (byte)(bytes2[i - 1] + 1); + bytes3[i] = (byte)(bytes3[i - 1] + 1); + } + + for (var i = 0; i < 1_000; i++) + { + var delRes = (int)db.Execute("DEL", ["foo"]); + + if (i != 0) + { + ClassicAssert.AreEqual(1, delRes); + } + else + { + ClassicAssert.AreEqual(0, delRes); + } + + var addRes1 = (int)db.Execute("VADD", ["foo", "XB8", bytes1, new byte[] { 0, 0, 0, 0 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes1); + + var addRes2 = (int)db.Execute("VADD", ["foo", "XB8", bytes2, new byte[] { 0, 0, 0, 1 }, "XPREQ8"]); + ClassicAssert.AreEqual(1, addRes2); + + var readExc = ClassicAssert.Throws(() => db.Execute("GET", ["foo"])); + ClassicAssert.IsTrue(readExc.Message.Equals("WRONGTYPE Operation against a key holding the wrong kind of value."), $"In iteration: {i}"); + + var query = (byte[][])db.Execute("VSIM", ["foo", "XB8", bytes3]); + + if (query is null) + { + try + { + var res = db.Execute("FOO"); + Console.WriteLine($"After unexpected null, got: {res}"); + } + catch { } + } + else if (query.Length != 2) + { + Console.WriteLine($"Wrong length {query.Length} != 2 response was"); + for (var j = 0; j < query.Length; j++) + { + var txt = Encoding.UTF8.GetString(query[j]); + Console.WriteLine("---"); + Console.WriteLine(txt); + } + } + + ClassicAssert.AreEqual(2, query.Length, $"In iteration: {i}"); + } + } + + [Test] + public unsafe void VectorReadBatchVariants() + { + // Single key, 4 byte keys + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234 }; + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 64, 1, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(64, keyCopy.GetNamespaceInPayload()); + ClassicAssert.IsTrue(keyCopy.AsReadOnlySpan().SequenceEqual(MemoryMarshal.Cast(data.AsSpan().Slice(1, 1)))); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(1, iters); + } + } + + // Multiple keys, 4 byte keys + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234, 4, 5678, 4, 0123, 4, 9999, 4, 0000, 4, int.MaxValue, 4, int.MinValue }; + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 32, 7, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(32, keyCopy.GetNamespaceInPayload()); + + var offset = i * 2 + 1; + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = MemoryMarshal.Cast(data.AsSpan().Slice(offset, 1)); + ClassicAssert.IsTrue(keyCopyData.SequenceEqual(expectedData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(7, iters); + } + } + + // Multiple keys, 4 byte keys, random order + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var data = new int[] { 4, 1234, 4, 5678, 4, 0123, 4, 9999, 4, 0000, 4, int.MaxValue, 4, int.MinValue }; + fixed (int* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length * sizeof(int)); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 16, 7, keyData); + + var rand = new Random(2025_10_06_00); + + for (var j = 0; j < 1_000; j++) + { + var i = rand.Next(batch.Count); + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(16, keyCopy.GetNamespaceInPayload()); + + var offset = i * 2 + 1; + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = MemoryMarshal.Cast(data.AsSpan().Slice(offset, 1)); + ClassicAssert.IsTrue(keyCopyData.SequenceEqual(expectedData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + } + } + + // Single key, variable length + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 8, 1, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(8, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(1, iters); + } + } + + // Multiple keys, variable length + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var key1 = "fizz"u8.ToArray(); + var key2 = "the quick brown fox jumps over the lazy dog"u8.ToArray(); + var key3 = "CF29E323-E376-4BC4-AB63-FCFD371EB445"u8.ToArray(); + var key4 = Array.Empty(); + var key5 = new byte[] { 1 }; + var key6 = new byte[] { 2, 3 }; + var key7 = new byte[] { 4, 5, 6 }; + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .Concat( + MemoryMarshal.Cast([key1.Length]).ToArray() + ) + .Concat( + key1 + ) + .Concat( + MemoryMarshal.Cast([key2.Length]).ToArray() + ) + .Concat( + key2 + ) + .Concat( + MemoryMarshal.Cast([key3.Length]).ToArray() + ) + .Concat( + key3 + ) + .Concat( + MemoryMarshal.Cast([key4.Length]).ToArray() + ) + .Concat( + key4 + ) + .Concat( + MemoryMarshal.Cast([key5.Length]).ToArray() + ) + .Concat( + key5 + ) + .Concat( + MemoryMarshal.Cast([key6.Length]).ToArray() + ) + .Concat( + key6 + ) + .Concat( + MemoryMarshal.Cast([key7.Length]).ToArray() + ) + .Concat( + key7 + ) + .ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 4, 8, keyData); + + var iters = 0; + for (var i = 0; i < batch.Count; i++) + { + iters++; + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + 1 => key1.Length, + 2 => key2.Length, + 3 => key3.Length, + 4 => key4.Length, + 5 => key5.Length, + 6 => key6.Length, + 7 => key7.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + 1 => key0.Length + 2 * sizeof(int), + 2 => key0.Length + key1.Length + 3 * sizeof(int), + 3 => key0.Length + key1.Length + key2.Length + 4 * sizeof(int), + 4 => key0.Length + key1.Length + key2.Length + key3.Length + 5 * sizeof(int), + 5 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + 6 * sizeof(int), + 6 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + 7 * sizeof(int), + 7 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + key6.Length + 8 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(4, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + + ClassicAssert.AreEqual(8, iters); + } + } + + // Multiple keys, variable length, random access + { + VectorInput input = default; + input.Callback = 5678; + input.CallbackContext = 9012; + + var key0 = "hello"u8.ToArray(); + var key1 = "fizz"u8.ToArray(); + var key2 = "the quick brown fox jumps over the lazy dog"u8.ToArray(); + var key3 = "CF29E323-E376-4BC4-AB63-FCFD371EB445"u8.ToArray(); + var key4 = Array.Empty(); + var key5 = new byte[] { 1 }; + var key6 = new byte[] { 2, 3 }; + var key7 = new byte[] { 4, 5, 6 }; + var data = + MemoryMarshal.Cast([key0.Length]) + .ToArray() + .Concat(key0) + .Concat( + MemoryMarshal.Cast([key1.Length]).ToArray() + ) + .Concat( + key1 + ) + .Concat( + MemoryMarshal.Cast([key2.Length]).ToArray() + ) + .Concat( + key2 + ) + .Concat( + MemoryMarshal.Cast([key3.Length]).ToArray() + ) + .Concat( + key3 + ) + .Concat( + MemoryMarshal.Cast([key4.Length]).ToArray() + ) + .Concat( + key4 + ) + .Concat( + MemoryMarshal.Cast([key5.Length]).ToArray() + ) + .Concat( + key5 + ) + .Concat( + MemoryMarshal.Cast([key6.Length]).ToArray() + ) + .Concat( + key6 + ) + .Concat( + MemoryMarshal.Cast([key7.Length]).ToArray() + ) + .Concat( + key7 + ) + .ToArray(); + fixed (byte* dataPtr = data) + { + var keyData = SpanByte.FromPinnedPointer((byte*)dataPtr, data.Length); + var batch = new VectorManager.VectorReadBatch(input.Callback, input.CallbackContext, 4, 8, keyData); + + var rand = new Random(2025_10_06_01); + + for (var j = 0; j < 1_000; j++) + { + var i = rand.Next(batch.Count); + + // Validate Input + batch.GetInput(i, out var inputCopy); + ClassicAssert.AreEqual((nint)input.Callback, (nint)inputCopy.Callback); + ClassicAssert.AreEqual(input.CallbackContext, inputCopy.CallbackContext); + ClassicAssert.AreEqual(i, inputCopy.Index); + + // Validate key + var expectedLength = + i switch + { + 0 => key0.Length, + 1 => key1.Length, + 2 => key2.Length, + 3 => key3.Length, + 4 => key4.Length, + 5 => key5.Length, + 6 => key6.Length, + 7 => key7.Length, + _ => throw new InvalidOperationException("Unexpected index"), + }; + var expectedStart = + i switch + { + 0 => 0 + 1 * sizeof(int), + 1 => key0.Length + 2 * sizeof(int), + 2 => key0.Length + key1.Length + 3 * sizeof(int), + 3 => key0.Length + key1.Length + key2.Length + 4 * sizeof(int), + 4 => key0.Length + key1.Length + key2.Length + key3.Length + 5 * sizeof(int), + 5 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + 6 * sizeof(int), + 6 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + 7 * sizeof(int), + 7 => key0.Length + key1.Length + key2.Length + key3.Length + key4.Length + key5.Length + key6.Length + 8 * sizeof(int), + _ => throw new InvalidOperationException("Unexpected index"), + }; + + batch.GetKey(i, out var keyCopy); + ClassicAssert.AreEqual(4, keyCopy.GetNamespaceInPayload()); + var keyCopyData = keyCopy.AsReadOnlySpan(); + var expectedData = data.AsSpan().Slice(expectedStart, expectedLength); + ClassicAssert.IsTrue(expectedData.SequenceEqual(keyCopyData)); + + // Validate output doesn't throw + batch.GetOutput(i, out _); + } + } + } + } + + [Test] + public void RecreateIndexesOnRestore() + { + var addData1 = Enumerable.Range(0, 75).Select(static x => (byte)x).ToArray(); + var addData2 = Enumerable.Range(0, 75).Select(static x => (byte)(x * 2)).ToArray(); + var queryData = addData1.ToArray(); + queryData[0]++; + + // VADD + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "fizz buzz"]); + ClassicAssert.AreEqual(1, (int)res2); + } + } + + // VSIM with vector + { + byte[][] expectedVSimResult; + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + expectedVSimResult = (byte[][])db.Execute("VSIM", ["foo", "XB8", queryData]); + ClassicAssert.AreEqual(1, expectedVSimResult.Length); +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (byte[][])db.Execute("VSIM", ["foo", "XB8", queryData]); + ClassicAssert.AreEqual(expectedVSimResult.Length, res2.Length); + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.IsTrue(expectedVSimResult[i].AsSpan().SequenceEqual(res2[i])); + } + } + } + + // VSIM with element + { + byte[][] expectedVSimResult; + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + expectedVSimResult = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(2, expectedVSimResult.Length); +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (byte[][])db.Execute("VSIM", ["foo", "ELE", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(expectedVSimResult.Length, res2.Length); + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.IsTrue(expectedVSimResult[i].AsSpan().SequenceEqual(res2[i])); + } + } + } + + // VDIM + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (int)db.Execute("VDIM", ["foo"]); + ClassicAssert.AreEqual(addData1.Length, res2); + } + } + + // VEMB + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res2 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(res2.Length, addData1.Length); + + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.AreEqual((float)addData1[i], float.Parse(res2[i])); + } + } + } + + // VREM + { + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var s = redis.GetServers()[0]; + var db = redis.GetDatabase(0); + + _ = db.KeyDelete("foo"); + + var res1 = db.Execute("VADD", ["foo", "XB8", addData1, new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "XB8", addData2, new byte[] { 0, 0, 0, 1 }, "CAS", "Q8", "EF", "16", "M", "32", "SETATTR", "hello world"]); + ClassicAssert.AreEqual(1, (int)res1); + +#pragma warning disable CS0618 // Intentionally doing bad things + s.Save(SaveType.ForegroundSave); +#pragma warning restore CS0618 + + var commit = server.Store.WaitForCommit(); + ClassicAssert.IsTrue(commit); + server.Dispose(deleteDir: false); + + server = TestUtils.CreateGarnetServer(TestUtils.MethodTestDir, tryRecover: true, enableAOF: true); + server.Start(); + } + + using (var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig(allowAdmin: true))) + { + var db = redis.GetDatabase(0); + + var res1 = (int)db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(1, res1); + + var res2 = (string[])db.Execute("VEMB", ["foo", new byte[] { 0, 0, 0, 1 }]); + ClassicAssert.AreEqual(res2.Length, addData1.Length); + + for (var i = 0; i < res2.Length; i++) + { + ClassicAssert.AreEqual((float)addData2[i], float.Parse(res2[i])); + } + } + } + } + + // TODO: FLUSHDB needs to cleanup too... + + [Test] + public void VREM() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(0); + + // Populate + var res1 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 0, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res1); + + var res2 = db.Execute("VADD", ["foo", "REDUCE", "50", "VALUES", "75", "100.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", new byte[] { 1, 0, 0, 0 }, "CAS", "Q8", "EF", "16", "M", "32"]); + ClassicAssert.AreEqual(1, (int)res2); + + // Remove on non-vector set fails + // TODO: test against Redis, how do they respond (I expect WRONGTYPE, but needs verification) + //_ = db.StringSet("fizz", "buzz"); + //var exc1 = ClassicAssert.Throws(() => db.Execute("VREM", "fizz", new byte[] { 0, 0, 0, 0 })); + //ClassicAssert.AreEqual("", exc1.Message); + + // Remove exists + var res3 = db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(1, (int)res3); + + // Remove again fails + var res4 = db.Execute("VREM", ["foo", new byte[] { 0, 0, 0, 0 }]); + ClassicAssert.AreEqual(0, (int)res4); + + // Remove not present + var res5 = db.Execute("VREM", ["foo", new byte[] { 1, 2, 3, 4 }]); + ClassicAssert.AreEqual(0, (int)res5); + + // VSIM doesn't return removed element + var res6 = (byte[][])db.Execute("VSIM", ["foo", "VALUES", "75", "110.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "4.0", "1.0", "2.0", "3.0", "COUNT", "5", "EPSILON", "1.0", "EF", "40"]); + ClassicAssert.AreEqual(1, res6.Length); + ClassicAssert.IsTrue(res6.Any(static x => x.SequenceEqual(new byte[] { 1, 0, 0, 0 }))); + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/TestUtils.cs b/test/Garnet.test/TestUtils.cs index eb3f64563ca..b5ec5531fec 100644 --- a/test/Garnet.test/TestUtils.cs +++ b/test/Garnet.test/TestUtils.cs @@ -125,6 +125,9 @@ internal static bool IsRunningAzureTests } } + internal static bool IsRunningAsGitHubAction + => "true".Equals(Environment.GetEnvironmentVariable("GITHUB_ACTIONS"), StringComparison.OrdinalIgnoreCase); + [MethodImpl(MethodImplOptions.AggressiveInlining)] internal static void AssertEqualUpToExpectedLength(string expectedResponse, byte[] response) { @@ -273,8 +276,9 @@ public static GarnetServer CreateGarnetServer( int expiredKeyDeletionScanFrequencySecs = -1, bool useReviv = false, bool useInChainRevivOnly = false, - bool useLogNullDevice = false - ) + bool useLogNullDevice = false, + bool enableVectorSetPreview = true + ) { if (useAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -361,6 +365,7 @@ public static GarnetServer CreateGarnetServer( UnixSocketPermission = unixSocketPermission, SlowLogThreshold = slowLogThreshold, ExpiredKeyDeletionScanFrequencySecs = expiredKeyDeletionScanFrequencySecs, + EnableVectorSetPreview = enableVectorSetPreview, }; if (!string.IsNullOrEmpty(memorySize)) @@ -653,7 +658,8 @@ public static GarnetServerOptions GetGarnetServerOptions( int loggingFrequencySecs = 5, int checkpointThrottleFlushDelayMs = 0, bool clusterReplicaResumeWithData = false, - int replicaSyncTimeout = 60) + int replicaSyncTimeout = 60, + bool enableVectorSetPreview = true) { if (useAzureStorage) IgnoreIfNotRunningAzureTests(); @@ -775,6 +781,7 @@ public static GarnetServerOptions GetGarnetServerOptions( CheckpointThrottleFlushDelayMs = checkpointThrottleFlushDelayMs, ClusterReplicaResumeWithData = clusterReplicaResumeWithData, ReplicaSyncTimeout = replicaSyncTimeout <= 0 ? Timeout.InfiniteTimeSpan : TimeSpan.FromSeconds(replicaSyncTimeout), + EnableVectorSetPreview = enableVectorSetPreview, }; if (lowMemory) diff --git a/website/docs/dev/vector-sets.md b/website/docs/dev/vector-sets.md new file mode 100644 index 00000000000..fc7a6271ffc --- /dev/null +++ b/website/docs/dev/vector-sets.md @@ -0,0 +1,413 @@ +--- +id: vector-sets +sidebar_label: Vector Sets +title: Vector Sets +--- + +# Overview + +Garnet has partial support for Vector Sets, implemented on top of the [DiskANN project](https://www.nuget.org/packages/diskann-garnet/). + +This data type is very strange when compared to others Garnet supports. + +> [!IMPORTANT] +> The DiskANN link needs to be updated once OSS'd. + +# Design + +Vector Sets are a combination of one "index" key, which stores metadata and a pointer to the DiskANN data structure, and many "element" keys, which store vectors/quantized vectors/attributes/etc. All Vector Set keys are kept in the main store, but only the index key is visible - this is accomplished by putting all element keys in different namespaces. + +## Global Metadata + +In order to track allocated Vector Sets (and their respective hash slots), in progress cleanups, in progress migrations - we keep a single `ContextMetadata` struct under the empty key in namespace 0. + +This is loaded and cached on startup, and updated (both in memory and in Tsavorite) whenever a Vector Set is created or deleted. Simple locking (on the `VectorManager` instance) is used to serialize these updates as they should be rare. + +> [!IMPORTANT] +> Today `ContextMetadata` can track only 64 Vector Sets in some state of creation or cleanup. +> +> The practical limit is actually 31, because context must be < 256, divisible by 8, and not 0 (which is reserved). +> +> This limitation will be lifted eventually, perhaps after Store V2 lands. + +## Indexes + +The index key (represented by the `Index` struct) contains the following data: + - `ulong Context` - used to derive namespaces, detailed below + - `ulong IndexPtr` - a pointer to the DiskANN data structure, note this may be _dangling_ after [recovery](#recovery) or [replication](#replication) + - `uint Dimensions` - the expected dimension of vectors in commands targeting the Vector Set, this is inferred based on the `VADD` that creates the Vector Set + - `uint ReduceDims` - if a Vector Set was created with the `REDUCE` option that value, otherwise zero + * > [!NOTE] + > Today this ignored except for validation purposes, eventually DiskANN will use it. + - `uint NumLinks` - the `M` used to create the Vector Set, or the default value of 16 if not specified + - `uint BuildExplorationFactor` - the `EF` used to create the Vector Set, or the default value of 200 if not specified + - `VectorQuantType QuantType` - the quantizier specified at creation time, or the default value of `Q8` if not specified + * > [!NOTE] + > We have an extension here, `XPREQ8` which is not from Redis. + > This is a quantizier for data sets which have already been 8-bit quantized or are otherwise naturally small byte vectors, and is extremely optimized for reducing reads during queries. + > It forbids the `REDUCE` option and requires 4-byte element ids. + * > [!IMPORTANT] + > Today only `XPREQ` is actually implemented, eventually DiskANN will provide reasonable versions of all the Redis builtin quantizers. + - `Guid ProcessInstanceId` - an identifier which is used distinguish the current process from previous instances, this is used after [recovery](#recovery) or [replication](#replication) to detect if `IndexPtr` is dangling + +The index key is in the main store alongside other binary values like strings, hyperloglogs, and so on. It is distinguished for `WRONGTYPE` purposes with the `VectorSet` bit on `RecordInfo`. + +> [!IMPORTANT] +> `RecordInfo.VectorSet` is checked in a few places to correctly produce `WRONGTYPE` responses, but we need more coverage for all commands. Probably something akin to how ACLs required per-command tests. + +> [!IMPORTANT] +> A generalization of the `VectorSet`-bit should be used for all data types, this can happen once we have Store V2. + +## Elements + +While the Vector Set API only concerns itself with top-level index keys, ids, vectors, and attributes; DiskANN has different storage needs. To abstract around these needs a bit, we reserve a number of different "namespaces" for each Vector Set. + +These namespaces are simple numbers, starting at the `Context` value stored in the `Index` struct - we currently reserve 8 namespaces per Vector Set. What goes in which namespace is mostly hidden from Garnet, DiskANN indicates namespace (and index) to use with a modified `Context` passed to relevant callbacks. +> There are two cases where we "know" the namespace involved: attributes (+3) and full vectors (+0) which are used to implement the `WITHATTR` option and the `VEMB` command respectively. These exceptions _may_ go away in the future, but don't have to. + +Using namespaces prevents other commands from accessing keys which store element data. + +To illustrate, this means that: +``` +VADD vector-set-key VALUES 1 123 element-key +SET element-key string-value +``` +Can work as expected. Without namespacing, the `SET` would overwrite (or otherwise mangle) the element data of the Vector Set. + +# Operations + +We implement the [Redis Vector Set API](https://redis.io/docs/latest/commands/?group=vector_set): + +Implemented commands: + - [x] VADD + - [ ] VCARD + - [x] VDIM + - [x] VEMB + - [ ] VGETATTR + - [ ] VINFO + - [ ] VISMEMBER + - [ ] VLINKS + - [ ] VRANDMEMBER + - [x] VREM + - [ ] VSETATTR + - [x] VSIM + +## Creation (via `VADD`) + +[`VADD`](https://redis.io/docs/latest/commands/vadd/) implicitly creates a Vector Set when run on an empty key. + +DiskANN index creation must be serialized, so this requires holding an exclusive lock ([more details on locking](#locking)) that covers just that key. During the `create_index` call to DiskANN the read/write/delete callbacks provided may be invoked - accordingly creation is re-entrant and we cannot call `create_index` directly from any Tsavorite session functions. + +## Insertion (via `VADD`) + +Once a Vector Set exists, insertions (which also use `VADD`) can proceed in parallel. + +Every insertion begins with a Tsavorite read, to get the [`Index`](#indexes) metadata (for validation) and the pointer to DiskANN's index. As a consequence, most `VADD` operations despite _semantically_ being writes are, from Tsavorite's perspective, reads. This has implications for replication, [which is discussed below](#replication). + +To prevent the index from being deleted mid-insertion, we hold a shared lock while calling DiskANN's `insert` function. These locks are sharded for performance purposes, [which is discussed below](#locking). + +## Removal (via `VREM`) + +Removal works much the same as insertion, using shared locks so it can proceed in parallel. The only meaningful difference is calling DiskANN's `remove` instead of `insert`. + +> [!NOTE] +> Removing all elements from a Vector Set is not the same as deleting it. While it is not possible to create an empty Vector Set with a single command, it is legal for one to exist after a `VREM`. + +## Search (via `VSIM`) + +Searching is a pure read operation, and so holds shared locks and proceeds in parallel like insertions and removals. + +Great care is taken to avoid copying during `VSIM`. In particular, values and element ids are passed directly from the receive buffer for all encodings except `VALUES`. Callbacks from DiskANN to Garnet likewise take great care to avoid copying, and are [detailed below](#diskann-integration). + +## Element Data (via `VEMB` and `VGETATTR`) + +These operations are handled purely on the Garnet side by first reading out the [`Index`](#indexes) structure, and then using the context value to look for data in the appropriate namespaces. + +> [!NOTE] +> Strictly speaking we don't need the DiskANN index to access this data, but the current implementation does make sure the index is valid. + +## Metadata (via `VDIM` and `VINFO`) + +Metadata is handled purely on the Garnet side by reading out the [`Index`](#indexes) structure. + +> [!NOTE] +> `VINFO` directly exposes Redis implementation details in addition to "normal" data. +> Because our implementation is different, we intentionally will not expose all the same information. +> To be concrete `max-level`, `vset-uid`, and `hnsw-max-node-uid` are not returned. + +> [!IMPORTANT] +> We _may_ return more details of our own implementation. What those are need to be documented, and why, +> when we implement `VINFO`. + +## Deletion (via `DEL` and `UNLINK`) + +`DEL` (and its equivalent `UNLINK`) is only non-Vector Set command to be routinely expected on a Vector Set key. It is complicated by not knowing we're operating on a Vector Set until we get rather far into deletion. + +We cope with this by _cancelling_ the Tsavorite delete operation once we have a `RecordInfo` with the `VectorSet`-bit set and a value which is not all zeros, detecting that cancellation in `MainStoreOps`, and shunting the delete attempt to `VectorManager`. + +`VectorManager` performs the delete in five steps: + - Acquire exclusive locks covering the Vector Set ([more locking details](#locking)) + - If the index was initialized in the current process ([see recovery for more details](#recovery)), call DiskANN's `drop_index` function + - Perform a write to zero out the index key in Tsavorite + - Reattempt the Tsavorite delete + - Cleanup ancillary metadata and schedule element data for cleanup ([more details below](#cleanup)) + +## FlushDB + +`FLUSHDB` (and it's relative `FLUSHALL`) require special handling. + +> [!IMPORTANT] +> This is not currently implemented. + +# Locking + +Vector Sets workloads require extreme parallelism, and so intricate locking protocols are required for both performance and correctness. + +Concretely, there are 3 sorts of locks involved: + - Tsavorite hashbucket locks + - A `ReadOptimizedLock` instance + - `VectorManager` lock around `ContextMetadata` + +## Tsavorite Locks + +Whenever we read or write a key/value pair in the main store, we acquire locks in Tsavorite. Importantly, we cannot start a new Tsavorite operation while still holding these locks - we must copy the index out before each operation so Garnet can use the read/write/delete callbacks. + +> [!NOTE] +> Based on profiling, Tsavorite shared locks are a significant source of contention. Even though reads will not block each other we still pay a cache coherency tax. Accordingly, reducing the number of Tsavorite operations (even reads) can lead to significant performance gains. + +> [!IMPORTANT] +> Some effort was spent early attempting to elide the initial index read in common cases. This did not pay dividends on smaller clusters, but is worth exploring again on large SKUs. + +## `ReadOptimizedLock` + +As noted above, to prevent `DEL` from clobbering in use Vector Sets and concurrent `VADD`s from calling `create_index` multiple times we have to hold locks based on the Vector Set key. As every Vector Set operations starts by taking these locks, we have sharded them into separate locks. To derive many related keys from a single key, we mangle the low bits of a key's hash value - this is implemented in the new (but not bound to Vector Sets) type `ReadOptimizedLock`. + +For operations which remain reads, we only acquire a single shared lock (based on the current thread) to prevent destructive operations. + +For operations which are always writes (like `DEL`) we acquire all sharded locks in exclusive mode. + +For operations which might be either (like `VADD`) we first acquire the usual single sharded lock (in shared mode), then promote to an exclusive lock if needed. + +## `VectorManager` Lock Around `ContextMetadata` + +Whenever we need to allocate a new context or mark an old one for cleanup, we need to modify the cached `ContextMetadata` and write the new value to Tsavorite. To simplify this, we take a plain `lock` around `VectorManager` while preparing a new `ContextMetadata`. + +The `RMW` into Tsavorite still proceeds in parallel, outside of the lock, but a version counter in `ContextMetadata` allows us to keep only the latest version in the store. + +> [!NOTE] +> Rapid creation or deletion of Vector Sets is expected to perform poorly due to this lock. +> This isn't a case we're very interested in right now, but if that changes this will need to be reworked. + +# Replication + +Replicating Vector Sets is tricky because of the unusual "writes are actually reads"-semantics of most operations. + +## On Primaries + +As noted above, inserts (via `VADD`) and deletes (via `VREM`) are reads from Tsavorite's perspective. As a consequence, normal replication (which is triggered via `MainSessionFunctions.WriteLog(Delete|RMW|Upsert)`) does not happen on those operations. + +To fix that, synthetic writes against related keys are made after an insert or remove. These writes are against the same Vector Set key, but in namespace 0. See `VectorManager.ReplicateVectorSetAdd` and `VectorManager.ReplicateVectorSetRemove` for details. + +> [!IMPORTANT] +> There is a failure case here where we crash between the insert operation completing and the replication operation completing. +> +> This appears to simply extend a window that already existed between when a Tsavorite operation completed and an entry was written to the AOF. +> This needs to confirmed - if it is not the case, handling this failure needs to be figured out. + +> [!IMPORTANT] +> This code assumes a Vector Set under the empty string is illegal. That needs to be tested against Redis, and if it's not true we need to use +> one of the other reserved namespaces. + +> [!NOTE] +> These synthetic writes might appear to double write volume, but that is not the case. Actual inserts and deletes have extreme write amplification (that is, each cause DiskANN to perform many writes against the Main Store), whereas the synthetic writes cause a single (no-op) modification to the Main Store plus an AOF entry. + +> [!NOTE] +> The replication key is the same for all operations against the same Vector Set, this could be sharded which may improve performance. + +## On Replicas + +The synthetic writes on primary are intercepted on replicas and redirected to `VectorManager.HandleVectorSetAddReplication` and `VectorManager.HandleVectorSetRemoveReplication`, rather than being handled directly by `AOFProcessor`. + +For performance reasons, replicated `VADD`s are applied across many threads instead of serially. This introduces a new source of non-determinism, since `VADD`s will occur in a different order than on the primary, but this is acceptable as Vector Sets are inherently non-deterministic. While not _exactly_ the same Redis also permits a degree of non-determinism with its `CAS` option for `VADD`, so we're not diverging an incredible amount here. + +While a `VADD` can proceed in parallel with respect to other `VADD`s, that is not the case for any other commands. Accordingly, `AofProcessor` now calls `VectorManager.WaitForVectorOperationsToComplete()` before applying any other updates to maintain coherency. + +## Migration + +Migrating a Vector Set between two primaries (either as part of a `MIGRATE ... KEYS` or migration of a whole hash slot) is complicated by storing element data in namespaces. + +Namespaces (intentionally) do not participate in hash slots or clustering, and are a node specific concept. This means that migration must also update the namespaces of elements as they are migrated. + +At a high level, migration between the originating primary a destination primary behaves as follows: + 1. Once target slots transition to `MIGRATING`... + * An addition to `ClusterSession.SingleKeySlotVerify` causes all WRITE Vector Set commands to pause once a slot is `MIGRATING` or `IMPORTING` - this is necessary because we cannot block based on the key as Vector Sets are composed of many key-value pairs across several namespaces + 2. `VectorManager` on the originating primary enumerates all _namespaces_ and Vector Sets that are covered by those slots + 3. The originating primary contacts the destination primary and reserves enough new Vector Set contexts to handled those found in step 2 + * These Vector Sets are "in use" but also in a migrating state in `ContextMetadata` + 4. During the scan of main store in `MigrateOperation` any keys found with namespaces found in step 2 are migrated, but their namespace is updated prior to transmission to the appropriate new namespaces reserved in step 3 + * Unlike with normal keys, we do not _delete_ the keys in namespaces as we enumerate them + * Also unlike with normal keys, we synthesize a write on the _destination_ (using a special arg and `VADD`) so replicas of the destination also get these writes + 5. Once all namespace keys are migrated, we migrate the Vector Set index keys, but mutate their values to have the appropriate context reserved in step 3 + * As in 4, we synthesize a write on the _destination_ to tell any replicas to also create the index key + 6. When the target slots transition back to `STABLE`, we do a delete of the Vector Set index keys, drop the DiskANN indexes, and schedule the original contexts for cleanup on the originating primary + * Unlike in 4 & 5, we do no synthetic writes here. The normal replication of `DEL` will cleanup replicas of the originating primary. + + `KEYS` migrations differ only in the slot discovery being omitted. We still have to determine the migrating namespaces, reserve new ones on the destination primary, and schedule cleanup only once migration is completed. This does mean that, if any of the keys being migrated is a Vector Set, `MIGRATE ... KEYS` now causes a scan of the main store. + +> [!NOTE] +> This approach prevents the Vector Set from being visible when it is partially migrated, which has the desirable property of not returning weird results during a migration. + +> [!NOTE] +> While we explicitly reserve contexts on primaries, they are implicit on replicas. This is because a replica should always come up with the same determination of reserved contexts. +> +> To keep that determinism, the synthetic `VADD`s introduced by migration are not executed in parallel. + +# Cleanup + +Deleting a Vector Set only drops the DiskANN index and removes the top-level keys (ie. the visible key and related hidden keys for replication). This leaves all element, attribute, neighbor lists, etc. still in the Main Store. + +To clean up the remaining data we record the deleted index context value in `ContextMetadata` and then schedule a full sweep of the Main Store looking for any keys under namespaces related to that context. When we find those keys we delete them, see `VectorManager.RunCleanupTaskAsync()` and `VectorManager.PostDropCleanupFunctions` for details. + +> [!NOTE] +> There isn't really an elegant way to avoid scanning the whole keyspace which can take awhile to free everything up. +> +> If we wanted to explore better options, we'd need to build something that can drop whole namespaces at once in Tsavorite. + +> [!IMPORTANT] +> Today because we only have ~30 available Vector Set contexts, it is quite likely that deleting a Vector Set and then immediately creating a new one will fail if you're near the limit. +> +> This will be fixed once we have arbitrarily long namespaces in Store V2, and have updated `ContextMetadata` to track those. + +# Recovery + +Vector Sets represent a unique kind of recovery because most operations are mediated through DiskANN, for which we only ever have a pointer to a data structure. This means that recovery needs to both deal with Vector Sets metadata AND the recreation of the DiskANN side of things. + +## Vector Set Metadata + +During startup we read any old `ContextMetadata` out of the Main Store, cache it, and resume any in progress cleanups. + +## Vector Sets + +While reading out [`Index`](#indexes) before performing a DiskANN function call, we check the stored `ProcessInstanceId` against the (randomly generated) one in our `VectorManager` instance. If they do not match, we know that the DiskANN `IndexPtr` is dangling and we need to recreate the index. + +To recreate, we acquire exclusive locks (in the same way we would for `VADD` or `DEL`) and invoke `create_index` again. From DiskANN's perspective, there's no difference between creating a new empty index and recreating an old one which has existing data. + +This means we recreate indexes lazily after recovery. Consequently the _first_ command (regardless of if it's a `VADD`, a `VSIM`, or whatever) against an index after recovery will be slower since it needs to do extra work, and will block other commands since it needs exclusive locking. + +> [!NOTE] +> Today `ProcessInstanceId` is a `GUID`, which means we're paying for a 16-byte comparison on every command. +> +> This comparison is highly predictable, but we could try and remove the comparison (with caching, as mentioned for `Index` above). +> We could also make it cheaper by using a random `ulong` instead, but would need to do some math to convince ourselves collisions aren't possible in realistic scenarios. + +# DiskANN Integration + +Almost all of how Vector Sets actually function is handled by DiskANN. Garnet just embeds it, translates between RESP commands and DiskANN functions, and manages storage. + +In order for DiskANN to access and store data in Garnet, we provide a set of callbacks. All callbacks are `[UnmanagedCallersOnly]` and converted to function pointers before they are passed to Garnet. + +All callbacks take a `ulong context` parameter which identifies the Vector Set involved (the high 61-bits of the context) and the associated namespace (the low 3-bits of the context). On the Garnet side, the whole `context` is effectively a namespace, but from DiskANN's perspective the top 61-bits are an opaque identifier. + +> [!IMPORTANT] +> As noted elsewhere, we only have a byte's worth of namespaces today - so although `context` could handle quintillions of Vector Sets, today we're limited to just 31. +> +> This restriction will go away with Store V2, but we expect "lower" Vector Sets to out perform "higher" ones due to the need for intermediate data copies with longer namespaces. + +## Read Callback + +The most complicated of our callbacks, the signature is: +```csharp +void ReadCallbackUnmanaged(ulong context, uint numKeys, nint keysData, nuint keysLength, nint dataCallback, nint dataCallbackContext) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, `numKeys` tells us how many keys have been encoded into `keysData`, `keysData` and `keysLength` define a `Span` of length prefixied keys, `dataCallback` is a `delegate* unmanaged[Cdecl, SuppressGCTransition]` used to push found keys back into DiskANN, and `dataCallbackContext` is passed back unaltered to `dataCallback`. + +In the `Span` defined by `keysData` and `keysLength` the keys are length prefixed with a 4-byte little endian `int`. This is necessary to support variable length element ids, but also gives us some scratch space to store a namespace when we convert these to `SpanByte`s. This mangling is done as part of the `IReadArgBatch` implementation we use to read keys from Tsavorite. + +> [!NOTE] +> Once variable sized namespaces are supported we'll have to handle the case where the namespace can't fit in 4 bytes. However, we expect that to be rare (4-bytes would give us ~53,000,000 Vector Sets) and the performance benefits of _not_ copying during querying are very large. + +As we find keys, we invoke `dataCallback(index, dataCallbackContext, keyPointer, keyLength)`. If a key is not found, its index is simply skipped. The benefits of this is that we don't copy data out of the Tsavorite log as part of reads, DiskANN is able to do distance calculations and traversal over in-place data. + +> [!NOTE] +> Each invocation of `dataCallback` is a managed -> native transition, which can add up very quickly. We've reduced that as much as possible with function points and `SuppressGCTransition`, but that comes with risks. +> +> In particular if DiskANN raises an error or blocks in the `dataCallback` expect very bad things to happen, up to the runtime corrupting itself. Great care must be taken to keep the DiskANN side of this call cheap and reliable. + +> [!IMPORTANT] +> Tsavorite has been extended with a `ContextReadWithPrefetch` method to accommodate this pattern, which also employs prefetching when we have batches of keys to lookup. +> +> Additionally, some experimentation to figure out good prefetch sizes (and if [AMAC](https://dl.acm.org/doi/10.14778/2856318.2856321) is useful) based on hardware is merited. Right now we've chosen 12 based on testing with some 96-core Intel machines, but that is unlikely to be correct in all interesting circumstances. + +## Write Callback + +A simpler callback, the signature is: +```csharp +byte WriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nint writeData, nuint writeLength) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, `keyData` and `keyLength` represent a `Span` of the key to write, and `writeData` and `writeLength` represent a `Span` of the value to write. + +DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we can safely modify. This is used to avoid copying the key value when we add a namespace to the `SpanByte` before invoking Tsavorite's `Upsert`. + +This callback returns 1 if successful, and 0 otherwise. + +## Delete Callback + +Another simple callback, the signature is: +```csharp +byte DeleteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, and `keyData` and `keyLength` represent a `Span` of the key to delete. + +As with the write callback, DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we use to store a namespace, and thus avoid copying the key value before invoking Tsavorite's `Delete`. + +This callback returns 1 if the key was found and removed, and 0 otherwise. + +## Read Modify Write Callback + +A more complicated callback, the signature is: +```csharp +byte ReadModifyWriteCallbackUnmanaged(ulong context, nint keyData, nuint keyLength, nuint writeLength, nint dataCallback, nint dataCallbackContext) +``` + +`context` identifies which Vector Set is being operated on AND the associated namespace, and `keyData` and `keyLength` represent a `Span` of the key to create, read, or update. + +`writeLength` is the desired number of bytes, this is only used used if we are creating a new key-value pair. + +As with the write and delete callbacks, DiskANN guarantees an extra 4-bytes BEFORE `keyData` that we use to store a namespace, and thus avoid copying the key value before invoking Tsavorite's `RMW`. + +After we allocate a new key-value pair or find an existing one, `dataCallback(nint dataCallbackContext, nint dataPointer, nuint dataLength)` is called. Changes made to data in this callback are persisted. This needs to be _fast_ to prevent gumming up Tsavorite, as we are under epoch protection. + +Newly allocated values are guaranteed to be all zeros. + +The callback returns 1 if the key-value pair was found or created, and 0 if some error occurred. + +## DiskANN Functions + +Garnet calls into the following DiskANN functions: + + - [x] `nint create_index(ulong context, uint dimensions, uint reduceDims, VectorQuantType quantType, uint buildExplorationFactor, uint numLinks, nint readCallback, nint writeCallback, nint deleteCallback, nint readModifyWriteCallback)` + - [x] `void drop_index(ulong context, nint index)` + - [x] `byte insert(ulong context, nint index, nint id_data, nuint id_len, VectorValueType vector_value_type, nint vector_data, nuint vector_len, nint attribute_data, nuint attribute_len)` + - [x] `byte remove(ulong context, nint index, nint id_data, nuint id_len)` + - [ ] `byte set_attribute(ulong context, nint index, nint id_data, nuint id_len, nint attribute_data, nuint attribute_len)` + - [x] `int search_vector(ulong context, nint index, VectorValueType vector_value_type, nint vector_data, nuint vector_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` + - [x] `int search_element(ulong context, nint index, nint id_data, nuint id_len, float delta, int search_exploration_factor, nint filter_data, nuint filter_len, nuint max_filtering_effort, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint continuation)` + - [ ] `int continue_search(ulong context, nint index, nint continuation, nint output_ids, nuint output_ids_len, nint output_distances, nuint output_distances_len, nint new_continuation)` + - [ ] `ulong card(ulong context, nint index)` + + Some non-obvious subtleties: + - The number of results _requested_ from `search_vector` and `search_element` is indicated by `output_distances_len` + - `output_distances_len` is the number of _floats_ in `output_distances`, not bytes + - When inserting, if `vector_value_type == FP32` then `vector_len` is the number of _floats_ in `vector_data`, otherwise it is the number of bytes + - `byte` returning functions are effectively returning booleans, `0 == false` and `1 == true` + - `index` is always a pointer created by DiskANN and returned from `create_index` + - `context` is always the `Context` value created by Garnet and stored in [`Index`](#indexes) for a Vector Set, this implies it is always a non-0 multiple of 8 + - `search_vector`, `search_element`, and `continue_search` all return the number of ids written into `output_ids`, and if there are more values to return they set the `nint` _pointed to by_ `continuation` or `new_continuation` + +> [!IMPORTANT] +> These p/invoke definitions are all a little rough and should be cleaned up. +> +> They were defined very loosely to ease getting the .NET <-> Rust interface working quickly. \ No newline at end of file