-
Notifications
You must be signed in to change notification settings - Fork 86
[feat] hma connector supports GPU buffer MR for GPUDirct RDMA #981
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
6962645
4ca6d9e
2bbdfa0
8d66c31
6737e8d
222620c
1f8aaa6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,6 +53,7 @@ def __init__(self, kvcaches: dict[str, torch.Tensor]) -> None: | |
| self.kvcaches = dict(sorted(kvcaches.items(), key=self._sort_key)) | ||
| self.base_ptrs: np.ndarray | ||
| self.block_strides: np.ndarray | ||
| self.buffer_sizes: np.ndarray | ||
| self.tensor_token_strides: np.ndarray | ||
| self.tensor_sizes_per_token: np.ndarray | ||
| self.tensor_block_sizes: np.ndarray | ||
|
|
@@ -68,6 +69,7 @@ def _build_layout(self) -> None: | |
|
|
||
| ptrs: list[int] = [] | ||
| strides: list[int] = [] | ||
| buffer_sizes: list[int] = [] | ||
| tensor_token_strides: list[int] = [] | ||
| tensor_sizes_per_token: list[int] = [] | ||
| tensor_block_sizes: list[int] = [] | ||
|
|
@@ -79,8 +81,12 @@ def handle_tensor( | |
| layer_name: str, | ||
| ) -> None: | ||
| ptrs.append(t[0].data_ptr()) | ||
| strides.append(t.stride(0) * t.element_size()) | ||
| block_stride = t.stride(0) * t.element_size() | ||
| strides.append(block_stride) | ||
| tensor_size = math.prod([t.shape[i] for i in size_dims]) * t.element_size() | ||
| # GPU buffer sizes for GPUDirect RDMA registration in store. | ||
| # Total buffer size = number of blocks (shape[0]) × bytes per block stride. | ||
| buffer_sizes.append(int(t.shape[0]) * block_stride) | ||
|
relat-ivity marked this conversation as resolved.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 Critical: Potential integer overflow. When |
||
| token_dim = 1 | ||
| tensor_block_size = int(t.shape[token_dim]) | ||
| tensor_token_strides.append(t.stride(token_dim) * t.element_size()) | ||
|
|
@@ -137,6 +143,7 @@ def handle_kv_layer_tensor(tensor: torch.Tensor, layer_name: str) -> None: | |
|
|
||
| self.base_ptrs = np.asarray(ptrs, dtype=np.uint64) | ||
| self.block_strides = np.asarray(strides, dtype=np.uint64) | ||
| self.buffer_sizes = np.asarray(buffer_sizes, dtype=np.uint64) | ||
| self.tensor_token_strides = np.asarray(tensor_token_strides, dtype=np.uint64) | ||
| self.tensor_sizes_per_token = np.asarray( | ||
| tensor_sizes_per_token, dtype=np.uint64 | ||
|
|
@@ -155,6 +162,7 @@ def handle_kv_layer_tensor(tensor: torch.Tensor, layer_name: str) -> None: | |
| logger.info( | ||
| f"KV cache group layout: views={len(self.kvcaches)}, " | ||
| f"ptrs={len(ptrs)}, " | ||
| f"buffer_bytes={sum(int(size) for size in self.buffer_sizes)}, " | ||
| f"tensor_block_sizes={sorted(set(tensor_block_sizes))}" | ||
| ) | ||
|
|
||
|
|
@@ -478,17 +486,23 @@ def _create_fa_store( | |
| """Create the backing store used for full-attention rows.""" | ||
|
|
||
| tensor_size_list = None | ||
| gpu_kv_buffer_config = None | ||
| if self._role == KVConnectorRole.WORKER: | ||
| if group_layouts is None: | ||
| raise RuntimeError("Worker FA store needs layouts.") | ||
| tensor_size_list = self._store_tensor_size_list( | ||
| group_layouts, | ||
| self.fa_group_ids, | ||
| ) | ||
| gpu_kv_buffer_config = self._gpu_kv_buffer_config( | ||
| group_layouts, | ||
| self.fa_group_ids, | ||
| ) | ||
| return self._create_store( | ||
| "FA", | ||
| "fa", | ||
| tensor_size_list, | ||
| gpu_kv_buffer_config, | ||
| cpu_affinity_cores, | ||
| ) | ||
|
|
||
|
|
@@ -500,17 +514,23 @@ def _create_wa_store( | |
| """Create the backing store used for window-tail rows.""" | ||
|
|
||
| tensor_size_list = None | ||
| gpu_kv_buffer_config = None | ||
| if self._role == KVConnectorRole.WORKER: | ||
| if group_layouts is None: | ||
| raise RuntimeError("Worker WA store needs layouts.") | ||
| tensor_size_list = self._store_tensor_size_list( | ||
| group_layouts, | ||
| self.window_group_ids, | ||
| ) | ||
| gpu_kv_buffer_config = self._gpu_kv_buffer_config( | ||
| group_layouts, | ||
| self.window_group_ids, | ||
| ) | ||
| return self._create_store( | ||
| "WA", | ||
| "wa", | ||
| tensor_size_list, | ||
| gpu_kv_buffer_config, | ||
| cpu_affinity_cores, | ||
| ) | ||
|
|
||
|
|
@@ -569,6 +589,7 @@ def _create_store( | |
| label: str, | ||
| store_suffix: str, | ||
| tensor_size_list: Optional[list[int]], | ||
| gpu_kv_buffer_config: Optional[tuple[list[int], list[int]]] = None, | ||
| cpu_affinity_cores: Optional[list[int]] = None, | ||
| ) -> UcmKVStoreBaseV1: | ||
| """Instantiate one UCM store with worker tensor layout metadata.""" | ||
|
|
@@ -591,6 +612,21 @@ def _create_store( | |
| ) | ||
| # MLA stores aggregate TP shards under one logical rank group. | ||
| config["local_rank_size"] = self.tp_size if self.is_mla else 1 | ||
| if gpu_kv_buffer_config is not None: | ||
| gpu_kv_buffer_addrs, gpu_kv_buffer_sizes = gpu_kv_buffer_config | ||
| if not gpu_kv_buffer_addrs or not gpu_kv_buffer_sizes: | ||
| raise RuntimeError( | ||
| f"Worker FAWA {label} store needs non-empty GPU KV " | ||
| "buffer addresses and sizes." | ||
| ) | ||
| config["gpu_kv_buffer_addrs"] = gpu_kv_buffer_addrs | ||
|
relat-ivity marked this conversation as resolved.
|
||
| config["gpu_kv_buffer_sizes"] = gpu_kv_buffer_sizes | ||
| logger.debug( | ||
| f"register FAWA {label} GPU KV buffers: " | ||
|
relat-ivity marked this conversation as resolved.
|
||
| f"count={len(gpu_kv_buffer_addrs)}, " | ||
| f"bytes={sum(int(size) for size in gpu_kv_buffer_sizes)}, " | ||
| f"first_5={[(addr, size) for addr, size in zip(gpu_kv_buffer_addrs[:5], gpu_kv_buffer_sizes[:5])]}" | ||
| ) | ||
| if cpu_affinity_cores: | ||
| config["cpu_affinity_cores"] = list(cpu_affinity_cores) | ||
| else: | ||
|
|
@@ -611,6 +647,16 @@ def _summarize_store_config(config: dict[str, object]) -> dict[str, object]: | |
| tensor_sizes = [int(size) for size in tensor_size_list] | ||
| summary["tensor_count"] = len(tensor_sizes) | ||
| summary["tensor_bytes"] = sum(tensor_sizes) | ||
| gpu_kv_buffer_addrs = summary.pop("gpu_kv_buffer_addrs", None) | ||
|
relat-ivity marked this conversation as resolved.
|
||
| gpu_kv_buffer_sizes = summary.pop("gpu_kv_buffer_sizes", None) | ||
| assert (gpu_kv_buffer_addrs is None) == ( | ||
| gpu_kv_buffer_sizes is None | ||
| ), "GPU KV buffer addresses and sizes must be both None or both non-None" | ||
| if gpu_kv_buffer_addrs is not None: | ||
|
relat-ivity marked this conversation as resolved.
|
||
| summary["gpu_kv_buffer_count"] = len(gpu_kv_buffer_addrs) | ||
| summary["gpu_kv_buffer_bytes"] = sum( | ||
| int(size) for size in gpu_kv_buffer_sizes | ||
| ) | ||
| return summary | ||
|
|
||
| def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): | ||
|
|
@@ -679,6 +725,36 @@ def _store_tensor_size_list( | |
| raise RuntimeError(f"Worker FAWA {group_label} layout is empty.") | ||
| return tensor_size_list | ||
|
|
||
| @staticmethod | ||
| def _gpu_kv_buffer_config( | ||
| group_layouts: dict[int, KVCacheGroupLayout], | ||
| group_ids: tuple[int, ...], | ||
| ) -> tuple[list[int], list[int]]: | ||
| gpu_kv_buffer_set: set[tuple[int, int]] = set() | ||
| gpu_kv_buffer_addrs: list[int] = [] | ||
| gpu_kv_buffer_sizes: list[int] = [] | ||
|
relat-ivity marked this conversation as resolved.
|
||
| for group_id in group_ids: | ||
| layout = group_layouts.get(group_id) | ||
| if layout is None: | ||
|
relat-ivity marked this conversation as resolved.
|
||
| logger.warning( | ||
| f"Skip GPU KV buffer registration for group_id={group_id}: " | ||
| "no KV cache layout was registered." | ||
| ) | ||
| continue | ||
| buffer_addrs = layout.base_ptrs.reshape(-1).tolist() | ||
| buffer_sizes = layout.buffer_sizes.reshape(-1).tolist() | ||
| assert len(buffer_addrs) == len( | ||
| buffer_sizes | ||
| ), "KV cache buffer addresses and sizes must have the same length." | ||
| for addr, size in zip(buffer_addrs, buffer_sizes): | ||
| key = (addr, size) | ||
| if key in gpu_kv_buffer_set: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Suggestion: For better readability, use gpu_kv_buffer_set.add((addr, size))
gpu_kv_buffer_addrs.append(addr)
gpu_kv_buffer_sizes.append(size) |
||
| continue | ||
| gpu_kv_buffer_set.add(key) | ||
| gpu_kv_buffer_addrs.append(key[0]) | ||
| gpu_kv_buffer_sizes.append(key[1]) | ||
| return gpu_kv_buffer_addrs, gpu_kv_buffer_sizes | ||
|
relat-ivity marked this conversation as resolved.
|
||
|
|
||
| def _lookup_external_hit_blocks(self, external_keys: list[bytes]) -> int: | ||
| """Find the longest reusable prefix present in both FA and WA stores.""" | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.