diff --git a/loadgen/mlperf.conf b/loadgen/mlperf.conf index 1b825514bd..bb28759996 100644 --- a/loadgen/mlperf.conf +++ b/loadgen/mlperf.conf @@ -77,6 +77,7 @@ retinanet.Server.target_latency = 100 bert.Server.target_latency = 130 dlrm.Server.target_latency = 60 dlrm-v2.Server.target_latency = 60 +dlrm-v3.Server.target_latency = 80 rnnt.Server.target_latency = 1000 gptj.Server.target_latency = 20000 stable-diffusion-xl.Server.target_latency = 20000 diff --git a/recommendation/dlrm_v3/README.md b/recommendation/dlrm_v3/README.md new file mode 100644 index 0000000000..ea3e07c588 --- /dev/null +++ b/recommendation/dlrm_v3/README.md @@ -0,0 +1,90 @@ +# MLPerf Inference reference implementation for DLRMv3 + +## Install dependencies and build loadgen + +The reference implementation has been tested on a single host, with x86_64 CPUs and 8 NVIDIA H100/B200 GPUs. Dependencies can be installed below, +``` +sh setup.sh +``` + +## Dataset download + +DLRMv3 uses a synthetic dataset specifically designed to match the model and system characteristics of large-scale sequential recommendation (large item set and long average sequence length for each request). To generate the dataset used for both training and inference, run +``` +python streaming_synthetic_data.py +``` +The generated dataset has 2TB size, and contains 5 million users interacting with a billion items over 100 timestamps. + +Only 1% of the dataset is used in the inference benchmark. The sampled DLRMv3 dataset and trained checkpoint are available at https://inference.mlcommons-storage.org/. + +Script to download the sampled dataset used in inference benchmark: +``` +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) https://inference.mlcommons-storage.org/metadata/dlrm-v3-dataset.uri +``` +Script to download the 1TB trained checkpoint: +``` +bash <(curl -s https://raw.githubusercontent.com/mlcommons/r2-downloader/refs/heads/main/mlc-r2-downloader.sh) https://inference.mlcommons-storage.org/metadata/dlrm-v3-checkpoint.uri +``` + +## Inference benchmark + +``` +WORLD_SIZE=8 python main.py --dataset sampled-streaming-100b +``` + +`WORLD_SIZE` is the number of GPUs used in the inference benchmark. + +``` +usage: main.py [-h] [--dataset {streaming-100b,sampled-streaming-100b}] [--model-path MODEL_PATH] [--scenario-name {Server,Offline}] [--batchsize BATCHSIZE] + [--output-trace OUTPUT_TRACE] [--data-producer-threads DATA_PRODUCER_THREADS] [--compute-eval COMPUTE_EVAL] [--find-peak-performance FIND_PEAK_PERFORMANCE] + [--dataset-path-prefix DATASET_PATH_PREFIX] [--warmup-ratio WARMUP_RATIO] [--num-queries NUM_QUERIES] [--target-qps TARGET_QPS] [--numpy-rand-seed NUMPY_RAND_SEED] + [--sparse-quant SPARSE_QUANT] [--dataset-percentage DATASET_PERCENTAGE] + +options: + -h, --help show this help message and exit + --dataset {streaming-100b,sampled-streaming-100b} + name of the dataset + --model-path MODEL_PATH + path to the model checkpoint. Example: /home/username/ckpts/streaming_100b/89/ + --scenario-name {Server,Offline} + inference benchmark scenario + --batchsize BATCHSIZE + batch size used in the benchmark + --output-trace OUTPUT_TRACE + Whether to output trace + --data-producer-threads DATA_PRODUCER_THREADS + Number of threads used in data producer + --compute-eval COMPUTE_EVAL + If true, will run AccuracyOnly mode and outputs both predictions and labels for accuracy calcuations + --find-peak-performance FIND_PEAK_PERFORMANCE + Whether to find peak performance in the benchmark + --dataset-path-prefix DATASET_PATH_PREFIX + Prefix to the dataset path. Example: /home/username/ + --warmup-ratio WARMUP_RATIO + The ratio of the dataset used to warmup SUT + --num-queries NUM_QUERIES + Number of queries to run in the benchmark + --target-qps TARGET_QPS + Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput + --numpy-rand-seed NUMPY_RAND_SEED + Numpy random seed + --sparse-quant SPARSE_QUANT + Whether to quantize sparse arch + --dataset-percentage DATASET_PERCENTAGE + Percentage of the dataset to run in the benchmark +``` + +## Accuracy test + +Set `run.compute_eval` will run the accuracy test and dump prediction outputs in +`mlperf_log_accuracy.json`. To check the accuracy, run + +``` +python accuracy.py --path path/to/mlperf_log_accuracy.json +``` +We use normalized entropy (NE), accuracy, and AUC as the metrics to evaluate the model quality. For accepted submissions, all three metrics (NE, Accuracy, AUC) must be within 99% of the reference implementation values. The accuracy for the reference implementation evaluated on 34,996 requests across 10 inference timestamps are listed below: +``` +NE: 86.687% +Accuracy: 69.651% +AUC: 78.663% +``` diff --git a/recommendation/dlrm_v3/accuracy.py b/recommendation/dlrm_v3/accuracy.py new file mode 100644 index 0000000000..5d2d0ff11a --- /dev/null +++ b/recommendation/dlrm_v3/accuracy.py @@ -0,0 +1,86 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Tool to calculate accuracy for loadgen accuracy output found in mlperf_log_accuracy.json +""" + +import argparse +import json +import logging + +import numpy as np +import torch +from configs import get_hstu_configs +from utils import MetricsLogger + +logger: logging.Logger = logging.getLogger("main") + + +def get_args() -> argparse.Namespace: + """Parse commandline.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--path", + required=True, + help="path to mlperf_log_accuracy.json", + ) + args = parser.parse_args() + return args + + +def main() -> None: + """ + Main function to calculate accuracy metrics from loadgen output. + + Reads the mlperf_log_accuracy.json file, parses the results, and computes + accuracy metrics using the MetricsLogger. Each result entry contains + predictions, labels, and weights packed as float32 numpy arrays. + """ + args = get_args() + logger.warning("Parsing loadgen accuracy log...") + with open(args.path, "r") as f: + results = json.load(f) + hstu_config = get_hstu_configs(dataset="sampled-streaming-100b") + metrics = MetricsLogger( + multitask_configs=hstu_config.multitask_configs, + batch_size=1, + window_size=3000, + device=torch.device("cpu"), + rank=0, + ) + logger.warning(f"results have {len(results)} entries") + for result in results: + data = np.frombuffer(bytes.fromhex(result["data"]), np.float32) + num_candidates = data[-1].astype(int) + assert len(data) == 1 + num_candidates * 3 + mt_target_preds = torch.from_numpy(data[0:num_candidates]) + mt_target_labels = torch.from_numpy(data[num_candidates : num_candidates * 2]) + mt_target_weights = torch.from_numpy( + data[num_candidates * 2 : num_candidates * 3] + ) + num_candidates = torch.tensor([num_candidates]) + metrics.update( + predictions=mt_target_preds.view(1, -1), + labels=mt_target_labels.view(1, -1), + weights=mt_target_weights.view(1, -1), + num_candidates=num_candidates, + ) + for k, v in metrics.compute().items(): + logger.warning(f"{k}: {v}") + + +if __name__ == "__main__": + main() diff --git a/recommendation/dlrm_v3/checkpoint.py b/recommendation/dlrm_v3/checkpoint.py new file mode 100644 index 0000000000..33dbaf3c58 --- /dev/null +++ b/recommendation/dlrm_v3/checkpoint.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Checkpoint utilities for saving and loading DLRMv3 model checkpoints. + +This module provides functions for saving and loading distributed model checkpoints, +including both sparse (embedding) and dense (non-embedding) components. +""" + +import gc +import os +from datetime import datetime +from typing import Any, Dict, Optional, Set + +import gin + +import torch +from utils import MetricsLogger +from torch.distributed.checkpoint.stateful import Stateful +from torch.optim.optimizer import Optimizer +from torchrec.distributed.types import ShardedTensor + + +class SparseState(Stateful): + """ + Stateful wrapper for sparse (embedding) tensors in a model. + + This class implements the Stateful interface for distributed checkpointing, + allowing sparse tensors to be saved and loaded separately from dense tensors. + + Args: + model: The PyTorch model containing sparse tensors. + sparse_tensor_keys: Set of keys identifying sparse tensors in the model's state dict. + """ + + def __init__(self, model: torch.nn.Module, sparse_tensor_keys: Set[str]) -> None: + self.model = model + self.sparse_tensor_keys = sparse_tensor_keys + + def state_dict(self) -> Dict[str, torch.Tensor]: + out_dict: Dict[str, torch.Tensor] = {} + is_sharded_tensor: Optional[bool] = None + for k, v in self.model.state_dict().items(): + if k in self.sparse_tensor_keys: + if is_sharded_tensor is None: + is_sharded_tensor = isinstance(v, ShardedTensor) + assert is_sharded_tensor == isinstance(v, ShardedTensor) + out_dict[k] = v + return out_dict + + def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + incompatible_keys = self.model.load_state_dict(state_dict, strict=False) + assert not incompatible_keys.unexpected_keys + + +def is_sparse_key(k: str, v: torch.Tensor) -> bool: + return isinstance(v, ShardedTensor) or "embedding_collection" in k + + +def load_dense_state_dict(model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + own_state = model.state_dict() + own_state_dense_keys = {k for k, v in own_state.items() if not is_sparse_key(k, v)} + state_dict_dense_keys = { + k for k, v in state_dict.items() if not is_sparse_key(k, v) + } + assert ( + own_state_dense_keys == state_dict_dense_keys + ), f"expects {own_state_dense_keys} but gets {state_dict_dense_keys}" + for name in state_dict_dense_keys: + param = state_dict[name] + if isinstance(param, torch.nn.Parameter): + # backwards compatibility for serialized parameters + param = param.data + own_state[name].copy_(param) + + +@gin.configurable +def save_dmp_checkpoint( + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + rank: int, + batch_idx: int, + path: str = "", +) -> None: + """ + Save a distributed model checkpoint including sparse and dense components. + + Saves the model's sparse tensors using distributed checkpointing and dense + tensors, optimizer state, and metrics using standard PyTorch serialization. + + Args: + model: The model to checkpoint. + optimizer: The optimizer whose state should be saved. + metric_logger: The metrics logger containing training/eval metrics. + rank: The current process rank in distributed training. + batch_idx: The current batch index (used for checkpoint naming). + path: Base path for saving the checkpoint. If empty, no checkpoint is saved. + """ + if path == "": + return + now = datetime.now() + formatted_datetime = now.strftime("%Y_%m_%d_%H_%M_%S") + path = f"{path}/{batch_idx}" + if not os.path.exists(path) and rank == 0: + os.makedirs(path) + sparse_path = f"{path}/sparse/" + if not os.path.exists(sparse_path) and rank == 0: + os.makedirs(sparse_path) + non_sparse_ckpt = f"{path}/non_sparse.ckpt" + + sparse_tensor_keys = { + k for k, v in model.state_dict().items() if isinstance(v, ShardedTensor) + } + if rank == 0: + dense_state_dict = { + k: v + for k, v in model.state_dict().items() + if not isinstance(v, ShardedTensor) + } + class_metric_state_dict = { + "train": [m.state_dict() for m in metric_logger.class_metrics["train"]], + "eval": [m.state_dict() for m in metric_logger.class_metrics["eval"]], + } + regression_metric_state_dict = { + "train": [ + m.state_dict() for m in metric_logger.regression_metrics["train"] + ], + "eval": [m.state_dict() for m in metric_logger.regression_metrics["eval"]], + } + torch.save( + { + "dense_dict": dense_state_dict, + "optimizer_dict": optimizer.state_dict(), + "class_metrics": class_metric_state_dict, + "reg_metrics": regression_metric_state_dict, + "global_step": metric_logger.global_step, + "sparse_tensor_keys": sparse_tensor_keys, + }, + non_sparse_ckpt, + ) + torch.distributed.barrier() + sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)} + torch.distributed.checkpoint.save( + sparse_dict, + storage_writer=torch.distributed.checkpoint.FileSystemWriter(sparse_path), + ) + torch.distributed.barrier() + print("checkpoint successfully saved") + + +@gin.configurable +def load_sparse_checkpoint( + model: torch.nn.Module, + path: str = "", +) -> None: + if path == "": + return + sparse_path = f"{path}/sparse/" + + sparse_tensor_keys = { + k for k, v in model.state_dict().items() if is_sparse_key(k, v) + } + sparse_dict = {"sparse_dict": SparseState(model, sparse_tensor_keys)} + gc.collect() + torch.distributed.checkpoint.load( + sparse_dict, + storage_reader=torch.distributed.checkpoint.FileSystemReader(sparse_path), + ) + gc.collect() + print("sparse checkpoint successfully loaded") + + +@gin.configurable +def load_nonsparse_checkpoint( + model: torch.nn.Module, + device: torch.device, + optimizer: Optional[Optimizer] = None, + metric_logger: Optional[MetricsLogger] = None, + path: str = "", +) -> None: + """ + Load non-sparse (dense) components from a checkpoint. + + Loads dense model parameters, and optionally optimizer state and metrics. + + Args: + model: The model to load dense parameters into. + device: The device to load tensors onto. + optimizer: Optional optimizer to restore state for. + metric_logger: Optional metrics logger to restore state for. + path: Base path of the checkpoint. If empty, no loading is performed. + """ + if path == "": + return + non_sparse_ckpt = f"{path}/non_sparse.ckpt" + + non_sparse_state_dict = torch.load(non_sparse_ckpt, map_location=device) + load_dense_state_dict(model, non_sparse_state_dict["dense_dict"]) + print("dense checkpoint successfully loaded") + if optimizer is not None: + optimizer.load_state_dict(non_sparse_state_dict["optimizer_dict"]) + print("optimizer checkpoint successfully loaded") + if metric_logger is not None: + metric_logger.global_step = non_sparse_state_dict["global_step"] + class_metric_state_dict = non_sparse_state_dict["class_metrics"] + regression_metric_state_dict = non_sparse_state_dict["reg_metrics"] + for i, m in enumerate(metric_logger.class_metrics["train"]): + m.load_state_dict(class_metric_state_dict["train"][i]) + for i, m in enumerate(metric_logger.class_metrics["eval"]): + m.load_state_dict(class_metric_state_dict["eval"][i]) + for i, m in enumerate(metric_logger.regression_metrics["train"]): + m.load_state_dict(regression_metric_state_dict["train"][i]) + for i, m in enumerate(metric_logger.regression_metrics["eval"]): + m.load_state_dict(regression_metric_state_dict["eval"][i]) + + +@gin.configurable +def load_dmp_checkpoint( + model: torch.nn.Module, + optimizer: Optimizer, + metric_logger: MetricsLogger, + device: torch.device, + path: str = "", +) -> None: + """ + Load a complete distributed model checkpoint (both sparse and dense components). + + This is a convenience function that calls both load_sparse_checkpoint and + load_nonsparse_checkpoint. + + Args: + model: The model to load the checkpoint into. + optimizer: The optimizer to restore state for. + metric_logger: The metrics logger to restore state for. + device: The device to load tensors onto. + path: Base path of the checkpoint. If empty, no loading is performed. + """ + load_sparse_checkpoint(model=model, path=path) + load_nonsparse_checkpoint( + model=model, + optimizer=optimizer, + metric_logger=metric_logger, + path=path, + device=device, + ) diff --git a/recommendation/dlrm_v3/configs.py b/recommendation/dlrm_v3/configs.py new file mode 100644 index 0000000000..3d053b6512 --- /dev/null +++ b/recommendation/dlrm_v3/configs.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Configuration module for DLRMv3 model. + +This module provides configuration functions for the HSTU model architecture and embedding table configurations. +""" +from typing import Dict + +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from generative_recommenders.modules.multitask_module import ( + MultitaskTaskType, + TaskConfig, +) +from torchrec.modules.embedding_configs import DataType, EmbeddingConfig + +HSTU_EMBEDDING_DIM = 512 # final DLRMv3 model +HASH_SIZE = 1_000_000_000 + + +def get_hstu_configs(dataset: str = "debug") -> DlrmHSTUConfig: + """ + Create and return HSTU model configuration. + + Builds a complete DlrmHSTUConfig with default hyperparameters for the HSTU + architecture including attention settings, embedding dimensions, dropout rates, + and feature name mappings. + + Args: + dataset: Dataset identifier (currently unused, reserved for dataset-specific configs). + + Returns: + DlrmHSTUConfig: Complete configuration object for the HSTU model. + """ + hstu_config = DlrmHSTUConfig( + hstu_num_heads=4, + hstu_attn_linear_dim=128, + hstu_attn_qk_dim=128, + hstu_attn_num_layers=5, + hstu_embedding_table_dim=HSTU_EMBEDDING_DIM, + hstu_preprocessor_hidden_dim=256, + hstu_transducer_embedding_dim=512, + hstu_group_norm=False, + hstu_input_dropout_ratio=0.2, + hstu_linear_dropout_rate=0.1, + causal_multitask_weights=0.2, + ) + hstu_config.user_embedding_feature_names = [ + "item_id", + "user_id", + "item_category_id", + ] + hstu_config.item_embedding_feature_names = [ + "item_candidate_id", + "item_candidate_category_id", + ] + hstu_config.uih_post_id_feature_name = "item_id" + hstu_config.uih_action_time_feature_name = "action_timestamp" + hstu_config.candidates_querytime_feature_name = "item_query_time" + hstu_config.candidates_weight_feature_name = "item_action_weights" + hstu_config.uih_weight_feature_name = "item_weights" + hstu_config.candidates_watchtime_feature_name = "item_rating" + hstu_config.action_weights = [1, 2, 4, 8, 16] + hstu_config.action_embedding_init_std = 5.0 + hstu_config.contextual_feature_to_max_length = {"user_id": 1} + hstu_config.contextual_feature_to_min_uih_length = {"user_id": 20} + hstu_config.merge_uih_candidate_feature_mapping = [ + ("item_id", "item_candidate_id"), + ("item_rating", "item_candidate_rating"), + ("action_timestamp", "item_query_time"), + ("item_weights", "item_action_weights"), + ("dummy_watch_time", "item_dummy_watchtime"), + ("item_category_id", "item_candidate_category_id"), + ] + hstu_config.hstu_uih_feature_names = [ + "user_id", + "item_id", + "item_rating", + "action_timestamp", + "item_weights", + "dummy_watch_time", + "item_category_id", + ] + hstu_config.hstu_candidate_feature_names = [ + "item_candidate_id", + "item_candidate_rating", + "item_query_time", + "item_action_weights", + "item_dummy_watchtime", + "item_candidate_category_id", + ] + hstu_config.max_num_candidates = 32 + hstu_config.max_num_candidates_inference = 2048 + hstu_config.multitask_configs = [ + TaskConfig( + task_name="rating", + task_weight=1, + task_type=MultitaskTaskType.BINARY_CLASSIFICATION, + ) + ] + return hstu_config + + +def get_embedding_table_config(dataset: str = "debug") -> Dict[str, EmbeddingConfig]: + """ + Create and return embedding table configurations. + + Defines the embedding table configurations for item IDs, category IDs, and user IDs + with their respective dimensions and data types. + + Args: + dataset: Dataset identifier (currently unused, reserved for dataset-specific configs). + + Returns: + Dict mapping table names to their EmbeddingConfig objects. + """ + return { + "item_id": EmbeddingConfig( + num_embeddings=HASH_SIZE, + embedding_dim=HSTU_EMBEDDING_DIM, + name="item_id", + data_type=DataType.FP16, + feature_names=["item_id", "item_candidate_id"], + ), + "item_category_id": EmbeddingConfig( + num_embeddings=128, + embedding_dim=HSTU_EMBEDDING_DIM, + name="item_category_id", + data_type=DataType.FP16, + weight_init_max=1.0, + weight_init_min=-1.0, + feature_names=["item_category_id", "item_candidate_category_id"], + ), + "user_id": EmbeddingConfig( + num_embeddings=10_000_000, + embedding_dim=HSTU_EMBEDDING_DIM, + name="user_id", + data_type=DataType.FP16, + feature_names=["user_id"], + ), + } diff --git a/recommendation/dlrm_v3/data_producer.py b/recommendation/dlrm_v3/data_producer.py new file mode 100644 index 0000000000..a2b8e18e09 --- /dev/null +++ b/recommendation/dlrm_v3/data_producer.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Data producer module for DLRMv3 inference. + +This module provides classes for producing and managing query data during inference, +supporting both single-threaded and multi-threaded data production modes. +""" + +import logging +import threading +import time +from queue import Queue +from typing import List, Optional, Tuple, Union + +import torch +from datasets.dataset import Dataset, Samples + +logging.basicConfig(level=logging.INFO) +logger: logging.Logger = logging.getLogger("data_producer") + + +class QueryItem: + """ + Container for a query item to be processed by the inference thread pool. + + Attributes: + query_ids: List of unique identifiers for the queries in this batch. + samples: The sample data containing features for the queries. + start: Time when the query was first received. + dt_queue: Time spent in the queue before processing. + dt_batching: Time spent on batching the data. + """ + + def __init__( + self, + query_ids: List[int], + samples: Samples, + start: float, + dt_queue: float, + dt_batching: float, + ) -> None: + self.query_ids = query_ids + self.samples = samples + self.start: float = start + self.dt_queue: float = dt_queue + self.dt_batching: float = dt_batching + + +class SingleThreadDataProducer: + """ + Single-threaded data producer for synchronous query processing. + + This producer processes queries on the main thread without any parallelism, + suitable for debugging or low-throughput scenarios. + + Args: + ds: The dataset to fetch samples from. + run_one_item: Callback function to process a single QueryItem. + """ + + def __init__(self, ds: Dataset, run_one_item) -> None: # pyre-ignore [2] + self.ds = ds + self.run_one_item = run_one_item # pyre-ignore [4] + + def enqueue( + self, query_ids: List[int], content_ids: List[int], t0: float, dt_queue: float + ) -> None: + """ + Enqueue queries for immediate synchronous processing. + + Args: + query_ids: List of unique query identifiers. + content_ids: List of content/sample identifiers to fetch. + t0: Timestamp when the query batch was created. + dt_queue: Time spent waiting in the queue. + """ + with torch.profiler.record_function("data batching"): + t0_batching: float = time.time() + samples: Union[Samples, List[Samples]] = self.ds.get_samples(content_ids) + dt_batching: float = time.time() - t0_batching + if isinstance(samples, Samples): + query = QueryItem( + query_ids=query_ids, + samples=samples, + start=t0, + dt_queue=dt_queue, + dt_batching=dt_batching, + ) + self.run_one_item(query) + else: + start_idx = 0 + for sample in samples: + batch_size: int = sample.batch_size() + query = QueryItem( + query_ids=query_ids[start_idx : start_idx + batch_size], + samples=sample, + start=t0, + dt_queue=dt_queue, + dt_batching=dt_batching, + ) + start_idx += batch_size + self.run_one_item(query) + + def finish(self) -> None: + """Finalize the producer. No-op for single-threaded mode.""" + pass + + +class MultiThreadDataProducer: + """ + Multi-threaded data producer for parallel query processing. + + Uses a thread pool to fetch and batch data in parallel with model inference, + improving throughput for high-load scenarios. + + Args: + ds: The dataset to fetch samples from. + threads: Number of worker threads to use. + run_one_item: Callback function to process a single QueryItem. + """ + + def __init__( + self, + ds: Dataset, + threads: int, + run_one_item, # pyre-ignore [2] + ) -> None: + queue_size_multiplier = 4 + self.ds = ds + self.threads = threads + self.run_one_item = run_one_item # pyre-ignore [4] + self.tasks: Queue[Optional[Tuple[List[int], List[int], float, float]]] = Queue( + maxsize=threads * queue_size_multiplier + ) + self.workers: List[threading.Thread] = [] + for _ in range(self.threads): + worker = threading.Thread(target=self.handle_tasks, args=(self.tasks,)) + worker.daemon = True + self.workers.append(worker) + worker.start() + + def handle_tasks( + self, tasks_queue: Queue[Optional[Tuple[List[int], List[int], float, float]]] + ) -> None: + """ + Worker thread main loop to process tasks from the queue. + + Each worker maintains its own CUDA stream for parallel execution. + + Args: + tasks_queue: Queue containing task tuples or None for termination. + """ + stream = torch.cuda.Stream() + while True: + query_and_content_ids = tasks_queue.get() + if query_and_content_ids is None: + tasks_queue.task_done() + break + query_ids, content_ids, t0, dt_queue = query_and_content_ids + t0_batching: float = time.time() + samples: Union[Samples, List[Samples]] = self.ds.get_samples(content_ids) + dt_batching: float = time.time() - t0_batching + if isinstance(samples, Samples): + qitem = QueryItem( + query_ids=query_ids, + samples=samples, + start=t0, + dt_queue=dt_queue, + dt_batching=dt_batching, + ) + with torch.inference_mode(), torch.cuda.stream(stream): + self.run_one_item(qitem) + else: + start_idx = 0 + for sample in samples: + batch_size: int = sample.batch_size() + qitem = QueryItem( + query_ids=query_ids[start_idx : start_idx + batch_size], + samples=sample, + start=t0, + dt_queue=dt_queue, + dt_batching=dt_batching, + ) + start_idx += batch_size + with torch.inference_mode(), torch.cuda.stream(stream): + self.run_one_item(qitem) + tasks_queue.task_done() + + def enqueue( + self, query_ids: List[int], content_ids: List[int], t0: float, dt_queue: float + ) -> None: + """ + Enqueue queries for asynchronous processing by worker threads. + + Args: + query_ids: List of unique query identifiers. + content_ids: List of content/sample identifiers to fetch. + t0: Timestamp when the query batch was created. + dt_queue: Time spent waiting in the queue. + """ + with torch.profiler.record_function("data batching"): + self.tasks.put((query_ids, content_ids, t0, dt_queue)) + + def finish(self) -> None: + """ + Signal all worker threads to terminate and wait for completion. + + Sends None to each worker to trigger graceful shutdown. + """ + for _ in self.workers: + self.tasks.put(None) + for worker in self.workers: + worker.join() diff --git a/recommendation/dlrm_v3/datasets/dataset.py b/recommendation/dlrm_v3/datasets/dataset.py new file mode 100644 index 0000000000..495c5836c1 --- /dev/null +++ b/recommendation/dlrm_v3/datasets/dataset.py @@ -0,0 +1,399 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +Dataset implementations for DLRMv3. + +This module provides dataset classes for loading and processing recommendation +data, including sample containers, collation functions, and random data generation. +""" + +import logging +import time +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import torch +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +logging.basicConfig(level=logging.INFO) +logger: logging.Logger = logging.getLogger("dlrmv3_dataset") + + +@dataclass +class Samples: + """ + Container for batched samples with user interaction history and candidate features. + + Attributes: + uih_features_kjt: User interaction history features as KeyedJaggedTensor. + candidates_features_kjt: Candidate item features as KeyedJaggedTensor. + """ + + uih_features_kjt: KeyedJaggedTensor + candidates_features_kjt: KeyedJaggedTensor + + def to(self, device: torch.device) -> None: + """ + Move all tensors to the specified device. + + Args: + device: Target device to move tensors to. + """ + for attr in vars(self): + setattr(self, attr, getattr(self, attr).to(device=device)) + + def batch_size(self) -> int: + """ + Get the batch size of the samples. + + Returns: + Number of samples in the batch. + """ + return self.uih_features_kjt.stride() + + +def collate_fn( + samples: List[Tuple[KeyedJaggedTensor, KeyedJaggedTensor]], +) -> Samples: + """ + Collate multiple samples into a batched Samples object. + + Args: + samples: List of (uih_features, candidates_features) tuples. + + Returns: + Batched Samples object with concatenated features. + """ + ( + uih_features_kjt_list, + candidates_features_kjt_list, + ) = list(zip(*samples)) + + return Samples( + uih_features_kjt=kjt_batch_func(uih_features_kjt_list), + candidates_features_kjt=kjt_batch_func(candidates_features_kjt_list), + ) + + +class Dataset: + """ + Base dataset class for DLRMv3. + + Provides the interface for loading, accessing, and managing samples + for recommendation model training and inference. + + Args: + hstu_config: HSTU model configuration. + **args: Additional arguments (unused in base class). + """ + + def __init__(self, hstu_config: DlrmHSTUConfig, **args): + self.arrival = None + self.image_list = [] + self.label_list = [] + self.image_list_inmemory = {} + self.last_loaded = -1.0 + + def preprocess(self, use_cache=True): + """ + Preprocess the dataset. + + Args: + use_cache: Whether to use cached preprocessed data. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:preprocess") + + def get_item_count(self): + """ + Get the total number of items in the dataset. + + Returns: + Number of items. + """ + return len(self.image_list) + + def load_query_samples(self, sample_list): + """ + Load specified samples into memory. + + Args: + sample_list: List of sample indices to load. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:load_query_samples") + + def unload_query_samples(self, sample_list): + """ + Unload specified samples from memory. + + Args: + sample_list: List of sample indices to unload. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:unload_query_samples") + + def get_sample(self, id: int): + """ + Get a single sample by ID. + + Args: + id: Sample identifier. + + Raises: + NotImplementedError: Subclasses must implement this method. + """ + raise NotImplementedError("Dataset:get_sample") + + def get_samples(self, id_list: List[int]) -> Samples: + """ + Get multiple samples and collate them into a batch. + + Args: + id_list: List of sample identifiers. + + Returns: + Collated Samples object containing the batch. + """ + list_samples = [self.get_sample(ix) for ix in id_list] + return collate_fn(list_samples) + + +@torch.jit.script +def kjt_batch_func( + kjt_list: List[KeyedJaggedTensor], +) -> KeyedJaggedTensor: + """ + Batch multiple KeyedJaggedTensors into a single tensor. + + Uses FBGEMM operations for efficient batching and reordering of + jagged tensor data. + + Args: + kjt_list: List of KeyedJaggedTensors to batch. + + Returns: + Batched KeyedJaggedTensor with reordered indices and lengths. + """ + bs_list = [kjt.stride() for kjt in kjt_list] + bs = sum(bs_list) + batched_length = torch.cat([kjt.lengths() for kjt in kjt_list], dim=0) + batched_indices = torch.cat([kjt.values() for kjt in kjt_list], dim=0) + bs_offset = torch.ops.fbgemm.asynchronous_complete_cumsum( + torch.tensor(bs_list) + ).int() + batched_offset = torch.ops.fbgemm.asynchronous_complete_cumsum(batched_length) + reorder_length = torch.ops.fbgemm.reorder_batched_ad_lengths( + batched_length, bs_offset, bs + ) + reorder_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(reorder_length) + reorder_indices = torch.ops.fbgemm.reorder_batched_ad_indices( + batched_offset, batched_indices, reorder_offsets, bs_offset, bs + ) + out = KeyedJaggedTensor( + keys=kjt_list[0].keys(), + lengths=reorder_length.long(), + values=reorder_indices.long(), + ) + return out + + +def get_random_data( + contexual_features: List[str], + hstu_uih_keys: List[str], + hstu_candidates_keys: List[str], + uih_max_seq_len: int, + max_num_candidates: int, + value_bound: int = 1000, +): + """ + Generate random sample data for testing and debugging. + + Creates synthetic user interaction history and candidate features + with random values. + + Args: + contexual_features: List of contextual feature names. + hstu_uih_keys: List of UIH feature keys. + hstu_candidates_keys: List of candidate feature keys. + uih_max_seq_len: Maximum sequence length for UIH. + max_num_candidates: Maximum number of candidates. + value_bound: Upper bound for random values. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + uih_non_seq_feature_keys = contexual_features + uih_seq_feature_keys = [ + k for k in hstu_uih_keys if k not in uih_non_seq_feature_keys + ] + uih_seq_len = torch.randint( + int(uih_max_seq_len * 0.8), + uih_max_seq_len + 1, + (1,), + ).item() + uih_lengths = torch.tensor( + [1 for _ in uih_non_seq_feature_keys] + + [uih_seq_len for _ in uih_seq_feature_keys] + ) + # logging.info(f"uih_lengths: {uih_lengths}") + uih_values = torch.randint( + 1, + value_bound, + # pyre-ignore[6] + (uih_seq_len * len(uih_seq_feature_keys) + len(uih_non_seq_feature_keys),), + ) + uih_features_kjt = KeyedJaggedTensor( + keys=uih_non_seq_feature_keys + uih_seq_feature_keys, + lengths=uih_lengths.long(), + values=uih_values.long(), + ) + num_candidates = torch.randint( + 1, + max_num_candidates + 1, + (1,), + ).item() + candidates_lengths = num_candidates * torch.ones(len(hstu_candidates_keys)) + candidates_values = torch.randint( + 1, + value_bound, + (num_candidates * len(hstu_candidates_keys),), # pyre-ignore[6] + ) + candidates_features_kjt = KeyedJaggedTensor( + keys=hstu_candidates_keys, + lengths=candidates_lengths.long(), + values=candidates_values.long(), + ) + return uih_features_kjt, candidates_features_kjt + + +class DLRMv3RandomDataset(Dataset): + """ + Dataset that generates random synthetic data for DLRMv3. + + Useful for testing and benchmarking without real data dependencies. + + Args: + hstu_config: HSTU model configuration. + num_aggregated_samples: Total number of samples to generate. + is_inference: Whether the dataset is used for inference mode. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + num_aggregated_samples: int = 10000, + is_inference: bool = False, + *args, + **kwargs, + ): + super().__init__( + hstu_config=hstu_config, + ) + self.hstu_config: DlrmHSTUConfig = hstu_config + self._max_num_candidates: int = hstu_config.max_num_candidates + self._max_num_candidates_inference: int = ( + hstu_config.max_num_candidates_inference + ) + self._max_seq_len: int = hstu_config.max_seq_len + self._uih_keys: List[str] = hstu_config.hstu_uih_feature_names + self._candidates_keys: List[str] = hstu_config.hstu_candidate_feature_names + self._contextual_feature_to_max_length: Dict[str, int] = ( + hstu_config.contextual_feature_to_max_length + ) + self._max_uih_len: int = ( + self._max_seq_len + - self._max_num_candidates + - ( + len(self._contextual_feature_to_max_length) + if self._contextual_feature_to_max_length + else 0 + ) + ) + self._is_inference = is_inference + + self.contexual_features = [] + if hstu_config.contextual_feature_to_max_length is not None: + self.contexual_features = [ + p[0] for p in hstu_config.contextual_feature_to_max_length + ] + + self.num_aggregated_samples = num_aggregated_samples + self.items_in_memory = {} + + def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Get a sample by ID from in-memory storage. + + Args: + id: Sample identifier. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + return self.items_in_memory[id] + + def get_item_count(self): + """ + Get the total number of samples in the dataset. + + Returns: + Number of aggregated samples. + """ + return self.num_aggregated_samples + + def unload_query_samples(self, sample_list): + """ + Clear all samples from memory. + + Args: + sample_list: Ignored; clears all samples. + """ + self.items_in_memory = {} + + def load_query_samples(self, sample_list): + """ + Generate and load random samples into memory. + + Args: + sample_list: List of sample IDs to generate. + """ + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + self.items_in_memory = {} + for sample in sample_list: + self.items_in_memory[sample] = get_random_data( + contexual_features=self.contexual_features, + hstu_uih_keys=self.hstu_config.hstu_uih_feature_names, + hstu_candidates_keys=self.hstu_config.hstu_candidate_feature_names, + uih_max_seq_len=self._max_uih_len, + max_num_candidates=max_num_candidates, + ) + self.last_loaded = time.time() diff --git a/recommendation/dlrm_v3/datasets/synthetic_streaming.py b/recommendation/dlrm_v3/datasets/synthetic_streaming.py new file mode 100644 index 0000000000..8cddcc36d2 --- /dev/null +++ b/recommendation/dlrm_v3/datasets/synthetic_streaming.py @@ -0,0 +1,400 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +Synthetic streaming dataset for DLRMv3 inference benchmarking. + +This module provides a streaming dataset implementation that loads user interaction +data from pre-generated CSV files with temporal (timestamp) organization, suitable +for simulating real-time recommendation scenarios. +""" + +import csv +import logging +import sys +import time +from typing import Any, Dict, List, Set, Tuple + +import pandas as pd +import torch +from datasets.dataset import ( + collate_fn, + DLRMv3RandomDataset, + Samples, +) +from datasets.utils import ( + json_loads, + maybe_truncate_seq, +) +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +csv.field_size_limit(sys.maxsize) +logger: logging.Logger = logging.getLogger(__name__) + + +class DLRMv3SyntheticStreamingDataset(DLRMv3RandomDataset): + """ + Streaming dataset that loads pre-generated synthetic recommendation data. + + Supports timestamp-based data organization for simulating streaming scenarios + where user interaction histories evolve over time. + + Args: + hstu_config: HSTU model configuration. + ratings_file_prefix: Path prefix for rating data files. + is_inference: Whether dataset is used for inference. + train_ts: Number of timestamps used for training. + total_ts: Total number of timestamps in the data. + num_files: Number of data files (for parallelization). + num_users: Total number of users in the dataset. + num_items: Total number of items in the catalog. + num_categories: Number of item categories. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + ratings_file_prefix: str, + is_inference: bool, + train_ts: int, + total_ts: int, + num_files: int, + num_users: int, + num_items: int, + num_categories: int, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(hstu_config=hstu_config, is_inference=is_inference) + self.ratings_file_prefix = ratings_file_prefix + self.file_to_offsets: Dict[int, List[int]] = {} + with open(f"{self.ratings_file_prefix}offset.csv", "r") as file: + reader = csv.reader(file) + for size in range(num_files): + row = next(reader) + assert len(row) == 1 + offset = json_loads(row[0]) + assert len(offset) == num_users // num_files + self.file_to_offsets[size] = offset + self.ts_requests_offsets: List[int] = [] + with open(f"{self.ratings_file_prefix}requests_per_ts_offset.csv", "r") as file: + reader = csv.reader(file) + row = next(reader) + assert len(row) == 1 + self.ts_requests_offsets = json_loads(row[0]) + assert len(self.ts_requests_offsets) == total_ts + self.requests: List[int] = [] + self.ts_to_users_cumsum: Dict[int, List[int]] = {} + with open( + f"{self.ratings_file_prefix}users_cumsum_per_ts.csv", "r" + ) as cumsum_file: + reader = csv.reader(cumsum_file) + ts = 0 + for row in reader: + assert len(row) == 1 + cumsum = json_loads(row[0]) + self.ts_to_users_cumsum[ts] = cumsum + ts += 1 + self.train_ts = train_ts + self.total_ts = total_ts + self.num_files = num_files + self.ts: int = -1 + self.is_inference: bool = False + self.is_eval: bool = False + self.users_per_file: int = num_users // num_files + self.cached_files: Set[str] = set() + self.items_per_category: int = num_items // num_categories + assert hstu_config.action_weights is not None + self.action_weights: List[int] = hstu_config.action_weights + self.items_in_memory: Dict[ + int, Dict[int, Tuple[KeyedJaggedTensor, KeyedJaggedTensor]] + ] = {} + + def get_item_count(self) -> int: + return len(self.requests) + + def load_query_samples(self, sample_list: List[int]) -> None: + max_num_candidates = ( + self._max_num_candidates_inference + if self._is_inference + else self._max_num_candidates + ) + for idx in sample_list: + data = self.iloc(idx) + sample = self.load_item(data, max_num_candidates) + if self.ts not in self.items_in_memory: + self.items_in_memory[self.ts] = {} + self.items_in_memory[self.ts][idx] = sample + + self.last_loaded = time.time() + + def unload_query_samples(self, sample_list: List[int]) -> None: + self.items_in_memory = {} + + def get_sample(self, id: int) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + return self.items_in_memory[self.ts][id] + + def get_sample_with_ts( + self, id: int, ts: int + ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Get a sample for a specific timestamp. + + Args: + id: Sample identifier. + ts: Timestamp index. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + return self.items_in_memory[ts][id] + + def get_samples_with_ts(self, id_list: List[int], ts: int) -> Samples: + """ + Get and collate multiple samples for a specific timestamp. + + Args: + id_list: List of sample identifiers. + ts: Timestamp index. + + Returns: + Collated Samples object. + """ + list_samples = [self.get_sample_with_ts(ix, ts) for ix in id_list] + return collate_fn(list_samples) + + def _process_line(self, line: str, user_id: int) -> pd.Series: + """ + Parse a CSV line into a pandas Series with user interaction data. + + Args: + line: CSV line containing user data. + user_id: User identifier. + + Returns: + pd.Series with parsed user interaction history and candidates. + """ + reader = csv.reader([line]) + parsed_line = next(reader) + # total ts + one more eval ts + one base ts so that uih won't be zero + # for each ts, ordered as candidate_ids, candidate_ratings, uih_ids, uih_ratings + assert len(parsed_line) == 4 * (self.total_ts + 2) + uih_item_ids_list = [] + uih_ratings_list = [] + candidate_item_ids = "" + candidate_ratings = "" + if (not self.is_eval) and (not self.is_inference): + assert self.ts < self.train_ts + for i in range(self.ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 1)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 1)] + elif self.is_eval: + for i in range(self.ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 1)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 1)] + else: + assert self.is_inference is True + assert self.ts >= self.train_ts + for i in range(self.train_ts + 1): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + for i in range(self.train_ts + 2, self.ts + 2): + if parsed_line[4 * i]: + uih_item_ids_list.append(parsed_line[2 + 4 * i]) + uih_ratings_list.append(parsed_line[3 + 4 * i]) + candidate_item_ids = parsed_line[4 * (self.ts + 2)] + candidate_ratings = parsed_line[1 + 4 * (self.ts + 2)] + uih_item_ids = ",".join(uih_item_ids_list) + uih_ratings = ",".join(uih_ratings_list) + assert candidate_item_ids != "" and candidate_ratings != "" + return pd.Series( + data={ + "user_id": user_id, + "uih_item_ids": uih_item_ids, + "uih_ratings": uih_ratings, + "candidate_item_ids": candidate_item_ids, + "candidate_ratings": candidate_ratings, + } + ) + + def iloc(self, idx: int) -> pd.Series: + """ + Get user data by request index using file offsets for efficient access. + + Args: + idx: Request index within the current timestamp. + + Returns: + pd.Series with parsed user interaction data. + """ + cumsum: List[int] = self.ts_to_users_cumsum[self.ts] + assert cumsum != [] + assert idx < cumsum[-1] + file_idx: int = 0 + while cumsum[file_idx] <= idx: + file_idx += 1 + user_idx = self.requests[idx] + filename = f"{self.ratings_file_prefix}{file_idx}.csv" + with open(filename, "r") as file: + idx = user_idx % self.users_per_file + file.seek(self.file_to_offsets[file_idx][idx]) + line = file.readline() + data = self._process_line(line=line, user_id=user_idx) + return data + + def get_timestamp_uih( + self, data: pd.Series, max_num_candidates: int, size: int + ) -> List[int]: + return [1] * size + + def set_ts(self, ts: int) -> None: + """ + Set the current timestamp and load associated request data. + + Args: + ts: Timestamp index to set. + """ + logger.warning(f"Streaming dataset ts set to {ts}") + if ts == self.ts: + return + self.ts = ts + with open( + f"{self.ratings_file_prefix}requests_per_ts.csv", "r" + ) as request_file: + request_file.seek(self.ts_requests_offsets[self.ts]) + line = request_file.readline() + reader = csv.reader([line]) + row = next(reader) + assert len(row) == 1 + requests = json_loads(row[0]) + self.requests = requests + logger.warning(f"DLRMv3SyntheticStreamingDataset: ts={ts} requests loaded") + assert self.ts_to_users_cumsum[self.ts][-1] == len(self.requests) + logger.warning( + f"DLRMv3SyntheticStreamingDataset: ts={ts} users_cumsum={self.ts_to_users_cumsum[self.ts]}" + ) + + def load_item( + self, data: pd.Series, max_num_candidates: int + ) -> Tuple[KeyedJaggedTensor, KeyedJaggedTensor]: + """ + Load and process a single user's data into KeyedJaggedTensors. + + Converts parsed user data into feature tensors suitable for model input, + including truncation to maximum sequence lengths. + + Args: + data: pd.Series with user interaction history and candidates. + max_num_candidates: Maximum number of candidates to include. + + Returns: + Tuple of (uih_features_kjt, candidates_features_kjt). + """ + ids_uih = json_loads(data.uih_item_ids) + ids_candidates = json_loads(data.candidate_item_ids) + ratings_uih = json_loads(data.uih_ratings) + ratings_candidates = json_loads(data.candidate_ratings) + timestamps_uih = self.get_timestamp_uih( + data=data, + max_num_candidates=max_num_candidates, + size=len(ids_uih), + ) + assert len(ids_uih) == len( + timestamps_uih + ), "history len differs from timestamp len." + assert len(ids_uih) == len( + ratings_uih + ), f"history len {len(ids_uih)} differs from ratings len {len(ratings_uih)}." + assert ( + len(ids_candidates) == len(ratings_candidates) + ), f"candidates len {len(ids_candidates)} differs from ratings len {len(ratings_candidates)}." + + ids_uih = maybe_truncate_seq(ids_uih, self._max_uih_len) + ratings_uih = maybe_truncate_seq(ratings_uih, self._max_uih_len) + timestamps_uih = maybe_truncate_seq(timestamps_uih, self._max_uih_len) + ids_candidates = maybe_truncate_seq(ids_candidates, max_num_candidates) + num_candidates = len(ids_candidates) + ratings_candidates = maybe_truncate_seq(ratings_candidates, max_num_candidates) + action_weights_uih = [ + self.action_weights[int(rating) - 1] for rating in ratings_uih + ] + action_weights_candidates = [ + int(rating >= 3.5) for rating in ratings_candidates + ] + + uih_kjt_values: List[int] = [] + uih_kjt_lengths: List[int] = [] + for name, length in self._contextual_feature_to_max_length.items(): + uih_kjt_values.append(data[name]) + uih_kjt_lengths.append(length) + + uih_seq_len = len(ids_uih) + dummy_watch_times_uih = [0 for _ in range(uih_seq_len)] + item_category_ids = [id // self.items_per_category for id in ids_uih] + extend_uih_kjt_values: List[int] = ( + ids_uih + + ratings_uih + + timestamps_uih + + action_weights_uih + + dummy_watch_times_uih + + item_category_ids + ) + uih_kjt_values.extend(extend_uih_kjt_values) + uih_kjt_lengths.extend( + [ + uih_seq_len + for _ in range( + len(self._uih_keys) - len(self._contextual_feature_to_max_length) + ) + ] + ) + + dummy_query_time = 0 if timestamps_uih == [] else max(timestamps_uih) + uih_kjt_values.append(dummy_query_time) + uih_kjt_lengths.append(1) + uih_features_kjt: KeyedJaggedTensor = KeyedJaggedTensor( + keys=self._uih_keys + ["dummy_query_time"], + lengths=torch.tensor(uih_kjt_lengths).long(), + values=torch.tensor(uih_kjt_values).long(), + ) + + candidates_kjt_lengths = num_candidates * torch.ones(len(self._candidates_keys)) + item_candidate_category_ids = [ + id // self.items_per_category for id in ids_candidates + ] + candidates_kjt_values = ( + ids_candidates + + ratings_candidates + + [dummy_query_time] * num_candidates # item_query_time + + action_weights_candidates + + [1] * num_candidates # item_dummy_watchtime + + item_candidate_category_ids + ) + candidates_features_kjt: KeyedJaggedTensor = KeyedJaggedTensor( + keys=self._candidates_keys, + lengths=candidates_kjt_lengths.detach().clone().long(), + values=torch.tensor(candidates_kjt_values).long(), + ) + return uih_features_kjt, candidates_features_kjt diff --git a/recommendation/dlrm_v3/datasets/utils.py b/recommendation/dlrm_v3/datasets/utils.py new file mode 100644 index 0000000000..c85c3cf706 --- /dev/null +++ b/recommendation/dlrm_v3/datasets/utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +Utility functions for dataset processing. + +This module provides helper functions for parsing and processing data +in the DLRMv3 dataset pipeline. +""" +import json +from typing import List, Tuple + + +def json_loads( + x: str | int | List[int], +) -> List[int]: + """ + Parse a JSON-like string into a list of integers. + + Handles multiple input formats including JSON arrays, comma-separated + strings, and single values. + + Args: + x: Input that can be a JSON array string, a single integer, + or already a list of integers. + + Returns: + List of integers parsed from the input. + """ + if isinstance(x, str): + if x[0] != "[" and x[-1] != "]": + x = "[" + x + "]" + y = json.loads(x) + else: + y = x + y_list = [y] if type(y) == int else list(y) + return y_list + + +def separate_uih_candidates( + x: str | int | List[int], + candidates_max_seq_len: int, +) -> Tuple[List[int], List[int]]: + """ + Separate a sequence into user interaction history (UIH) and candidates. + + Splits the input sequence such that the last `candidates_max_seq_len` + elements become candidates and the rest become UIH. + + Args: + x: Input sequence as JSON string, single int, or list of ints. + candidates_max_seq_len: Number of items at the end to use as candidates. + + Returns: + Tuple of (uih, candidates) where both are lists of integers. + """ + if isinstance(x, str): + if x[0] != "[" and x[-1] != "]": + x = "[" + x + "]" + y = json.loads(x) + else: + y = x + y_list = [y] if type(y) == int else list(y) + candidates, uih = ( + y_list[-candidates_max_seq_len:], + y_list[:-candidates_max_seq_len], + ) + return uih, candidates + + +def maybe_truncate_seq( + y: List[int], + max_seq_len: int, +) -> List[int]: + """ + Truncate a sequence if it exceeds the maximum length. + + Args: + y: Input sequence to potentially truncate. + max_seq_len: Maximum allowed sequence length. + + Returns: + The input sequence, truncated to max_seq_len if necessary. + """ + y_len = len(y) + if y_len > max_seq_len: + y = y[:max_seq_len] + return y diff --git a/recommendation/dlrm_v3/generative_recommenders/common.py b/recommendation/dlrm_v3/generative_recommenders/common.py new file mode 100644 index 0000000000..9ba5821d9f --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/common.py @@ -0,0 +1,436 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +import copy +from enum import Enum, unique +from typing import Any, List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton +from generative_recommenders.ops.utils import is_sm100 + +# @manual=//triton:triton +from triton.runtime.autotuner import Autotuner + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + +try: + # @manual=//triton:triton + import triton.language.extra.tlx # type: ignore + + HAS_TLX = True +except ImportError: + HAS_TLX = False + +try: + from generative_recommenders.fb.triton_cc.utils import triton_cc + from hammer.ops.triton.utils import triton_autotune + from hammer.utils import is_dev_mode, set_dev_mode, set_verbose_level +except ImportError: + # pyre-ignore + def triton_cc(annotations): + # pyre-ignore + def decorator(fn): + return fn + + return decorator + + # pyre-ignore + def triton_autotune( + configs: List[triton.Config], + key: List[str], + # pyre-ignore + prune_configs_by=None, + # pyre-ignore + reset_to_zero=None, + # pyre-ignore + restore_value=None, + warmup: int = 25, + rep: int = 100, + ): + # pyre-ignore + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + pre_hook=None, + post_hook=None, + prune_configs_by=prune_configs_by, + warmup=warmup, + rep=rep, + ) + + return decorator + + DEV_MODE: bool = False + VERBOSE_LEVEL: int = 0 + + def set_dev_mode(val: bool) -> None: + global DEV_MODE + DEV_MODE = val + + def is_dev_mode() -> bool: + global DEV_MODE + return DEV_MODE + + def set_verbose_level(level: int) -> None: + global VERBOSE_LEVEL + VERBOSE_LEVEL = level + + def get_verbose_level() -> int: + global VERBOSE_LEVEL + return VERBOSE_LEVEL + + +@unique +class HammerKernel(Enum): + TRITON = "TRITON" + PYTORCH = "PYTORCH" + CUDA = "CUDA" + TRITON_CC = "TRITON_CC" + + +class HammerModule(torch.nn.Module, abc.ABC): + _is_inference: bool = False + _use_triton_cc: bool = True + _training_dtype: torch.dtype = torch.float32 + _hammer_kernel: Optional[HammerKernel] = None + + def __init__( + self, + is_inference: bool, + training_dytpe: torch.dtype = torch.float32, + use_triton_cc: bool = _use_triton_cc, + hammer_kernel: Optional[HammerKernel] = None, + ) -> None: + super().__init__() + self._is_inference = is_inference + self._training_dtype = training_dytpe + self._hammer_kernel = hammer_kernel + self._use_triton_cc = use_triton_cc + + def hammer_kernel(self) -> HammerKernel: + kernel = self._hammer_kernel + if kernel is not None: + return kernel + if self._is_inference and self._use_triton_cc: + return HammerKernel.TRITON_CC + else: + return HammerKernel.TRITON + + # pyre-ignore[2] + def recursive_setattr(self, name: str, value: Any) -> None: + for _, module in self.named_modules(): + if hasattr(module, name): + setattr(module, name, value) + + def set_use_triton_cc(self, use_triton_cc: bool) -> None: + self._use_triton_cc = use_triton_cc + self.recursive_setattr("_use_triton_cc", use_triton_cc) + + def set_is_inference(self, is_inference: bool) -> None: + self._is_inference = is_inference + self.recursive_setattr("_is_inference", is_inference) + + def set_training_dtype(self, training_dtype: torch.dtype) -> None: + self._training_dtype = training_dtype + self.recursive_setattr("_training_dtype", training_dtype) + + def set_hammer_kernel(self, hammer_kernel: HammerKernel) -> None: + self._hammer_kernel = hammer_kernel + self.recursive_setattr("_hammer_kernel", hammer_kernel) + + @property + def is_inference(self) -> bool: + return self._is_inference + + @property + def is_eval(self) -> bool: + return (not self._is_inference) and (not self.training) + + @property + def is_train(self) -> bool: + return (not self._is_inference) and self.training + + +def generate_sparse_seq_len( + size: int, + max_seq_len: int, + sparsity: float, + device: torch.device, +) -> torch.Tensor: + if sparsity == 0.0: + return torch.zeros(size=(size,), device=device, dtype=torch.int) + elif sparsity == 1.0: + return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len + elif sparsity >= 0.5: + min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len) + return torch.randint( + low=min_seq_len, + high=max_seq_len, + size=(size,), + device=device, + dtype=torch.int, + ) + else: + min_seq_len: int = 0 + max_seq_len: int = int(2 * sparsity * max_seq_len) + return torch.randint( + low=min_seq_len, + high=max_seq_len, + size=(size,), + device=device, + dtype=torch.int, + ) + + +def apply_sampling( + lengths: torch.Tensor, + alpha: float, + max_seq_len: int, +) -> torch.Tensor: + threshold = int(max_seq_len ** (alpha / 2)) + no_sample_prob = (max_seq_len**alpha) / torch.pow(lengths, 2) + users_to_sample = torch.logical_and( + lengths > threshold, + torch.rand_like(no_sample_prob) < 1 - no_sample_prob, + ) + lengths = torch.where(users_to_sample, threshold, lengths) + return lengths + + +nv_gpu_unavailable: Tuple[bool, str] = ( + not torch.cuda.is_available() or torch.cuda.device_count() == 0, + "CUDA is not available or no GPUs detected", +) +nv_gpu_available: bool = not nv_gpu_unavailable[0] + + +amd_gpu_unavailable: Tuple[bool, str] = ( + not torch.version.hip, + "AMD HIP not available or no GPUs detected", +) +amd_gpu_available: bool = not amd_gpu_unavailable[0] + +gpu_unavailable: Tuple[bool, str] = ( + not nv_gpu_available and not amd_gpu_available, + "CUDA/HIP is not available or no GPUs detected", +) + +gpu_available: bool = not gpu_unavailable[0] + +blackwell_tlx_unavailable: Tuple[bool, str] = ( + not is_sm100() or not HAS_TLX, + "Skip TLX and blackwell only tests", +) + + +def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor: + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + # Tell Dynamo this data-dependent value is in the range (0, 10**9) + torch._check(x.size(0) > 0) + torch._check(x.size(0) < 10**9) + if x.stride(-1) == 1: + return x + return x.contiguous() + + +@torch.fx.wrap +def prev_power_of_2(x: int) -> int: + if torch.compiler.is_compiling(): + # Re-write to make Dynamo happy + x_tensor = torch.scalar_tensor(x, dtype=torch.int64) # type: ignore[arg-type] + x_tensor_orig = x_tensor.clone() + out = triton.next_power_of_2(x_tensor) # type: ignore[arg-type] + return int(torch.where(torch.lt(x_tensor_orig, out), out // 2, out).item()) # type: ignore[return-value] + else: + out = triton.next_power_of_2(x) + return out // 2 if out > x else out + + +STATIC_MAX_SEQ_LENS: List[int] = [] +USE_RUNTIME_MAX_SEQ_LEN: bool = False + + +def set_static_max_seq_lens(max_seq_lens: List[int]) -> None: + global STATIC_MAX_SEQ_LENS + STATIC_MAX_SEQ_LENS = copy.deepcopy(max_seq_lens) + STATIC_MAX_SEQ_LENS.sort() + + +def set_use_runtime_max_seq_len(use_runtime_max_seq_len: bool) -> None: + global USE_RUNTIME_MAX_SEQ_LEN + USE_RUNTIME_MAX_SEQ_LEN = use_runtime_max_seq_len + + +def autotune_max_seq_len(runtime_max_seq_len: int) -> int: + global USE_RUNTIME_MAX_SEQ_LEN + + if USE_RUNTIME_MAX_SEQ_LEN: + return prev_power_of_2(runtime_max_seq_len) + else: + if STATIC_MAX_SEQ_LENS == []: + return 1 + for max_len in STATIC_MAX_SEQ_LENS: + if max_len >= runtime_max_seq_len: + return max_len + return STATIC_MAX_SEQ_LENS[-1] + + +def fine_grained_autotune_max_seq_len(runtime_max_seq_len: int) -> int: + global USE_RUNTIME_MAX_SEQ_LEN + + if USE_RUNTIME_MAX_SEQ_LEN: + return _fine_grained_bucket_size(runtime_max_seq_len) + else: + if STATIC_MAX_SEQ_LENS == []: + return 1 + for max_len in STATIC_MAX_SEQ_LENS: + if max_len >= runtime_max_seq_len: + return max_len + return STATIC_MAX_SEQ_LENS[-1] + + +def _generate_fine_grained_buckets() -> List[int]: + buckets = [ + 1024, + 2048, + 4096, + 8192, + 12288, + 16384, + 24576, + 32768, + 40960, + 49152, + 65536, + 81920, + 98304, + ] + return buckets + + +@torch.fx.wrap +def _fine_grained_bucket_size(x: int) -> int: + if torch.compiler.is_compiling(): + x_tensor = torch.scalar_tensor(x, dtype=torch.int64) + buckets = torch.tensor(_generate_fine_grained_buckets(), dtype=torch.int64) + + mask = buckets >= x_tensor + valid_buckets = torch.where( + mask, buckets, torch.tensor(2**31 - 1, dtype=torch.int64) + ) + + result = torch.where(mask.any(), valid_buckets.min(), buckets[-1]) + + return int(result.item()) + else: + buckets = _generate_fine_grained_buckets() + + for bucket in buckets: + if x <= bucket: + return bucket + + return buckets[-1] + + +@torch.fx.wrap +def fx_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor: + assert optional is not None, "Expected optional to be non-None Tensor" + return optional + + +@torch.fx.wrap +def fx_arange(len: int, device: torch.device) -> torch.Tensor: + return torch.arange(len, device=device) + + +@torch.fx.wrap +def fx_infer_max_len( + lengths: torch.Tensor, +) -> int: + max_len = int(lengths.max().item()) + if not torch.jit.is_scripting() and torch.compiler.is_compiling(): + # Tell Dynamo this data-dependent value is in the range [0, 10**9) + torch._check_is_size(max_len) + torch._check(max_len < 10**9) + torch._check(max_len > 0) + return max_len + + +@torch.fx.wrap +def fx_mark_length_features(tensor: torch.Tensor) -> torch.Tensor: + return tensor + + +@torch.fx.wrap +def fx_torch_ones( + shape: List[int], + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + return torch.ones(shape, device=device, dtype=dtype) + + +@torch.fx.wrap +def fx_torch_zeros(shape: List[int], device: torch.device) -> torch.Tensor: + return torch.zeros(shape, device=device) + + +@torch.fx.wrap +def jagged_to_padded_dense( + values: torch.Tensor, + offsets: List[torch.Tensor], + max_lengths: List[int], + padding_value: float, +) -> torch.Tensor: + return torch.ops.fbgemm.jagged_to_padded_dense( + values=values, + offsets=offsets, + max_lengths=max_lengths, + padding_value=padding_value, + ) + + +@torch.fx.wrap +def dense_to_jagged( + dense: torch.Tensor, + x_offsets: List[torch.Tensor], +) -> torch.Tensor: + return torch.ops.fbgemm.dense_to_jagged( + dense=dense, + x_offsets=x_offsets, + )[0] + + +def init_mlp_weights_optional_bias(m: torch.nn.Module) -> None: + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + m.bias.data.fill_(0.0) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/action_encoder.py b/recommendation/dlrm_v3/generative_recommenders/modules/action_encoder.py new file mode 100644 index 0000000000..0116b99b43 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/action_encoder.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Dict, List, Optional, Tuple + +import torch + +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ActionEncoder(HammerModule): + def __init__( + self, + action_embedding_dim: int, + action_feature_name: str, + action_weights: List[int], + watchtime_feature_name: str = "", + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None, + embedding_init_std: float = 0.1, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._watchtime_feature_name: str = watchtime_feature_name + self._action_feature_name: str = action_feature_name + self._watchtime_to_action_thresholds_and_weights: List[Tuple[int, int]] = ( + watchtime_to_action_thresholds_and_weights + if watchtime_to_action_thresholds_and_weights is not None + else [] + ) + self.register_buffer( + "_combined_action_weights", + torch.tensor( + action_weights + + [x[1] for x in self._watchtime_to_action_thresholds_and_weights] + ), + ) + self._num_action_types: int = len(action_weights) + len( + self._watchtime_to_action_thresholds_and_weights + ) + self._action_embedding_dim = action_embedding_dim + self._action_embedding_table: torch.nn.Parameter = torch.nn.Parameter( + torch.empty((self._num_action_types, action_embedding_dim)).normal_( + mean=0, std=embedding_init_std + ), + ) + self._target_action_embedding_table: torch.nn.Parameter = torch.nn.Parameter( + torch.empty((1, self._num_action_types * action_embedding_dim)).normal_( + mean=0, std=embedding_init_std + ), + ) + + @property + def output_embedding_dim(self) -> int: + return self._action_embedding_dim * self._num_action_types + + def forward( + self, + max_uih_len: int, + max_targets: int, + uih_offsets: torch.Tensor, + target_offsets: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + seq_actions = seq_payloads[self._action_feature_name] + if len(self._watchtime_to_action_thresholds_and_weights) > 0: + watchtimes = seq_payloads[self._watchtime_feature_name] + for threshold, weight in self._watchtime_to_action_thresholds_and_weights: + seq_actions = torch.bitwise_or( + seq_actions, (watchtimes >= threshold).to(torch.int64) * weight + ) + exploded_actions = ( + torch.bitwise_and( + seq_actions.unsqueeze(-1), self._combined_action_weights.unsqueeze(0) + ) + > 0 + ) + action_embeddings = ( + exploded_actions.unsqueeze(-1) * self._action_embedding_table.unsqueeze(0) + ).view(-1, self._num_action_types * self._action_embedding_dim) + total_targets: int = seq_embeddings.size(0) - action_embeddings.size(0) + action_embeddings = concat_2D_jagged( + max_seq_len=max_uih_len + max_targets, + values_left=action_embeddings, + values_right=self._target_action_embedding_table.tile( + total_targets, + 1, + ), + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=target_offsets, + kernel=self.hammer_kernel(), + ) + return action_embeddings diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/content_encoder.py b/recommendation/dlrm_v3/generative_recommenders/modules/content_encoder.py new file mode 100644 index 0000000000..75d73298a4 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/content_encoder.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Dict, List, Optional + +import torch + +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ContentEncoder(HammerModule): + def __init__( + self, + input_embedding_dim: int, + additional_content_features: Optional[Dict[str, int]] = None, + target_enrich_features: Optional[Dict[str, int]] = None, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._input_embedding_dim: int = input_embedding_dim + self._additional_content_features: Dict[str, int] = ( + additional_content_features + if additional_content_features is not None + else {} + ) + self._target_enrich_features: Dict[str, int] = ( + target_enrich_features if target_enrich_features is not None else {} + ) + self._target_enrich_dummy_embeddings: torch.nn.ParameterDict = ( + torch.nn.ParameterDict( + { + name: torch.nn.Parameter( + torch.empty((1, dim)).normal_(mean=0, std=0.1), + ) + for name, dim in self._target_enrich_features.items() + } + ) + ) + + @property + def output_embedding_dim(self) -> int: + return self._input_embedding_dim + sum( + list(self._additional_content_features.values()) + + list(self._target_enrich_features.values()) + ) + + def forward( + self, + max_uih_len: int, + max_targets: int, + uih_offsets: torch.Tensor, + target_offsets: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + content_embeddings_list: List[torch.Tensor] = [seq_embeddings] + if len(self._additional_content_features) > 0: + content_embeddings_list = content_embeddings_list + [ + (seq_payloads[x].to(seq_embeddings.dtype)) + for x in self._additional_content_features.keys() + ] + + if self._target_enrich_dummy_embeddings: + total_seq_len: int = seq_embeddings.size(0) + for name, param in self._target_enrich_dummy_embeddings.items(): + enrich_embeddings_target = seq_payloads[name].to(seq_embeddings.dtype) + total_targets: int = enrich_embeddings_target.size(0) + total_uih_len: int = total_seq_len - total_targets + enrich_embeddings_uih = param.tile(total_uih_len, 1).to( + seq_embeddings.dtype + ) + enrich_embeddings = concat_2D_jagged( + max_seq_len=max_uih_len + max_targets, + values_left=enrich_embeddings_uih, + values_right=enrich_embeddings_target, + max_len_left=max_uih_len, + max_len_right=max_targets, + offsets_left=uih_offsets, + offsets_right=target_offsets, + kernel=self.hammer_kernel(), + ) + content_embeddings_list.append(enrich_embeddings) + + if ( + len(self._target_enrich_features) == 0 + and len(self._additional_content_features) == 0 + ): + return seq_embeddings + else: + content_embeddings = torch.cat( + content_embeddings_list, + dim=1, + ) + return content_embeddings diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/contextual_interleave_preprocessor.py b/recommendation/dlrm_v3/generative_recommenders/modules/contextual_interleave_preprocessor.py new file mode 100644 index 0000000000..fff0d72f0d --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/contextual_interleave_preprocessor.py @@ -0,0 +1,357 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from math import sqrt +from typing import Callable, Dict, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor +from generative_recommenders.modules.action_encoder import ActionEncoder +from generative_recommenders.modules.content_encoder import ContentEncoder +from generative_recommenders.modules.contextualize_mlps import ( + ContextualizedMLP, + ParameterizedContextualizedMLP, +) +from generative_recommenders.modules.preprocessors import ( + get_contextual_input_embeddings, + InputPreprocessor, +) +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged + + +class ContextualInterleavePreprocessor(InputPreprocessor): + def __init__( + self, + input_embedding_dim: int, + output_embedding_dim: int, + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + content_encoder: ContentEncoder, + content_contextualize_mlp_fn: Callable[ + [int, int, int, bool], ContextualizedMLP + ], + action_encoder: ActionEncoder, + action_contextualize_mlp_fn: Callable[[int, int, int, bool], ContextualizedMLP], + pmlp_contextual_dropout_ratio: float = 0.0, + enable_interleaving: bool = False, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._input_embedding_dim: int = input_embedding_dim + self._output_embedding_dim: int = output_embedding_dim + self._contextual_feature_to_max_length: Dict[str, int] = ( + contextual_feature_to_max_length + ) + self._max_contextual_seq_len: int = sum( + contextual_feature_to_max_length.values() + ) + self._contextual_feature_to_min_uih_length: Dict[str, int] = ( + contextual_feature_to_min_uih_length + ) + std = 1.0 * sqrt(2.0 / float(input_embedding_dim + output_embedding_dim)) + self._batched_contextual_linear_weights = torch.nn.Parameter( + torch.empty( + ( + self._max_contextual_seq_len, + input_embedding_dim, + output_embedding_dim, + ) + ).normal_(0.0, std) + ) + self._pmlp_contextual_dropout_ratio: float = pmlp_contextual_dropout_ratio + self._batched_contextual_linear_bias = torch.nn.Parameter( + torch.empty((self._max_contextual_seq_len, 1, output_embedding_dim)).fill_( + 0.0 + ) + ) + contextual_embedding_dim: int = ( + self._max_contextual_seq_len * input_embedding_dim + ) + self._content_encoder: ContentEncoder = content_encoder + self._content_embedding_mlp: ContextualizedMLP = content_contextualize_mlp_fn( + self._content_encoder.output_embedding_dim, + output_embedding_dim, + contextual_embedding_dim, + is_inference, + ) + self._action_encoder: ActionEncoder = action_encoder + self._action_embedding_mlp: ContextualizedMLP = action_contextualize_mlp_fn( + self._action_encoder.output_embedding_dim, + output_embedding_dim, + contextual_embedding_dim, + is_inference, + ) + self._enable_interleaving: bool = enable_interleaving + + def combine_embeddings( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + content_embeddings: torch.Tensor, + action_embeddings: torch.Tensor, + contextual_embeddings: Optional[torch.Tensor], + num_targets: torch.Tensor, + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + if self._enable_interleaving: + output_seq_timestamps = seq_timestamps.repeat_interleave(2) + output_seq_embeddings = torch.stack( + [content_embeddings, action_embeddings], dim=1 + ).reshape(-1, self._output_embedding_dim) + if self.interleave_targets(): + output_seq_lengths = seq_lengths * 2 + output_max_seq_len = (max_uih_len + max_targets) * 2 + output_num_targets = num_targets * 2 + output_total_uih_len = total_uih_len * 2 + output_total_targets = total_targets * 2 + else: + seq_lengths_by_2 = seq_lengths * 2 + output_seq_lengths = seq_lengths_by_2 - num_targets + output_max_seq_len = 2 * max_uih_len + max_targets + indices = torch.arange( + 2 * (max_uih_len + max_targets), device=seq_lengths.device + ).view(1, -1) + valid_mask = torch.logical_and( + indices < seq_lengths_by_2.view(-1, 1), + torch.logical_or( + indices < (output_seq_lengths - num_targets).view(-1, 1), + torch.remainder(indices, 2) == 0, + ), + ) + jagged_valid_mask = ( + torch.ops.fbgemm.dense_to_jagged( + valid_mask.int().unsqueeze(-1), + [ + torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_lengths_by_2 + ) + ], + )[0] + .to(torch.bool) + .squeeze(1) + ) + output_seq_embeddings = output_seq_embeddings[jagged_valid_mask] + output_seq_timestamps = output_seq_timestamps[jagged_valid_mask] + output_num_targets = num_targets + output_total_uih_len = total_uih_len * 2 + output_total_targets = total_targets + else: + output_max_seq_len = max_uih_len + max_targets + output_seq_lengths = seq_lengths + output_num_targets = num_targets + output_seq_timestamps = seq_timestamps + output_seq_embeddings = content_embeddings + action_embeddings + output_total_uih_len = total_uih_len + output_total_targets = total_targets + + # concat contextual embeddings + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + if self._max_contextual_seq_len > 0: + output_seq_embeddings = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=output_seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ) + output_seq_timestamps = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=torch.zeros( + (output_seq_lengths.size(0) * self._max_contextual_seq_len, 1), + dtype=output_seq_timestamps.dtype, + device=output_seq_timestamps.device, + ), + values_right=output_seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ).squeeze(-1) + output_max_seq_len = output_max_seq_len + self._max_contextual_seq_len + output_total_uih_len = ( + output_total_uih_len + + self._max_contextual_seq_len * output_seq_lengths.size(0) + ) + output_seq_lengths = output_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + ) + + def forward( # noqa C901 + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + max_seq_len = max_uih_len + max_targets + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference and self._training_dtype == torch.bfloat16), + ): + # get contextual_embeddings + contextual_embeddings: Optional[torch.Tensor] = None + pmlp_contextual_embeddings: Optional[torch.Tensor] = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_lengths=seq_lengths, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self._contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self._contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + if isinstance( + self._action_embedding_mlp, ParameterizedContextualizedMLP + ) or isinstance( + self._action_embedding_mlp, ParameterizedContextualizedMLP + ): + pmlp_contextual_embeddings = torch.nn.functional.dropout( + contextual_input_embeddings, + p=self._pmlp_contextual_dropout_ratio, + training=self.training, + ) + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.to( + contextual_input_embeddings.dtype + ), + contextual_input_embeddings.view( + -1, self._max_contextual_seq_len, self._input_embedding_dim + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + + # content embeddings + seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths) + target_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_targets) + uih_offsets = seq_offsets - target_offsets + content_embeddings = self._content_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + content_embeddings = self._content_embedding_mlp( + seq_embeddings=content_embeddings, + seq_offsets=seq_offsets, + max_seq_len=max_seq_len, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # action embeddings + action_embeddings = self._action_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ).to(seq_embeddings.dtype) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + seq_offsets=seq_offsets, + max_seq_len=max_seq_len, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + ) = self.combine_embeddings( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + content_embeddings=content_embeddings, + action_embeddings=action_embeddings, + contextual_embeddings=contextual_embeddings, + num_targets=num_targets, + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + seq_payloads, + ) + + def interleave_targets(self) -> bool: + return self.is_train and self._enable_interleaving diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/contextualize_mlps.py b/recommendation/dlrm_v3/generative_recommenders/modules/contextualize_mlps.py new file mode 100644 index 0000000000..95c29f0381 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/contextualize_mlps.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc + +from typing import Optional + +import torch + +from generative_recommenders.common import HammerModule, init_mlp_weights_optional_bias +from generative_recommenders.ops.jagged_tensors import jagged_dense_bmm_broadcast_add +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm +from libfb.py.pyre import none_throws + + +class ContextualizedMLP(HammerModule): + @abc.abstractmethod + def forward( + self, + max_seq_len: int, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + seq_embeddings: (L, D) + seq_offsets: (B + 1,) + max_seq_len: int + contextual_embeddings: (B, D') + """ + pass + + +class SimpleContextualizedMLP(ContextualizedMLP): + def __init__( + self, + sequential_input_dim: int, + sequential_output_dim: int, + hidden_dim: int, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=sequential_input_dim, + out_features=hidden_dim, + ), + SwishLayerNorm(hidden_dim, is_inference=is_inference), + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_output_dim, + ), + LayerNorm(sequential_output_dim), + ).apply(init_mlp_weights_optional_bias) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + return self._mlp(seq_embeddings) + + +class ParameterizedContextualizedMLP(ContextualizedMLP): + def __init__( + self, + contextual_embedding_dim: int, + sequential_input_dim: int, + sequential_output_dim: int, + hidden_dim: int, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._sequential_input_dim: int = sequential_input_dim + self._sequential_output_dim: int = sequential_output_dim + + self._dense_features_compress: torch.nn.Module = torch.nn.Linear( + in_features=contextual_embedding_dim, + out_features=hidden_dim, + ).apply(init_mlp_weights_optional_bias) + + self._attn_raw_weights: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_input_dim * sequential_output_dim, + ), + ).apply(init_mlp_weights_optional_bias) + + self._attn_weights_norm: torch.nn.Module = torch.nn.LayerNorm( + [sequential_input_dim, sequential_output_dim] + ) + + self._res_weights: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hidden_dim, + out_features=hidden_dim, + ), + SwishLayerNorm(hidden_dim), + torch.nn.Linear( + in_features=hidden_dim, + out_features=sequential_output_dim, + ), + ).apply(init_mlp_weights_optional_bias) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + max_seq_len: int, + contextual_embeddings: Optional[torch.Tensor], + ) -> torch.Tensor: + shared_input = self._dense_features_compress(none_throws(contextual_embeddings)) + attn_weights = self._attn_weights_norm( + self._attn_raw_weights(shared_input).reshape( + -1, self._sequential_input_dim, self._sequential_output_dim + ) + ) + return jagged_dense_bmm_broadcast_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=seq_embeddings, + dense=attn_weights.to(seq_embeddings.dtype), + bias=self._res_weights(shared_input), + kernel=self.hammer_kernel(), + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/dlrm_hstu.py b/recommendation/dlrm_v3/generative_recommenders/modules/dlrm_hstu.py new file mode 100644 index 0000000000..003abe77dd --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/dlrm_hstu.py @@ -0,0 +1,581 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, NamedTuple, Optional, Tuple + +import torch +from generative_recommenders.common import ( + fx_infer_max_len, + fx_mark_length_features, + HammerKernel, + HammerModule, + init_mlp_weights_optional_bias, + set_static_max_seq_lens, +) +from generative_recommenders.modules.hstu_transducer import HSTUTransducer +from generative_recommenders.modules.multitask_module import ( + DefaultMultitaskModule, + MultitaskTaskType, + TaskConfig, +) +from generative_recommenders.modules.positional_encoder import HSTUPositionalEncoder +from generative_recommenders.modules.postprocessors import ( + LayerNormPostprocessor, + TimestampLayerNormPostprocessor, +) +from generative_recommenders.modules.preprocessors import ContextualPreprocessor +from generative_recommenders.modules.stu import STU, STULayer, STULayerConfig, STUStack +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm +from torch.autograd.profiler import record_function +from torchrec import KeyedJaggedTensor +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection + +logger: logging.Logger = logging.getLogger(__name__) + +torch.fx.wrap("fx_infer_max_len") +torch.fx.wrap("len") + + +class SequenceEmbedding(NamedTuple): + lengths: torch.Tensor + embedding: torch.Tensor + + +@dataclass +class DlrmHSTUConfig: + max_seq_len: int = 16384 + max_num_candidates: int = 10 + max_num_candidates_inference: int = 5 + hstu_num_heads: int = 1 + hstu_attn_linear_dim: int = 256 + hstu_attn_qk_dim: int = 128 + hstu_attn_num_layers: int = 12 + hstu_embedding_table_dim: int = 192 + hstu_preprocessor_hidden_dim: int = 256 + hstu_transducer_embedding_dim: int = 0 + hstu_group_norm: bool = False + hstu_input_dropout_ratio: float = 0.2 + hstu_linear_dropout_rate: float = 0.2 + contextual_feature_to_max_length: Dict[str, int] = field(default_factory=dict) + contextual_feature_to_min_uih_length: Dict[str, int] = field(default_factory=dict) + candidates_weight_feature_name: str = "" + candidates_watchtime_feature_name: str = "" + candidates_querytime_feature_name: str = "" + causal_multitask_weights: float = 0.2 + multitask_configs: List[TaskConfig] = field(default_factory=list) + user_embedding_feature_names: List[str] = field(default_factory=list) + item_embedding_feature_names: List[str] = field(default_factory=list) + uih_post_id_feature_name: str = "" + uih_action_time_feature_name: str = "" + uih_weight_feature_name: str = "" + hstu_uih_feature_names: List[str] = field(default_factory=list) + hstu_candidate_feature_names: List[str] = field(default_factory=list) + merge_uih_candidate_feature_mapping: List[Tuple[str, str]] = field( + default_factory=list + ) + action_weights: Optional[List[int]] = None + action_embedding_init_std: float = 0.1 + enable_postprocessor: bool = True + use_layer_norm_postprocessor: bool = False + + +def _get_supervision_labels_and_weights( + supervision_bitmasks: torch.Tensor, + watchtime_sequence: torch.Tensor, + task_configs: List[TaskConfig], +) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + supervision_labels: Dict[str, torch.Tensor] = {} + supervision_weights: Dict[str, torch.Tensor] = {} + for task in task_configs: + if task.task_type == MultitaskTaskType.REGRESSION: + supervision_labels[task.task_name] = watchtime_sequence.to(torch.float32) + elif task.task_type == MultitaskTaskType.BINARY_CLASSIFICATION: + supervision_labels[task.task_name] = ( + torch.bitwise_and(supervision_bitmasks, task.task_weight) > 0 + ).to(torch.float32) + else: + raise RuntimeError("Unsupported MultitaskTaskType") + return supervision_labels, supervision_weights + + +class DlrmHSTU(HammerModule): + def __init__( # noqa C901 + self, + hstu_configs: DlrmHSTUConfig, + embedding_tables: Dict[str, EmbeddingConfig], + is_inference: bool, + is_dense: bool = False, + bf16_training: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + logger.info(f"Initialize HSTU module with configs {hstu_configs}") + self._hstu_configs = hstu_configs + self._bf16_training: bool = bf16_training + set_static_max_seq_lens([self._hstu_configs.max_seq_len]) + + if not is_dense: + self._embedding_collection: EmbeddingCollection = EmbeddingCollection( + tables=list(embedding_tables.values()), + need_indices=False, + device=torch.device("meta"), + ) + + # multitask configs must be sorted by task types + self._multitask_configs: List[TaskConfig] = hstu_configs.multitask_configs + self._multitask_module = DefaultMultitaskModule( + task_configs=self._multitask_configs, + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + prediction_fn=lambda in_dim, num_tasks: torch.nn.Sequential( + torch.nn.Linear(in_features=in_dim, out_features=512), + SwishLayerNorm(512), + torch.nn.Linear(in_features=512, out_features=num_tasks), + ).apply(init_mlp_weights_optional_bias), + causal_multitask_weights=hstu_configs.causal_multitask_weights, + is_inference=self._is_inference, + ) + self._additional_embedding_features: List[str] = [ + uih_feature_name + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping + if ( + candidate_feature_name + in self._hstu_configs.item_embedding_feature_names + ) + and (uih_feature_name in self._hstu_configs.user_embedding_feature_names) + and (uih_feature_name is not self._hstu_configs.uih_post_id_feature_name) + ] + + # preprocessor setup + preprocessor = ContextualPreprocessor( + input_embedding_dim=hstu_configs.hstu_embedding_table_dim, + hidden_dim=hstu_configs.hstu_preprocessor_hidden_dim, + output_embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + contextual_feature_to_max_length=hstu_configs.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=hstu_configs.contextual_feature_to_min_uih_length, + action_embedding_dim=8, + action_feature_name=self._hstu_configs.uih_weight_feature_name, + action_weights=self._hstu_configs.action_weights, + action_embedding_init_std=self._hstu_configs.action_embedding_init_std, + additional_embedding_features=self._additional_embedding_features, + is_inference=is_inference, + ) + + # positional encoder + positional_encoder = HSTUPositionalEncoder( + num_position_buckets=8192, + num_time_buckets=2048, + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + contextual_seq_len=sum( + dict(hstu_configs.contextual_feature_to_max_length).values() + ), + is_inference=self._is_inference, + ) + + if hstu_configs.enable_postprocessor: + if hstu_configs.use_layer_norm_postprocessor: + postprocessor = LayerNormPostprocessor( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + eps=1e-5, + is_inference=self._is_inference, + ) + else: + postprocessor = TimestampLayerNormPostprocessor( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + time_duration_features=[ + (60 * 60, 24), # hour of day + (24 * 60 * 60, 7), # day of week + # (24 * 60 * 60, 365), # time of year (approximate) + ], + eps=1e-5, + is_inference=self._is_inference, + ) + else: + postprocessor = None + + # construct HSTU + stu_module: STU = STUStack( + stu_list=[ + STULayer( + config=STULayerConfig( + embedding_dim=hstu_configs.hstu_transducer_embedding_dim, + num_heads=hstu_configs.hstu_num_heads, + hidden_dim=hstu_configs.hstu_attn_linear_dim, + attention_dim=hstu_configs.hstu_attn_qk_dim, + output_dropout_ratio=hstu_configs.hstu_linear_dropout_rate, + use_group_norm=hstu_configs.hstu_group_norm, + causal=True, + target_aware=True, + max_attn_len=None, + attn_alpha=None, + recompute_normed_x=True, + recompute_uvqk=True, + recompute_y=True, + sort_by_length=True, + contextual_seq_len=0, + ), + is_inference=is_inference, + ) + for _ in range(hstu_configs.hstu_attn_num_layers) + ], + is_inference=is_inference, + ) + self._hstu_transducer: HSTUTransducer = HSTUTransducer( + stu_module=stu_module, + input_preprocessor=preprocessor, + output_postprocessor=postprocessor, + input_dropout_ratio=hstu_configs.hstu_input_dropout_ratio, + positional_encoder=positional_encoder, + is_inference=self._is_inference, + return_full_embeddings=False, + listwise=False, + ) + + # item embeddings + self._item_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=hstu_configs.hstu_embedding_table_dim + * len(self._hstu_configs.item_embedding_feature_names), + out_features=512, + ), + SwishLayerNorm(512), + torch.nn.Linear( + in_features=512, + out_features=hstu_configs.hstu_transducer_embedding_dim, + ), + LayerNorm(hstu_configs.hstu_transducer_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + + def _construct_payload( + self, + payload_features: Dict[str, torch.Tensor], + seq_embeddings: Dict[str, SequenceEmbedding], + ) -> Dict[str, torch.Tensor]: + if len(self._hstu_configs.contextual_feature_to_max_length) > 0: + contextual_offsets: List[torch.Tensor] = [] + for x in self._hstu_configs.contextual_feature_to_max_length.keys(): + contextual_offsets.append( + torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_embeddings[x].lengths + ) + ) + else: + # Dummy, offsets are unused + contextual_offsets = torch.empty((0, 0)) + return { + **payload_features, + **{ + x: seq_embeddings[x].embedding + for x in self._hstu_configs.contextual_feature_to_max_length.keys() + }, + **{ + x + "_offsets": contextual_offsets[i] + for i, x in enumerate( + list(self._hstu_configs.contextual_feature_to_max_length.keys()) + ) + }, + **{ + x: seq_embeddings[x].embedding + for x in self._additional_embedding_features + }, + } + + def _user_forward( + self, + max_uih_len: int, + max_candidates: int, + seq_embeddings: Dict[str, SequenceEmbedding], + payload_features: Dict[str, torch.Tensor], + num_candidates: torch.Tensor, + ) -> torch.Tensor: + source_lengths = seq_embeddings[ + self._hstu_configs.uih_post_id_feature_name + ].lengths + source_timestamps = concat_2D_jagged( + max_seq_len=max_uih_len + max_candidates, + max_len_left=max_uih_len, + offsets_left=payload_features["uih_offsets"], + values_left=payload_features[ + self._hstu_configs.uih_action_time_feature_name + ].unsqueeze(-1), + max_len_right=max_candidates, + offsets_right=payload_features["candidate_offsets"], + values_right=payload_features[ + self._hstu_configs.candidates_querytime_feature_name + ].unsqueeze(-1), + kernel=self.hammer_kernel(), + ).squeeze(-1) + total_targets = int(num_candidates.sum().item()) + embedding = seq_embeddings[ + self._hstu_configs.uih_post_id_feature_name + ].embedding + dtype = embedding.dtype + if (not self.is_inference) and self._bf16_training: + embedding = embedding.to(torch.bfloat16) + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference) and self._bf16_training, + ): + candidates_user_embeddings, _ = self._hstu_transducer( + max_uih_len=max_uih_len, + max_targets=max_candidates, + total_uih_len=source_timestamps.numel() - total_targets, + total_targets=total_targets, + seq_embeddings=embedding, + seq_lengths=source_lengths, + seq_timestamps=source_timestamps, + seq_payloads=self._construct_payload( + payload_features=payload_features, + seq_embeddings=seq_embeddings, + ), + num_targets=num_candidates, + ) + candidates_user_embeddings = candidates_user_embeddings.to(dtype) + + return candidates_user_embeddings + + def _item_forward( + self, + seq_embeddings: Dict[str, SequenceEmbedding], + ) -> torch.Tensor: # [L, D] + all_embeddings = torch.cat( + [ + seq_embeddings[name].embedding + for name in self._hstu_configs.item_embedding_feature_names + ], + dim=-1, + ) + item_embeddings = self._item_embedding_mlp(all_embeddings) + return item_embeddings + + def preprocess( + self, + uih_features: KeyedJaggedTensor, + candidates_features: KeyedJaggedTensor, + ) -> Tuple[ + Dict[str, SequenceEmbedding], + Dict[str, torch.Tensor], + int, + torch.Tensor, + int, + torch.Tensor, + ]: + # embedding lookup for uih and candidates + merged_sparse_features = KeyedJaggedTensor.from_lengths_sync( + keys=uih_features.keys() + candidates_features.keys(), + values=torch.cat( + [uih_features.values(), candidates_features.values()], + dim=0, + ), + lengths=torch.cat( + [uih_features.lengths(), candidates_features.lengths()], + dim=0, + ), + ) + seq_embeddings_dict = self._embedding_collection(merged_sparse_features) + num_candidates = fx_mark_length_features( + candidates_features.lengths().view(len(candidates_features.keys()), -1) + )[0] + max_num_candidates = fx_infer_max_len(num_candidates) + uih_seq_lengths = uih_features[ + self._hstu_configs.uih_post_id_feature_name + ].lengths() + max_uih_len = fx_infer_max_len(uih_seq_lengths) + + # prepare payload features + payload_features: Dict[str, torch.Tensor] = {} + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping: + if ( + candidate_feature_name + not in self._hstu_configs.item_embedding_feature_names + and uih_feature_name + not in self._hstu_configs.user_embedding_feature_names + ): + values_left = uih_features[uih_feature_name].values() + if self._is_inference and ( + candidate_feature_name + == self._hstu_configs.candidates_weight_feature_name + or candidate_feature_name + == self._hstu_configs.candidates_watchtime_feature_name + ): + total_candidates = torch.sum(num_candidates).item() + values_right = torch.zeros( + total_candidates, # pyre-ignore + dtype=torch.int64, + device=values_left.device, + ) + else: + values_right = candidates_features[candidate_feature_name].values() + payload_features[uih_feature_name] = values_left + payload_features[candidate_feature_name] = values_right + payload_features["uih_offsets"] = torch.ops.fbgemm.asynchronous_complete_cumsum( + uih_seq_lengths + ) + payload_features["candidate_offsets"] = ( + torch.ops.fbgemm.asynchronous_complete_cumsum(num_candidates) + ) + + seq_embeddings = { + k: SequenceEmbedding( + lengths=seq_embeddings_dict[k].lengths(), + embedding=seq_embeddings_dict[k].values(), + ) + for k in self._hstu_configs.user_embedding_feature_names + + self._hstu_configs.item_embedding_feature_names + } + + return ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) + + def main_forward( + self, + seq_embeddings: Dict[str, SequenceEmbedding], + payload_features: Dict[str, torch.Tensor], + max_uih_len: int, + uih_seq_lengths: torch.Tensor, + max_num_candidates: int, + num_candidates: torch.Tensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + # merge uih and candidates embeddings + for ( + uih_feature_name, + candidate_feature_name, + ) in self._hstu_configs.merge_uih_candidate_feature_mapping: + if uih_feature_name in seq_embeddings: + seq_embeddings[uih_feature_name] = SequenceEmbedding( + lengths=uih_seq_lengths + num_candidates, + embedding=concat_2D_jagged( + max_seq_len=max_uih_len + max_num_candidates, + max_len_left=max_uih_len, + offsets_left=torch.ops.fbgemm.asynchronous_complete_cumsum( + uih_seq_lengths + ), + values_left=seq_embeddings[uih_feature_name].embedding, + max_len_right=max_num_candidates, + offsets_right=torch.ops.fbgemm.asynchronous_complete_cumsum( + num_candidates + ), + values_right=seq_embeddings[candidate_feature_name].embedding, + kernel=self.hammer_kernel(), + ), + ) + + with record_function("## item_forward ##"): + candidates_item_embeddings = self._item_forward( + seq_embeddings, + ) + with record_function("## user_forward ##"): + candidates_user_embeddings = self._user_forward( + max_uih_len=max_uih_len, + max_candidates=max_num_candidates, + seq_embeddings=seq_embeddings, + payload_features=payload_features, + num_candidates=num_candidates, + ) + with record_function("## multitask_module ##"): + supervision_labels, supervision_weights = ( + _get_supervision_labels_and_weights( + supervision_bitmasks=payload_features[ + self._hstu_configs.candidates_weight_feature_name + ], + watchtime_sequence=payload_features[ + self._hstu_configs.candidates_watchtime_feature_name + ], + task_configs=self._multitask_configs, + ) + ) + mt_target_preds, mt_target_labels, mt_target_weights, mt_losses = ( + self._multitask_module( + encoded_user_embeddings=candidates_user_embeddings, + item_embeddings=candidates_item_embeddings, + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + ) + ) + + aux_losses: Dict[str, torch.Tensor] = {} + if not self._is_inference and self.training: + for i, task in enumerate(self._multitask_configs): + aux_losses[task.task_name] = mt_losses[i] + + return ( + candidates_user_embeddings, + candidates_item_embeddings, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) + + def forward( + self, + uih_features: KeyedJaggedTensor, + candidates_features: KeyedJaggedTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + with record_function("## preprocess ##"): + ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) = self.preprocess( + uih_features=uih_features, + candidates_features=candidates_features, + ) + + with record_function("## main_forward ##"): + return self.main_forward( + seq_embeddings=seq_embeddings, + payload_features=payload_features, + max_uih_len=max_uih_len, + uih_seq_lengths=uih_seq_lengths, + max_num_candidates=max_num_candidates, + num_candidates=num_candidates, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/dynamic_stu.py b/recommendation/dlrm_v3/generative_recommenders/modules/dynamic_stu.py new file mode 100644 index 0000000000..e1fe8ad161 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/dynamic_stu.py @@ -0,0 +1,304 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc +import contextlib +from typing import Any, Generator, Optional, Tuple + +import torch +from generative_recommenders.common import fx_infer_max_len +from generative_recommenders.modules.stu import STU +from generative_recommenders.ops.jagged_tensors import ( + hstu_concat_l2_embeddings, + hstu_split_l2_embeddings, +) + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@contextlib.contextmanager +# pyre-ignore[3] +def _freeze_rng_state() -> Generator[Any, None, None]: + rng_state = torch.get_rng_state() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + try: + yield + finally: + if torch.cuda.is_available(): + # pyre-ignore[61] + torch.cuda.set_rng_state(cuda_rng_state) + torch.set_rng_state(rng_state) + + +class DynamicSTU(STU): + def __init__(self, stu: STU, is_inference: bool) -> None: + super().__init__(is_inference) + self._stu = stu + + @abc.abstractmethod + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + pass + + @abc.abstractmethod + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + pass + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ( + x, + x_lengths, + x_offsets, + max_seq_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) = self._preprocess( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + stu_output = self._stu( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + return self._postprocess( + stu_output=stu_output, + ) + + +class SDSTU(DynamicSTU): + def __init__( + self, + stu: STU, + is_inference: bool, + dropout_ratio: float = 0.5, + seed: int = 0, + ) -> None: + """ + Stochastic Depth STU + """ + super().__init__(stu=stu, is_inference=is_inference) + self._dropout_ratio: float = dropout_ratio + self._iter: int = 0 + self._seed: int = seed + self._skip_x: Optional[torch.Tensor] = None + + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + if self.training: + with _freeze_rng_state(): + torch.manual_seed(self._iter + self._seed) + prob = torch.rand(1) + if prob.item() <= self._dropout_ratio: + new_x = torch.empty(size=(0, x.shape[1]), device=x.device) + self._skip_x = x + new_x_lengths = torch.zeros_like(x_lengths) + new_x_offsets = torch.zeros_like(x_offsets) + new_max_seq_len = 1 + else: + new_x = x + new_x_lengths = x_lengths + new_x_offsets = x_offsets + new_max_seq_len = max_seq_len + self._iter += 1 + else: + new_x = x + new_x_lengths = x_lengths + new_x_offsets = x_offsets + new_max_seq_len = max_seq_len + return ( + new_x, + new_x_lengths, + new_x_offsets, + new_max_seq_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) + + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + if self.training and self._skip_x is not None: + ret = self._skip_x + self._skip_x = None + return ret + else: + return stu_output + + +@torch.fx.wrap +def _fx_unwrap_optional_tuple_tensor( + optional: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert optional is not None, "Expected optional to be non-None" + return optional + + +class L2STU(DynamicSTU): + def __init__( + self, + stu: STU, + max_l2_len: int, + is_inference: bool, + contextual_seq_len: int = 0, + ) -> None: + """ + Stochastic Depth STU + """ + super().__init__(stu=stu, is_inference=is_inference) + self._max_l2_len: int = max_l2_len + self._contextual_seq_len: int = contextual_seq_len + self._saved_tensors: Optional[ + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = None + self._runtime_max_l2_len: int = 0 + self._runtime_prefix_len: int = 0 + + def _preprocess( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + int, + torch.Tensor, + int, + Optional[torch.Tensor], + ]: + prefix_lengths = ( + x_lengths - self._max_l2_len - num_targets - self._contextual_seq_len + ) + prefix_lengths = torch.clamp(prefix_lengths, min=0) + prefix_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(prefix_lengths) + l2_lengths = x_lengths - prefix_lengths + l2_offsets = x_offsets - prefix_offsets + self._runtime_max_l2_len: int = fx_infer_max_len(l2_lengths) + self._runtime_prefix_len: int = fx_infer_max_len(prefix_lengths) + prefix_x, l2_x = hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ) + self._saved_tensors = ( + prefix_offsets, + prefix_x, + l2_offsets, + ) + return ( + l2_x, + l2_lengths, + l2_offsets, + self._runtime_max_l2_len, + num_targets, + max_kv_caching_len, + kv_caching_lengths, + ) + + def _postprocess( + self, + stu_output: torch.Tensor, + ) -> torch.Tensor: + ( + prefix_offsets, + prefix_x, + l2_offsets, + ) = _fx_unwrap_optional_tuple_tensor(self._saved_tensors) + self._saved_tensors = None + return hstu_concat_l2_embeddings( + max_prefix_len=self._runtime_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=self._runtime_max_l2_len, + l2_x=stu_output, + l2_offsets=l2_offsets, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/hstu_transducer.py b/recommendation/dlrm_v3/generative_recommenders/modules/hstu_transducer.py new file mode 100644 index 0000000000..b4ae836ada --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/hstu_transducer.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import logging +from typing import Dict, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor, HammerModule +from generative_recommenders.modules.positional_encoder import HSTUPositionalEncoder +from generative_recommenders.modules.postprocessors import ( + L2NormPostprocessor, + OutputPostprocessor, +) +from generative_recommenders.modules.preprocessors import InputPreprocessor +from generative_recommenders.modules.stu import STU +from generative_recommenders.ops.jagged_tensors import split_2D_jagged + +from torch.profiler import record_function + +logger: logging.Logger = logging.getLogger(__name__) +torch.fx.wrap("len") + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def default_seq_payload( + seq_payloads: Optional[Dict[str, torch.Tensor]], +) -> Dict[str, torch.Tensor]: + if seq_payloads is None: + return {} + else: + return torch.jit._unwrap_optional(seq_payloads) + + +class HSTUTransducer(HammerModule): + def __init__( + self, + stu_module: STU, + input_preprocessor: InputPreprocessor, + output_postprocessor: Optional[OutputPostprocessor] = None, + input_dropout_ratio: float = 0.0, + positional_encoder: Optional[HSTUPositionalEncoder] = None, + is_inference: bool = True, + return_full_embeddings: bool = False, + listwise: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._stu_module = stu_module + self._input_preprocessor: InputPreprocessor = input_preprocessor + self._output_postprocessor: OutputPostprocessor = ( + output_postprocessor + if output_postprocessor is not None + else L2NormPostprocessor(is_inference=is_inference) + ) + assert ( + self._is_inference == self._input_preprocessor._is_inference + ), f"input_preprocessor must have the same mode; self: {self._is_inference} vs input_preprocessor {self._input_preprocessor._is_inference}" + self._positional_encoder: Optional[HSTUPositionalEncoder] = positional_encoder + self._input_dropout_ratio: float = input_dropout_ratio + self._return_full_embeddings: bool = return_full_embeddings + self._listwise_training: bool = listwise and self.is_train + + for name, m in self.named_modules(): + if "_stu_module" in name: + continue + elif isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + elif isinstance(m, torch.nn.LayerNorm): + if m.weight.dim() >= 2: + torch.nn.init.xavier_normal_(m.weight) + if m.bias is not None and m.bias.dim() >= 2: + torch.nn.init.xavier_normal_(m.bias) + + def _preprocess( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + seq_payloads = default_seq_payload(seq_payloads) + + with record_function("hstu_input_preprocessor"): + ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + output_seq_payloads, + ) = self._input_preprocessor( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + with record_function("hstu_positional_encoder"): + if self._positional_encoder is not None: + output_seq_embeddings = self._positional_encoder( + max_seq_len=output_max_seq_len, + seq_lengths=output_seq_lengths, + seq_offsets=output_seq_offsets, + seq_timestamps=output_seq_timestamps, + seq_embeddings=output_seq_embeddings, + num_targets=( + None if self._listwise_training else output_num_targets + ), + ) + + output_seq_embeddings = torch.nn.functional.dropout( + output_seq_embeddings, + p=self._input_dropout_ratio, + training=self.training, + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + output_seq_payloads, + ) + + def _hstu_compute( + self, + max_seq_len: int, + seq_lengths: torch.Tensor, + seq_offsets: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + ) -> torch.Tensor: + with record_function("hstu"): + seq_embeddings = self._stu_module( + max_seq_len=max_seq_len, + x=seq_embeddings, + x_lengths=seq_lengths, + x_offsets=seq_offsets, + num_targets=(None if self._listwise_training else num_targets), + ) + return seq_embeddings + + def _postprocess( + self, + max_seq_len: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: + with record_function("hstu_output_postprocessor"): + if self._return_full_embeddings: + seq_embeddings = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + seq_payloads=seq_payloads, + ) + uih_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + seq_lengths - num_targets + ) + candidates_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + num_targets + ) + _, candidate_embeddings = split_2D_jagged( + values=seq_embeddings, + max_seq_len=max_seq_len, + total_len_left=total_uih_len, + total_len_right=total_targets, + offsets_left=uih_offsets, + offsets_right=candidates_offsets, + kernel=self.hammer_kernel(), + ) + interleave_targets: bool = self._input_preprocessor.interleave_targets() + if interleave_targets: + candidate_embeddings = candidate_embeddings.view( + -1, 2, candidate_embeddings.size(-1) + )[:, 0, :] + if not self._return_full_embeddings: + _, candidate_timestamps = split_2D_jagged( + values=seq_timestamps.unsqueeze(-1), + max_seq_len=max_seq_len, + total_len_left=total_uih_len, + total_len_right=total_targets, + offsets_left=uih_offsets, + offsets_right=candidates_offsets, + kernel=self.hammer_kernel(), + ) + candidate_timestamps = candidate_timestamps.squeeze(-1) + if interleave_targets: + candidate_timestamps = candidate_timestamps.view(-1, 2)[:, 0] + candidate_embeddings = self._output_postprocessor( + seq_embeddings=candidate_embeddings, + seq_timestamps=candidate_timestamps, + seq_payloads=seq_payloads, + ) + + return ( + seq_embeddings if self._return_full_embeddings else None, + candidate_embeddings, + ) + + def forward( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + ]: + orig_dtype = seq_embeddings.dtype + if not self._is_inference: + seq_embeddings = seq_embeddings.to(self._training_dtype) + + ( + max_seq_len, + total_uih_len, + total_targets, + seq_lengths, + seq_offsets, + seq_timestamps, + seq_embeddings, + num_targets, + seq_payloads, + ) = self._preprocess( + max_uih_len=max_uih_len, + max_targets=max_targets, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + encoded_embeddings = self._hstu_compute( + max_seq_len=max_seq_len, + seq_lengths=seq_lengths, + seq_offsets=seq_offsets, + seq_timestamps=seq_timestamps, + seq_embeddings=seq_embeddings, + num_targets=num_targets, + ) + + encoded_embeddings, encoded_candidate_embeddings = self._postprocess( + max_seq_len=max_seq_len, + total_uih_len=total_uih_len, + total_targets=total_targets, + seq_lengths=seq_lengths, + seq_embeddings=encoded_embeddings, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + ) + + if not self._is_inference: + encoded_candidate_embeddings = encoded_candidate_embeddings.to(orig_dtype) + if self._return_full_embeddings: + encoded_embeddings = fx_unwrap_optional_tensor(encoded_embeddings).to( + orig_dtype + ) + return ( + encoded_candidate_embeddings, + encoded_embeddings, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/multitask_module.py b/recommendation/dlrm_v3/generative_recommenders/modules/multitask_module.py new file mode 100644 index 0000000000..d5efe237ea --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/multitask_module.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +import logging +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from generative_recommenders.common import HammerModule + +logger: logging.Logger = logging.getLogger(__name__) + + +class MultitaskTaskType(IntEnum): + BINARY_CLASSIFICATION = 0 + REGRESSION = 1 + + +@dataclass +class TaskConfig: + task_name: str + task_weight: int + task_type: MultitaskTaskType + + +class MultitaskModule(HammerModule): + @abc.abstractmethod + def forward( + self, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + """ + Computes multi-task predictions. + + Args: + encoded_user_embeddings: (L, D) x float. + item_embeddings: (L, D) x float. + supervision_labels: Dict[T, L] x float or int + supervision_weights: Dict[T', L] x float or int, T' <= T + Returns: + (T, L) x float, predictions, labels, weights, losses + """ + pass + + +def _compute_pred_and_logits( + prediction_module: torch.nn.Module, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + task_offsets: List[int], + has_multiple_task_types: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + mt_logits = prediction_module(encoded_user_embeddings * item_embeddings).transpose( + 0, 1 + ) + mt_preds_list: List[torch.Tensor] = [] + for task_type in MultitaskTaskType: + logits = mt_logits[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + if task_offsets[task_type + 1] - task_offsets[task_type] > 0: + if task_type == MultitaskTaskType.REGRESSION: + mt_preds_list.append(logits) + else: + mt_preds_list.append(F.sigmoid(logits)) + if has_multiple_task_types: + mt_preds: torch.Tensor = torch.concat(mt_preds_list, dim=0) + else: + mt_preds: torch.Tensor = mt_preds_list[0] + + return mt_preds, mt_logits + + +def _compute_labels_and_weights( + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + task_configs: List[TaskConfig], + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + first_label: torch.Tensor = list(supervision_labels.values())[0] + default_supervision_weight = torch.ones_like( + first_label, + dtype=dtype, + device=device, + ) + mt_lables_list: List[torch.Tensor] = [] + mt_weights_list: List[torch.Tensor] = [] + for task in task_configs: + mt_lables_list.append(supervision_labels[task.task_name]) + mt_weights_list.append( + supervision_weights.get(task.task_name, default_supervision_weight) + ) + if len(task_configs) > 1: + mt_labels = torch.stack(mt_lables_list, dim=0) + mt_weights = torch.stack(mt_weights_list, dim=0) + else: + mt_labels = mt_lables_list[0].unsqueeze(0) + mt_weights = mt_weights_list[0].unsqueeze(0) + return mt_labels, mt_weights + + +def _compute_loss( + task_offsets: List[int], + causal_multitask_weights: float, + mt_logits: torch.Tensor, + mt_labels: torch.Tensor, + mt_weights: torch.Tensor, + has_multiple_task_types: bool, +) -> torch.Tensor: + mt_losses_list: List[torch.Tensor] = [] + for task_type in MultitaskTaskType: + if task_offsets[task_type + 1] - task_offsets[task_type] > 0: + logits = mt_logits[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + labels = mt_labels[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + weights = mt_weights[ + task_offsets[task_type] : task_offsets[task_type + 1], + :, + ] + if task_type == MultitaskTaskType.REGRESSION: + mt_losses_list.append( + F.mse_loss(logits, labels, reduction="none") * weights + ) + else: + mt_losses_list.append( + F.binary_cross_entropy_with_logits( + input=logits, target=labels, reduction="none" + ) + * weights + ) + + if has_multiple_task_types: + mt_losses = torch.concat(mt_losses_list, dim=0) + else: + mt_losses = mt_losses_list[0] + mt_losses = ( + mt_losses.sum(-1) / mt_weights.sum(-1).clamp(min=1.0) * causal_multitask_weights + ) + return mt_losses + + +class DefaultMultitaskModule(MultitaskModule): + def __init__( + self, + task_configs: List[TaskConfig], + embedding_dim: int, + prediction_fn: Callable[[int, int], torch.nn.Module], + causal_multitask_weights: float, + is_inference: bool, + ) -> None: + super().__init__(is_inference) + assert ( + sorted(task_configs, key=lambda x: x.task_type) == task_configs + ), "task_configs must be sorted by task_type." + assert len(task_configs) > 0, "task_configs must be non-empty." + self._task_configs: List[TaskConfig] = task_configs + self._task_offsets: List[int] = [0] * (len(MultitaskTaskType) + 1) + for task in self._task_configs: + self._task_offsets[task.task_type + 1] += 1 + self._has_multiple_task_types: bool = self._task_offsets.count(0) < len( + MultitaskTaskType + ) + self._task_offsets[1:] = np.cumsum(self._task_offsets[1:]).tolist() + self._causal_multitask_weights: float = causal_multitask_weights + self._prediction_module: torch.nn.Module = prediction_fn( + embedding_dim, len(task_configs) + ) + + def forward( + self, + encoded_user_embeddings: torch.Tensor, + item_embeddings: torch.Tensor, + supervision_labels: Dict[str, torch.Tensor], + supervision_weights: Dict[str, torch.Tensor], + ) -> Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + orig_dtype = encoded_user_embeddings.dtype + if not self._is_inference: + encoded_user_embeddings = encoded_user_embeddings.to(self._training_dtype) + item_embeddings = item_embeddings.to(self._training_dtype) + + with torch.autocast( + "cuda", + dtype=torch.bfloat16, + enabled=(not self.is_inference and self._training_dtype == torch.bfloat16), + ): + mt_preds, mt_logits = _compute_pred_and_logits( + prediction_module=self._prediction_module, + encoded_user_embeddings=encoded_user_embeddings, + item_embeddings=item_embeddings, + task_offsets=self._task_offsets, + has_multiple_task_types=self._has_multiple_task_types, + ) + + # losses are always computed in fp32 + mt_labels: Optional[torch.Tensor] = None + mt_weights: Optional[torch.Tensor] = None + mt_losses: Optional[torch.Tensor] = None + if not self._is_inference: + mt_labels, mt_weights = _compute_labels_and_weights( + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + task_configs=self._task_configs, + device=encoded_user_embeddings.device, + ) + mt_losses = _compute_loss( + task_offsets=self._task_offsets, + causal_multitask_weights=self._causal_multitask_weights, + mt_logits=mt_logits.to(mt_labels.dtype), + mt_labels=mt_labels, + mt_weights=mt_weights, + has_multiple_task_types=self._has_multiple_task_types, + ) + mt_preds = mt_preds.to(orig_dtype) + + return ( + mt_preds, + mt_labels, + mt_weights, + mt_losses, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/positional_encoder.py b/recommendation/dlrm_v3/generative_recommenders/modules/positional_encoder.py new file mode 100644 index 0000000000..99d904fd4f --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/positional_encoder.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from math import sqrt +from typing import Optional + +import torch +from generative_recommenders.common import HammerModule +from generative_recommenders.ops.position import add_timestamp_positional_embeddings + + +class HSTUPositionalEncoder(HammerModule): + def __init__( + self, + num_position_buckets: int, + num_time_buckets: int, + embedding_dim: int, + contextual_seq_len: int, + is_inference: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + self._embedding_dim: int = embedding_dim + self._contextual_seq_len: int = contextual_seq_len + self._position_embeddings_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty(num_position_buckets, embedding_dim).uniform_( + -sqrt(1.0 / num_position_buckets), + sqrt(1.0 / num_position_buckets), + ), + ) + self._timestamp_embeddings_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty(num_time_buckets + 1, embedding_dim).uniform_( + -sqrt(1.0 / num_time_buckets), + sqrt(1.0 / num_time_buckets), + ), + ) + + def forward( + self, + max_seq_len: int, + seq_lengths: torch.Tensor, + seq_offsets: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: Optional[torch.Tensor], + ) -> torch.Tensor: + seq_embeddings = add_timestamp_positional_embeddings( + alpha=self._embedding_dim**0.5, + max_seq_len=max_seq_len, + max_contextual_seq_len=self._contextual_seq_len, + position_embeddings_weight=self._position_embeddings_weight, + timestamp_embeddings_weight=self._timestamp_embeddings_weight, + seq_offsets=seq_offsets, + seq_lengths=seq_lengths, + seq_embeddings=seq_embeddings, + timestamps=seq_timestamps, + num_targets=num_targets, + interleave_targets=False, + kernel=self.hammer_kernel(), + ) + return seq_embeddings diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/postprocessors.py b/recommendation/dlrm_v3/generative_recommenders/modules/postprocessors.py new file mode 100644 index 0000000000..32fa660602 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/postprocessors.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from abc import abstractmethod +from typing import Dict, List, Tuple + +import torch +from generative_recommenders.common import HammerModule, init_mlp_weights_optional_bias + + +@torch.fx.wrap +def _cast_dtype(t: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + if t.dtype != dtype: + return t.to(dtype) + return t + + +class OutputPostprocessor(HammerModule): + """An abstract class for post-processing user embeddings after HSTU layers.""" + + @abstractmethod + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + """ + Args: + seq_embeddings: (L, D) + seq_timestamps: (L, ) + seq_payloads: str-keyed tensors. Implementation specific. + + Returns: + postprocessed seq_embeddings, (L, D) + """ + pass + + +class L2NormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with l2 norm.""" + + def __init__(self, is_inference: bool = False) -> None: + super().__init__(is_inference=is_inference) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + return seq_embeddings / torch.linalg.norm( + seq_embeddings, ord=2, dim=-1, keepdim=True + ).clamp(min=1e-6) + + +class LayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with layer norm.""" + + def __init__( + self, + embedding_dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._layer_norm: torch.nn.Module = torch.nn.LayerNorm( + normalized_shape=[embedding_dim], eps=eps + ) + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + # pyre-fixme[6]: For 1st argument expected `dtype` but got `Union[dtype, + # Tensor, Module]`. + return self._layer_norm(seq_embeddings.to(self._layer_norm.weight.dtype)) + + +@torch.fx.wrap +def _unsqueeze_if_needed(t: torch.Tensor, embedding: torch.Tensor) -> torch.Tensor: + if embedding.dim() == 3: + return t.unsqueeze(0) + return t + + +class TimestampLayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with timestamp-based MLP -> layer norm.""" + + def __init__( + self, + embedding_dim: int, + time_duration_features: List[Tuple[int, int]], + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + + self._layer_norm: torch.nn.Module = torch.nn.LayerNorm( + normalized_shape=[embedding_dim], eps=eps + ) + self.register_buffer( + "_period_units", + torch.Tensor([f[0] for f in time_duration_features]).view(1, -1), + ) + self.register_buffer( + "_units_per_period", + torch.Tensor([f[1] for f in time_duration_features]).view(1, -1), + ) + self._time_feature_combiner: torch.nn.Module = torch.nn.Linear( + embedding_dim + 2 * len(time_duration_features), + embedding_dim, + ).apply(init_mlp_weights_optional_bias) + + def _concat_time_features( + self, + combined_embeddings: torch.Tensor, + timestamps: torch.Tensor, # [B] or [B, D] + ) -> torch.Tensor: + # concat time representation to combined embeddings + period_units = self._period_units + units_per_period = self._units_per_period + + timestamps = timestamps.unsqueeze(-1) + period_units = _unsqueeze_if_needed(period_units, combined_embeddings) + units_per_period = _unsqueeze_if_needed(units_per_period, combined_embeddings) + _units_since_epoch = torch.div( + timestamps, period_units, rounding_mode="floor" + ) # [sum(N_i), num_time_features] or [B, N, num_time_features] + _units_elapsed = ( + (torch.remainder(_units_since_epoch, units_per_period) / units_per_period) + * 2 + * 3.14 + ) + # Note: `torch.polar` does not support bfloat16 datatype + _units_elapsed_type: torch.dtype = _units_elapsed.dtype + _units_elapsed = torch.view_as_real( + torch.polar( + _cast_dtype(torch.ones_like(_units_elapsed), torch.float32), + _cast_dtype(_units_elapsed, torch.float32), + ) + ).flatten( + -2, -1 + ) # [sum(N_i), num_time_features * 2] or [B, N, num_time_features * 2] + _units_elapsed = _cast_dtype(_units_elapsed, _units_elapsed_type) + combined_embeddings = torch.cat([combined_embeddings, _units_elapsed], dim=-1) + return combined_embeddings + + def forward( + self, + seq_embeddings: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> torch.Tensor: + user_embeddings = self._time_feature_combiner( + self._concat_time_features(seq_embeddings, timestamps=seq_timestamps) + ) + return self._layer_norm(user_embeddings) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/preprocessors.py b/recommendation/dlrm_v3/generative_recommenders/modules/preprocessors.py new file mode 100644 index 0000000000..dc7806bb45 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/preprocessors.py @@ -0,0 +1,334 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import abc +from math import sqrt +from typing import Dict, List, Optional, Tuple + +import torch +from generative_recommenders.common import ( + fx_unwrap_optional_tensor, + HammerModule, + init_mlp_weights_optional_bias, + jagged_to_padded_dense, +) +from generative_recommenders.modules.action_encoder import ActionEncoder +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged +from generative_recommenders.ops.layer_norm import LayerNorm, SwishLayerNorm + + +class InputPreprocessor(HammerModule): + """An abstract class for pre-processing sequence embeddings before HSTU layers.""" + + @abc.abstractmethod + def forward( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + """ + Args: + max_uih_len: int + max_targets: int + total_uih_len: int + total_targets: int + seq_lengths: (B,) + seq_embeddings: (L, D) + seq_timestamps: (B, N) + num_targets: (B,) Optional. + seq_payloads: str-keyed tensors. Implementation specific. + + Returns: + (max_seq_len, total_uih_len, total_targets, lengths, offsets, timestamps, embeddings, num_targets, payloads) updated based on input preprocessor. + """ + pass + + def interleave_targets(self) -> bool: + return False + + +def get_contextual_input_embeddings( + seq_lengths: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + dtype: torch.dtype, +) -> torch.Tensor: + padded_values: List[torch.Tensor] = [] + for key, max_len in contextual_feature_to_max_length.items(): + v = torch.flatten( + jagged_to_padded_dense( + values=seq_payloads[key].to(dtype), + offsets=[seq_payloads[key + "_offsets"]], + max_lengths=[max_len], + padding_value=0.0, + ), + 1, + 2, + ) + min_uih_length = contextual_feature_to_min_uih_length.get(key, 0) + if min_uih_length > 0: + v = v * (seq_lengths.view(-1, 1) >= min_uih_length) + padded_values.append(v) + return torch.cat(padded_values, dim=1) + + +class ContextualPreprocessor(InputPreprocessor): + def __init__( + self, + input_embedding_dim: int, + hidden_dim: int, + output_embedding_dim: int, + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + action_embedding_dim: int = 8, + action_feature_name: str = "", + action_weights: Optional[List[int]] = None, + additional_embedding_features: List[str] = [], + action_embedding_init_std: float = 0.1, + is_inference: bool = True, + ) -> None: + super().__init__(is_inference=is_inference) + self._output_embedding_dim: int = output_embedding_dim + self._input_embedding_dim: int = input_embedding_dim + self._hidden_dim: int = hidden_dim + self._contextual_feature_to_max_length: Dict[str, int] = ( + contextual_feature_to_max_length + ) + self._max_contextual_seq_len: int = sum( + contextual_feature_to_max_length.values() + ) + self._contextual_feature_to_min_uih_length: Dict[str, int] = ( + contextual_feature_to_min_uih_length + ) + if self._max_contextual_seq_len > 0: + std = 1.0 * sqrt( + 2.0 / float(input_embedding_dim + self._output_embedding_dim) + ) + self._batched_contextual_linear_weights: torch.nn.Parameter = ( + torch.nn.Parameter( + torch.empty( + ( + self._max_contextual_seq_len, + input_embedding_dim, + self._output_embedding_dim, + ) + ).normal_(0.0, std) + ) + ) + self._batched_contextual_linear_bias: torch.nn.Parameter = ( + torch.nn.Parameter( + torch.empty( + (self._max_contextual_seq_len, self._output_embedding_dim) + ).fill_(0.0) + ) + ) + self._content_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._input_embedding_dim, + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + self._additional_embedding_features: List[str] = additional_embedding_features + self._additional_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._input_embedding_dim + * len(additional_embedding_features), + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + self._action_feature_name: str = action_feature_name + self._action_weights: Optional[List[int]] = action_weights + if self._action_weights is not None: + self._action_encoder: ActionEncoder = ActionEncoder( + action_feature_name=action_feature_name, + action_weights=self._action_weights, + action_embedding_dim=action_embedding_dim, + embedding_init_std=action_embedding_init_std, + is_inference=is_inference, + ) + self._action_embedding_mlp: torch.nn.Module = torch.nn.Sequential( + torch.nn.Linear( + in_features=self._action_encoder.output_embedding_dim, + out_features=self._hidden_dim, + ), + SwishLayerNorm(self._hidden_dim), + torch.nn.Linear( + in_features=self._hidden_dim, + out_features=self._output_embedding_dim, + ), + LayerNorm(self._output_embedding_dim), + ).apply(init_mlp_weights_optional_bias) + + def forward( # noqa C901 + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: torch.Tensor, + seq_timestamps: torch.Tensor, + seq_embeddings: torch.Tensor, + num_targets: torch.Tensor, + seq_payloads: Dict[str, torch.Tensor], + ) -> Tuple[ + int, + int, + int, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Dict[str, torch.Tensor], + ]: + output_seq_embeddings = self._content_embedding_mlp(seq_embeddings) + if len(self._additional_embedding_features) > 0: + additional_embeddings = torch.cat( + [ + seq_payloads[feature] + for feature in self._additional_embedding_features + ], + dim=1, + ) + output_seq_embeddings = ( + output_seq_embeddings + + self._additional_embedding_mlp(additional_embeddings) + ) + max_seq_len = max_uih_len + max_targets + target_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(num_targets) + seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(seq_lengths) + uih_offsets = seq_offsets - target_offsets + if self._action_weights is not None: + action_embeddings = self._action_encoder( + max_uih_len=max_uih_len, + max_targets=max_targets, + uih_offsets=uih_offsets, + target_offsets=target_offsets, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + output_seq_embeddings = output_seq_embeddings + self._action_embedding_mlp( + action_embeddings + ) + + output_max_seq_len = max_seq_len + output_total_uih_len = total_uih_len + output_total_targets = total_targets + output_seq_lengths = seq_lengths + output_num_targets = num_targets + output_seq_timestamps = seq_timestamps + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + # concat contextual embeddings + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_lengths=seq_lengths, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self._contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self._contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + contextual_embeddings = torch.baddbmm( + self._batched_contextual_linear_bias.view( + -1, 1, self._output_embedding_dim + ).to(contextual_input_embeddings.dtype), + contextual_input_embeddings.view( + -1, self._max_contextual_seq_len, self._input_embedding_dim + ).transpose(0, 1), + self._batched_contextual_linear_weights.to( + contextual_input_embeddings.dtype + ), + ).transpose(0, 1) + output_seq_embeddings = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=fx_unwrap_optional_tensor(contextual_embeddings).reshape( + -1, self._output_embedding_dim + ), + values_right=output_seq_embeddings, + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ) + output_seq_timestamps = concat_2D_jagged( + max_seq_len=self._max_contextual_seq_len + output_max_seq_len, + values_left=torch.zeros( + (output_seq_lengths.size(0) * self._max_contextual_seq_len, 1), + dtype=output_seq_timestamps.dtype, + device=output_seq_timestamps.device, + ), + values_right=output_seq_timestamps.unsqueeze(-1), + max_len_left=self._max_contextual_seq_len, + max_len_right=output_max_seq_len, + offsets_left=None, + offsets_right=output_seq_offsets, + kernel=self.hammer_kernel(), + ).squeeze(-1) + output_max_seq_len = output_max_seq_len + self._max_contextual_seq_len + output_seq_lengths = output_seq_lengths + self._max_contextual_seq_len + output_seq_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + output_seq_lengths + ) + output_total_uih_len = ( + output_total_uih_len + + self._max_contextual_seq_len * output_seq_lengths.size(0) + ) + + return ( + output_max_seq_len, + output_total_uih_len, + output_total_targets, + output_seq_lengths, + output_seq_offsets, + output_seq_timestamps, + output_seq_embeddings, + output_num_targets, + seq_payloads, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/modules/stu.py b/recommendation/dlrm_v3/generative_recommenders/modules/stu.py new file mode 100644 index 0000000000..d186000e38 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/modules/stu.py @@ -0,0 +1,467 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict +import abc +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from generative_recommenders.common import fx_unwrap_optional_tensor, HammerModule +from generative_recommenders.ops.hstu_attention import delta_hstu_mha +from generative_recommenders.ops.hstu_compute import ( + hstu_compute_output, + hstu_compute_uqvk, + hstu_preprocess_and_attention, +) +from generative_recommenders.ops.jagged_tensors import concat_2D_jagged, split_2D_jagged +from torch.autograd.profiler import record_function + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +class STU(HammerModule, abc.ABC): + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + @abc.abstractmethod + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pass + + +@dataclass +class STULayerConfig: + embedding_dim: int + num_heads: int + hidden_dim: int + attention_dim: int + output_dropout_ratio: float = 0.3 + causal: bool = True + target_aware: bool = True + max_attn_len: Optional[int] = None + attn_alpha: Optional[float] = None + use_group_norm: bool = False + recompute_normed_x: bool = True + recompute_uvqk: bool = True + recompute_y: bool = True + sort_by_length: bool = True + contextual_seq_len: int = 0 + + +@torch.fx.wrap +def _update_kv_cache( + max_seq_len: int, + seq_offsets: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + max_kv_caching_len: int, + kv_caching_lengths: Optional[torch.Tensor], + orig_k_cache: Optional[torch.Tensor], + orig_v_cache: Optional[torch.Tensor], + orig_max_kv_caching_len: int, + orig_kv_caching_offsets: Optional[torch.Tensor], +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, Optional[torch.Tensor]]: + if kv_caching_lengths is not None: + kv_caching_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum( + kv_caching_lengths + ) + delta_offsets = seq_offsets - kv_caching_offsets + k_cache, _ = split_2D_jagged( + max_seq_len=max_seq_len, + values=fx_unwrap_optional_tensor(k).flatten(1, 2), + max_len_left=None, + max_len_right=None, + offsets_left=kv_caching_offsets, + offsets_right=delta_offsets, + ) + v_cache, _ = split_2D_jagged( + max_seq_len=max_seq_len, + values=fx_unwrap_optional_tensor(v).flatten(1, 2), + max_len_left=None, + max_len_right=None, + offsets_left=kv_caching_offsets, + offsets_right=delta_offsets, + ) + if max_kv_caching_len == 0: + max_kv_caching_len = int(kv_caching_lengths.max().item()) + return ( + k_cache, + v_cache, + max_kv_caching_len, + kv_caching_offsets, + ) + else: + return ( + orig_k_cache, + orig_v_cache, + orig_max_kv_caching_len, + orig_kv_caching_offsets, + ) + + +@torch.fx.wrap +def _construct_full_kv( + delta_k: torch.Tensor, + delta_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + max_kv_caching_len: int, + kv_caching_offsets: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: + L, _ = delta_k.shape + B = kv_caching_offsets.shape[0] - 1 + delta_size = L // B + full_k = concat_2D_jagged( + max_seq_len=max_kv_caching_len + delta_size, + values_left=k_cache, + values_right=delta_k, + max_len_left=max_kv_caching_len, + max_len_right=delta_size, + offsets_left=kv_caching_offsets, + offsets_right=None, + ) + full_v = concat_2D_jagged( + max_seq_len=max_kv_caching_len + delta_size, + values_left=v_cache, + values_right=delta_v, + max_len_left=max_kv_caching_len, + max_len_right=delta_size, + offsets_left=kv_caching_offsets, + offsets_right=None, + ) + full_kv_caching_offsets = kv_caching_offsets + delta_size * torch.arange( + B + 1, device=delta_k.device + ) + return ( + full_k, + full_v, + max_kv_caching_len + delta_size, + full_kv_caching_offsets, + ) + + +class STULayer(STU): + max_kv_caching_len: int + k_cache: Optional[torch.Tensor] + v_cache: Optional[torch.Tensor] + kv_caching_offsets: Optional[torch.Tensor] + + def __init__( + self, + config: STULayerConfig, + is_inference: bool = False, + ) -> None: + super().__init__( + is_inference=is_inference, + ) + self.reset_kv_cache() + self._num_heads: int = config.num_heads + self._embedding_dim: int = config.embedding_dim + self._hidden_dim: int = config.hidden_dim + self._attention_dim: int = config.attention_dim + self._output_dropout_ratio: float = config.output_dropout_ratio + self._target_aware: bool = config.target_aware + self._causal: bool = config.causal + self._max_attn_len: int = config.max_attn_len or 0 + self._attn_alpha: float = config.attn_alpha or 1.0 / (self._attention_dim**0.5) + self._use_group_norm: bool = config.use_group_norm + self._recompute_normed_x: bool = config.recompute_normed_x + self._recompute_uvqk: bool = config.recompute_uvqk + self._recompute_y: bool = config.recompute_y + self._sort_by_length: bool = config.sort_by_length + self._contextual_seq_len: int = config.contextual_seq_len + + self._uvqk_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.empty( + ( + self._embedding_dim, + (self._hidden_dim * 2 + self._attention_dim * 2) * self._num_heads, + ) + ), + ) + torch.nn.init.xavier_uniform_(self._uvqk_weight) + self._uvqk_beta: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros( + (self._hidden_dim * 2 + self._attention_dim * 2) * self._num_heads, + ), + ) + self._input_norm_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.ones((self._embedding_dim,)), + ) + self._input_norm_bias: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros((self._embedding_dim,)), + ) + self._output_weight = torch.nn.Parameter( + torch.empty( + ( + self._hidden_dim * self._num_heads * 3, + self._embedding_dim, + ) + ), + ) + torch.nn.init.xavier_uniform_(self._output_weight) + output_norm_shape: int = ( + self._hidden_dim * self._num_heads + if not self._use_group_norm + else self._num_heads + ) + self._output_norm_weight: torch.nn.Parameter = torch.nn.Parameter( + torch.ones((output_norm_shape,)), + ) + self._output_norm_bias: torch.nn.Parameter = torch.nn.Parameter( + torch.zeros((output_norm_shape,)), + ) + + def reset_kv_cache(self) -> None: + self.k_cache = None + self.v_cache = None + self.kv_caching_offsets = None + self.max_kv_caching_len = 0 + + def update_kv_cache( + self, + max_seq_len: int, + seq_offsets: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + max_kv_caching_len: int, + kv_caching_lengths: Optional[torch.Tensor], + ) -> None: + self.k_cache, self.v_cache, self.max_kv_caching_len, self.kv_caching_offsets = ( + _update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + orig_k_cache=self.k_cache, + orig_v_cache=self.v_cache, + orig_max_kv_caching_len=self.max_kv_caching_len, + orig_kv_caching_offsets=self.kv_caching_offsets, + ) + ) + + def construct_full_kv( + self, + delta_k: torch.Tensor, + delta_v: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor]: + return _construct_full_kv( + delta_k=delta_k, + delta_v=delta_v, + k_cache=fx_unwrap_optional_tensor(self.k_cache), + v_cache=fx_unwrap_optional_tensor(self.v_cache), + max_kv_caching_len=self.max_kv_caching_len, + kv_caching_offsets=self.kv_caching_offsets, + ) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + with record_function("## stu_preprocess_and_attention ##"): + u, attn_output, k, v = hstu_preprocess_and_attention( + x=x, + norm_weight=self._input_norm_weight.to(x.dtype), + norm_bias=self._input_norm_bias.to(x.dtype), + norm_eps=1e-6, + num_heads=self._num_heads, + attn_dim=self._attention_dim, + hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight.to(x.dtype), + uvqk_bias=self._uvqk_beta.to(x.dtype), + max_seq_len=max_seq_len, + seq_offsets=x_offsets, + attn_alpha=self._attn_alpha, + causal=self._causal, + num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + recompute_uvqk_in_backward=self._recompute_uvqk, + recompute_normed_x_in_backward=self._recompute_normed_x, + sort_by_length=self._sort_by_length, + prefill=kv_caching_lengths is not None, + kernel=self.hammer_kernel(), + ) + + self.update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=x_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + + with record_function("## stu_compute_output ##"): + return hstu_compute_output( + attn=attn_output, + u=u, + x=x, + norm_weight=self._output_norm_weight.to(x.dtype), + norm_bias=self._output_norm_bias.to(x.dtype), + norm_eps=1e-6, + dropout_ratio=self._output_dropout_ratio, + output_weight=self._output_weight.to(x.dtype), + group_norm=self._use_group_norm, + num_heads=self._num_heads, + linear_dim=self._hidden_dim, + concat_ux=True, + training=self.training, + kernel=self.hammer_kernel(), + recompute_y_in_backward=self._recompute_y, + ) + + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + with record_function("## stu_compute_uqvk ##"): + delta_u, delta_q, delta_k, delta_v = hstu_compute_uqvk( + x=delta_x, + norm_weight=self._input_norm_weight.to(delta_x.dtype), + norm_bias=self._input_norm_bias.to(delta_x.dtype), + norm_eps=1e-6, + num_heads=self._num_heads, + attn_dim=self._attention_dim, + hidden_dim=self._hidden_dim, + uvqk_weight=self._uvqk_weight.to(delta_x.dtype), + uvqk_bias=self._uvqk_beta.to(delta_x.dtype), + kernel=self.hammer_kernel(), + ) + k, v, max_seq_len, seq_offsets = self.construct_full_kv( + delta_k=delta_k.flatten(1, 2), + delta_v=delta_v.flatten(1, 2), + ) + self.update_kv_cache( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + k=k, + v=v, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + k = k.view(-1, self._num_heads, self._attention_dim) + v = v.view(-1, self._num_heads, self._hidden_dim) + with record_function("## delta_hstu_mha ##"): + delta_attn_output = delta_hstu_mha( + max_seq_len=max_seq_len, + alpha=self._attn_alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets if self._target_aware else None, + max_attn_len=self._max_attn_len, + contextual_seq_len=self._contextual_seq_len, + kernel=self.hammer_kernel(), + ).view(-1, self._hidden_dim * self._num_heads) + with record_function("## stu_compute_output ##"): + return hstu_compute_output( + attn=delta_attn_output, + u=delta_u, + x=delta_x, + norm_weight=self._output_norm_weight.to(delta_x.dtype), + norm_bias=self._output_norm_bias.to(delta_x.dtype), + norm_eps=1e-6, + dropout_ratio=self._output_dropout_ratio, + output_weight=self._output_weight.to(delta_x.dtype), + group_norm=self._use_group_norm, + num_heads=self._num_heads, + linear_dim=self._hidden_dim, + concat_ux=True, + training=self.training, + kernel=self.hammer_kernel(), + recompute_y_in_backward=self._recompute_y, + ) + + +class STUStack(STU): + def __init__( + self, + stu_list: List[STU], + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._stu_layers: torch.nn.ModuleList = torch.nn.ModuleList(modules=stu_list) + + def forward( + self, + x: torch.Tensor, + x_lengths: torch.Tensor, + x_offsets: torch.Tensor, + max_seq_len: int, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self._stu_layers: + x = layer( + x=x, + x_lengths=x_lengths, + x_offsets=x_offsets, + max_seq_len=max_seq_len, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + return x + + def cached_forward( + self, + delta_x: torch.Tensor, + num_targets: torch.Tensor, + max_kv_caching_len: int = 0, + kv_caching_lengths: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self._stu_layers: + delta_x = layer.cached_forward( # pyre-ignore [29] + delta_x=delta_x, + num_targets=num_targets, + max_kv_caching_len=max_kv_caching_len, + kv_caching_lengths=kv_caching_lengths, + ) + return delta_x diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/hstu_attention.py b/recommendation/dlrm_v3/generative_recommenders/ops/hstu_attention.py new file mode 100644 index 0000000000..b7021bb075 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/hstu_attention.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch + +from generative_recommenders.common import HammerKernel, switch_to_contiguous_if_needed +from generative_recommenders.ops.pytorch.pt_hstu_attention import ( + pytorch_cached_hstu_mha, + pytorch_hstu_mha, +) +from generative_recommenders.ops.triton.triton_hstu_attention import ( + triton_cached_hstu_mha, + triton_hstu_mha, +) + +try: + from hammer.ops.triton.cc.hstu_attention.triton_cc_hstu_attention import ( + triton_cc_hstu_mha, + ) +except: + from generative_recommenders.ops.triton.triton_hstu_attention import ( + triton_hstu_mha as triton_cc_hstu_mha, + ) +from torch.fx._symbolic_trace import is_fx_tracing + + +def hstu_mha( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + causal: bool = True, + dropout_pr: float = 0.0, + training: bool = True, + num_targets: Optional[torch.Tensor] = None, + attn_scale: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, + sort_by_length: bool = False, + kernel: HammerKernel = HammerKernel.PYTORCH, + enable_tma: bool = False, +) -> torch.Tensor: + _, H, _ = q.shape + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(q.dim() == 3, "q must be 3-D") + torch._assert(k.shape == q.shape, "k must be the same shape as q") + torch._assert(v.dim() == 3, "v must be 3-D") + torch._assert(v.shape[0] == q.shape[0], "wrong v shape[0]") + torch._assert(v.shape[1] == H, "wrong v shape[1]") + torch._assert(causal, "only support causal attention") + + if kernel in [HammerKernel.TRITON, HammerKernel.TRITON_CC]: + if not is_fx_tracing() and kernel == HammerKernel.TRITON: + torch._assert(q.is_cuda, "q must be CUDA tensor") + torch._assert(k.is_cuda, "k must be CUDA tensor") + torch._assert(v.is_cuda, "v must be CUDA tensor") + torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") + torch._assert(dropout_pr < 1e-6, "dropout for triton path not implemented") + torch._assert( + min_full_attn_seq_len == 0, "min_full_attn_seq_len not implemented" + ) + assert attn_scale is None, "attn_scale not implemented" + q = switch_to_contiguous_if_needed(q) + k = switch_to_contiguous_if_needed(k) + v = switch_to_contiguous_if_needed(v) + seq_offsets = seq_offsets.contiguous() + + if kernel == HammerKernel.TRITON: + return triton_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length=sort_by_length, + enable_tma=enable_tma, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + else: + return pytorch_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=True, + dropout_pr=dropout_pr, + training=training, + num_targets=num_targets, + attn_scale=attn_scale, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + min_full_attn_seq_len=min_full_attn_seq_len, + ) + + +def delta_hstu_mha( + max_seq_len: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + kernel: HammerKernel = HammerKernel.PYTORCH, + enable_tma: bool = False, +) -> torch.Tensor: + L, H, D = delta_q.shape + B = seq_offsets.size(0) - 1 + DeltaSize = L // B + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(delta_q.dim() == 3, "delta_q must be 3-D") + torch._assert(L % B == 0, "delta_q must be padded") + torch._assert(k.dim() == 3, "k must be 3-D") + torch._assert(k.shape[1] == H, "wrong k shape[1]") + torch._assert(k.shape[2] == D, "wrong k shape[2]") + torch._assert(v.dim() == 3, "v must be 3-D") + torch._assert(v.shape[1] == H, "wrong v shape[1]") + if kernel in [HammerKernel.TRITON, HammerKernel.TRITON_CC]: + if not is_fx_tracing() and kernel == HammerKernel.TRITON: + torch._assert(delta_q.is_cuda, "q must be CUDA tensor") + torch._assert(seq_offsets.is_cuda, "seq_offsets must be CUDA tensor") + if num_targets is not None: + torch._assert(num_targets.is_cuda, "num_targets must be CUDA tensor") + seq_offsets = seq_offsets.contiguous() + delta_q = switch_to_contiguous_if_needed(delta_q) + k = switch_to_contiguous_if_needed(k) + v = switch_to_contiguous_if_needed(v) + + if kernel == HammerKernel.TRITON: + return triton_cached_hstu_mha( + N=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + enable_tma=enable_tma, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_hstu_mha( + N=max_seq_len, + alpha=alpha, + q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + is_delta_q=True, + delta_size=DeltaSize, + ) + else: + return pytorch_cached_hstu_mha( + max_seq_len=max_seq_len, + alpha=alpha, + delta_q=delta_q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/hstu_compute.py b/recommendation/dlrm_v3/generative_recommenders/ops/hstu_compute.py new file mode 100644 index 0000000000..909a3582fe --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/hstu_compute.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from generative_recommenders.ops.layer_norm import layer_norm +from generative_recommenders.ops.mm import addmm +from generative_recommenders.ops.pytorch.pt_hstu_linear import ( + pytorch_hstu_compute_output, +) + +try: + from hammer.ops.triton.cc.addmm.triton_cc_addmm import triton_cc_addmm + from hammer.ops.triton.cc.group_norm_mul_dropout.triton_cc_group_norm_mul_dropout import ( + triton_cc_group_norm_mul_dropout_wrapper, + ) + from hammer.ops.triton.cc.layer_norm_mul_dropout.triton_cc_layer_norm_mul_dropout import ( + triton_cc_layer_norm_mul_dropout_wrapper, + ) +except ImportError: + pass +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.hstu_attention import hstu_mha +from generative_recommenders.ops.triton.triton_hstu_linear import ( + triton_hstu_compute_output, +) +from generative_recommenders.ops.triton.triton_hstu_preprocess_and_attention import ( + triton_hstu_preprocess_and_attention, +) +from torch.fx._symbolic_trace import is_fx_tracing + + +def hstu_compute_uqvk( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + normed_x = layer_norm( + x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + kernel=kernel, + ) + # NOTE: for AMD training, we go with torch.addmm instead of the triton + # version before Triton on AMD achieves on-par perf with NV GPU. + if torch.version.hip and kernel == HammerKernel.TRITON: + uvqk = torch.addmm(uvqk_bias, normed_x, uvqk_weight) + else: + uvqk = addmm(uvqk_bias, normed_x, uvqk_weight, kernel) + u, v, q, k = torch.split( + uvqk, + [ + hidden_dim * num_heads, + hidden_dim * num_heads, + attn_dim * num_heads, + attn_dim * num_heads, + ], + dim=1, + ) + u = F.silu(u) + q = q.view(-1, num_heads, attn_dim) + k = k.view(-1, num_heads, attn_dim) + v = v.view(-1, num_heads, hidden_dim) + return u, q, k, v + + +def hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + output_weight: torch.Tensor, + num_heads: int, + linear_dim: int, + dropout_ratio: float, + training: bool, + concat_ux: bool, + group_norm: bool, + recompute_y_in_backward: bool, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + return triton_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + seed=None, + recompute_y_in_backward=recompute_y_in_backward, + ) + elif kernel == HammerKernel.TRITON_CC: + if group_norm: + y = triton_cc_group_norm_mul_dropout_wrapper( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + num_heads=num_heads, + linear_dim=linear_dim, + ) + else: + y = triton_cc_layer_norm_mul_dropout_wrapper( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + ) + return triton_cc_addmm(x, y, output_weight) + else: + return pytorch_hstu_compute_output( + attn=attn, + u=u, + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + output_weight=output_weight, + eps=norm_eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + + +def hstu_preprocess_and_attention( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + causal: bool, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + sort_by_length: bool, + prefill: bool = False, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + if not is_fx_tracing(): + torch._assert(max_seq_len > 0, "max_seq_len must be larger than 0") + torch._assert(x.dim() == 2, "x must be 2-D") + torch._assert( + x.shape[1] == uvqk_weight.shape[0], + "x.shape[1] must equal uvqk_weight.shape[0]", + ) + torch._assert( + uvqk_weight.shape[1] == 2 * num_heads * (hidden_dim + attn_dim), + "uvqk_weight.shape[1] must equal 2 * num_heads * (hidden_dim + attn_dim)", + ) + torch._assert(causal is True, "only causal attention is supported.") + if kernel == HammerKernel.TRITON and prefill is False: + u, attn_output = triton_hstu_preprocess_and_attention( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=num_heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + attn_alpha=attn_alpha, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + recompute_uvqk_in_backward=recompute_uvqk_in_backward, + recompute_normed_x_in_backward=recompute_normed_x_in_backward, + sort_by_length=sort_by_length, + ) + attn_output = attn_output.view(-1, hidden_dim * num_heads) + k = None + v = None + else: + u, q, k, v = hstu_compute_uqvk( + x=x, + norm_weight=norm_weight, + norm_bias=norm_bias, + norm_eps=norm_eps, + num_heads=num_heads, + attn_dim=attn_dim, + hidden_dim=hidden_dim, + uvqk_weight=uvqk_weight, + uvqk_bias=uvqk_bias, + kernel=kernel, + ) + attn_output = hstu_mha( + max_seq_len=max_seq_len, + alpha=attn_alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + causal=causal, + dropout_pr=0.0, + training=False, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length=sort_by_length, + kernel=kernel, + ).view(-1, hidden_dim * num_heads) + return u, attn_output, k, v diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/jagged_tensors.py b/recommendation/dlrm_v3/generative_recommenders/ops/jagged_tensors.py new file mode 100644 index 0000000000..0ca24daa55 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/jagged_tensors.py @@ -0,0 +1,356 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch + +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.pytorch.pt_jagged import pytorch_jagged_dense_bmm_add +from generative_recommenders.ops.pytorch.pt_jagged_tensors import ( + pytorch_concat_2D_jagged, + pytorch_hstu_concat_l2_embeddings, + pytorch_hstu_split_l2_embeddings, + pytorch_split_2D_jagged, +) + +from generative_recommenders.ops.triton.triton_jagged import triton_jagged_dense_bmm_add +from generative_recommenders.ops.triton.triton_jagged_tensors import ( + triton_concat_2D_jagged, + triton_concat_2D_jagged_multirow, + triton_split_2D_jagged, + triton_split_2D_jagged_multirow, +) +from torch.fx._symbolic_trace import is_fx_tracing + +try: + from hammer.ops.triton.cc.jagged_dense_bmm.triton_cc_jagged_dense_bmm import ( + triton_cc_jagged_dense_bmm, + ) +except ImportError: + pass + + +torch.fx.wrap("triton_concat_2D_jagged") +torch.fx.wrap("triton_split_2D_jagged") + + +def concat_2D_jagged( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if not is_fx_tracing(): + torch._assert(values_left.dim() == 2, "values_left must be 2D") + torch._assert(values_right.dim() == 2, "values_right must be 2D") + torch._assert( + values_right.shape[1] == values_left.shape[1], + f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", + ) + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged( + max_seq_len=max_seq_len, + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + else: + return pytorch_concat_2D_jagged( + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + + +def split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not is_fx_tracing(): + torch._assert(values.dim() == 2, "values must be 2D") + torch._assert( + offsets_left is not None or offsets_right is not None, + "offsets_left and offsets_right cannot be None at the same time", + ) + if offsets_left is None: + torch._assert( + max_len_left is not None, + "max_len_left must be provided when offsets_left is None", + ) + if offsets_right is None: + torch._assert( + max_len_right is not None, + "max_len_right must be provided when offsets_right is None", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left shape[0] must be equal to offsets_right shape[0]", + ) + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + else: + return pytorch_split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + + +def hstu_split_l2_embeddings( + max_seq_len: int, + x: torch.Tensor, + prefix_offsets: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> Tuple[torch.Tensor, torch.Tensor]: + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged( + max_seq_len=max_seq_len, + values=x, + total_len_right=None, + total_len_left=None, + max_len_left=None, + max_len_right=None, + offsets_left=prefix_offsets, + offsets_right=l2_offsets, + n_prefix_to_right=contextual_seq_len, + ) + else: + return pytorch_hstu_split_l2_embeddings( + max_seq_len=max_seq_len, + x=x, + prefix_offsets=prefix_offsets, + l2_offsets=l2_offsets, + contextual_seq_len=contextual_seq_len, + ) + + +def hstu_concat_l2_embeddings( + max_prefix_len: int, + prefix_x: torch.Tensor, + prefix_offsets: torch.Tensor, + max_l2_len: int, + l2_x: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged( + max_seq_len=max_prefix_len + max_l2_len, + values_left=prefix_x, + values_right=l2_x, + max_len_left=max_prefix_len, + max_len_right=max_l2_len, + offsets_left=prefix_offsets, + offsets_right=l2_offsets, + n_prefix_from_right=contextual_seq_len, + ) + else: + return pytorch_hstu_concat_l2_embeddings( + contextual_seq_len=contextual_seq_len, + max_prefix_len=max_prefix_len, + prefix_x=prefix_x, + prefix_offsets=prefix_offsets, + max_l2_len=max_l2_len, + l2_x=l2_x, + l2_offsets=l2_offsets, + ) + + +def jagged_dense_bmm_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + """ + Computing out = jagged x dense + bias + jagged has shape (sum_B(M_i), K), dense has shape (B, K, N), and bias has shape (B, N) + out has shape (sum_B(M_i), N) + """ + if not is_fx_tracing(): + _, K = jagged.shape + B, _, N = dense.shape + torch._assert(dense.shape[1] == K, "wrong dense shape[1]") + torch._assert(seq_offsets.shape[0] == B + 1, "wrong seq_offsets shape[0]") + torch._assert(bias.shape[0] == B, "wrong bias shape[0]") + torch._assert(bias.shape[1] == N, "wrong bias shape[1]") + if kernel == HammerKernel.TRITON: + return triton_jagged_dense_bmm_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + elementwise=False, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_jagged_dense_bmm( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + ) + else: + return pytorch_jagged_dense_bmm_add( + max_seq_len=max_seq_len, + seq_offsets=seq_offsets, + jagged=jagged, + dense=dense, + bias=bias, + ) + + +def concat_2D_jagged_multirow( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + max_len_left: int, + max_len_right: int, + kernel: HammerKernel = HammerKernel.TRITON, +) -> torch.Tensor: + if not is_fx_tracing(): + torch._assert(values_left.dim() == 2, "values_left must be 2D") + torch._assert(values_right.dim() == 2, "values_right must be 2D") + torch._assert( + values_right.shape[1] == values_left.shape[1], + f"values_left shape[1] must be equal to values_right shape[1] {values_left.shape[1]} vs {values_right.shape[1]}", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left and offsets_right must have the same batch dimension", + ) + + if kernel == HammerKernel.TRITON: + return triton_concat_2D_jagged_multirow( + max_seq_len=max_seq_len, + values_a=values_left, + values_b=values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + max_len_a=max_len_left, + max_len_b=max_len_right, + ) + else: + return concat_2D_jagged( + max_seq_len=max_seq_len, + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + kernel=kernel, + ) + + +def split_2D_jagged_multirow( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, + kernel: HammerKernel = HammerKernel.TRITON, +) -> Tuple[torch.Tensor, torch.Tensor]: + if not is_fx_tracing(): + torch._assert(values.dim() == 2, "values must be 2D") + torch._assert( + offsets_left is not None or offsets_right is not None, + "offsets_left and offsets_right cannot be None at the same time", + ) + if offsets_left is None: + torch._assert( + max_len_left is not None, + "max_len_left must be provided when offsets_left is None", + ) + if offsets_right is None: + torch._assert( + max_len_right is not None, + "max_len_right must be provided when offsets_right is None", + ) + if offsets_left is not None and offsets_right is not None: + torch._assert( + offsets_left.shape[0] == offsets_right.shape[0], + "offsets_left and offsets_right must have the same batch dimension", + ) + + if kernel == HammerKernel.TRITON: + return triton_split_2D_jagged_multirow( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + ) + else: + return split_2D_jagged( + max_seq_len=max_seq_len, + values=values, + total_len_left=total_len_left, + total_len_right=total_len_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left, + offsets_right=offsets_right, + kernel=kernel, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/layer_norm.py b/recommendation/dlrm_v3/generative_recommenders/ops/layer_norm.py new file mode 100644 index 0000000000..5b21b463eb --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/layer_norm.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + + +from typing import List + +import torch +from generative_recommenders.ops.pytorch.pt_layer_norm import ( + pytorch_layer_norm, + pytorch_rms_norm, + pytorch_swish_layer_norm, +) +from generative_recommenders.ops.triton.triton_layer_norm import triton_rms_norm + +try: + from hammer.ops.triton.cc.swish_layer_norm.triton_cc_swish_layer_norm import ( + triton_cc_swish_layer_norm, + ) +except ImportError: + pass +try: + from hammer.ops.triton.cc.rms_norm.triton_cc_rms_norm import triton_cc_rms_norm +except ImportError: + pass +from generative_recommenders.common import HammerKernel, HammerModule +from generative_recommenders.ops.triton.triton_layer_norm import ( + triton_layer_norm, + triton_swish_layer_norm, +) +from torch.fx._symbolic_trace import is_fx_tracing + +torch.fx.wrap("triton_layer_norm") +torch.fx.wrap("triton_swish_layer_norm") + + +def layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + torch._assert(not bias.is_cpu, "bias must be device tensor") + return triton_layer_norm(x, weight, bias, eps) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=False, + ) + else: + return pytorch_layer_norm( + x, + [ + x.shape[-1], + ], + weight, + bias, + eps, + ) + + +def rms_norm( + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, + silu: bool = False, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + return triton_rms_norm(x, weight, eps, silu) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_rms_norm( + x, + weight, + eps, + silu=silu, + ) + else: + return pytorch_rms_norm( + x, + [ + x.shape[-1], + ], + weight, + eps, + silu, + ) + + +def swish_layer_norm( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float = 1e-5, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + if not is_fx_tracing(): + torch._assert(not x.is_cpu, "x must be device tensor") + torch._assert(not weight.is_cpu, "weight must be device tensor") + torch._assert(not bias.is_cpu, "bias must be device tensor") + return triton_swish_layer_norm(x, [x.shape[-1]], weight, bias, eps) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_swish_layer_norm( + x, + weight, + bias, + eps, + is_swish=True, + ) + else: + return pytorch_swish_layer_norm( + x, + [ + x.shape[-1], + ], + weight, + bias, + eps, + ) + + +class LayerNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._normalized_shape: List[int] = [dim] + self._eps = eps + self.weight = torch.nn.Parameter( + torch.ones(self._normalized_shape), + ) + self.bias = torch.nn.Parameter( + torch.zeros(self._normalized_shape), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return layer_norm( + x=x, + weight=self.weight, + bias=self.bias, + eps=self._eps, + kernel=self.hammer_kernel(), + ) + + +class RMSNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return rms_norm( + x, + self.weight, + self._eps, + silu=False, + kernel=self.hammer_kernel(), + ) + + +class RMSNormSilu(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return rms_norm( + x, + self.weight, + self._eps, + silu=True, + kernel=self.hammer_kernel(), + ) + + +class SwishLayerNorm(HammerModule): + def __init__( + self, + dim: int, + eps: float = 1e-5, + is_inference: bool = False, + ) -> None: + super().__init__(is_inference=is_inference) + self._normalized_shape: List[int] = [dim] + self.weight = torch.nn.Parameter(torch.ones(self._normalized_shape)) + self.bias = torch.nn.Parameter(torch.zeros(self._normalized_shape)) + self._eps = eps + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return swish_layer_norm( + x=x, + weight=self.weight, + bias=self.bias, + eps=self._eps, + kernel=self.hammer_kernel(), + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/mm.py b/recommendation/dlrm_v3/generative_recommenders/ops/mm.py new file mode 100644 index 0000000000..0dd9c89cfb --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/mm.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import torch + +try: + from hammer.ops.triton.cc.addmm.triton_cc_addmm import triton_cc_addmm +except ImportError: + pass +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_addmm import triton_addmm + + +def addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + if kernel == HammerKernel.TRITON: + return triton_addmm(input, mat1, mat2) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_addmm(input, mat1, mat2) + else: + return torch.addmm(input, mat1, mat2) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/position.py b/recommendation/dlrm_v3/generative_recommenders/ops/position.py new file mode 100644 index 0000000000..dd476d6100 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/position.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch +from generative_recommenders.ops.pytorch.pt_position import ( + pytorch_add_timestamp_positional_embeddings, +) + +try: + from hammer.ops.triton.cc.add_timestamp_position_embeddings.triton_cc_add_timestamp_position_embeddings import ( + triton_cc_add_timestamp_position_embeddings, + ) +except ImportError: + pass +from generative_recommenders.common import HammerKernel +from generative_recommenders.ops.triton.triton_position import ( + triton_add_timestamp_positional_embeddings, +) + + +def add_timestamp_positional_embeddings( + alpha: float, + max_seq_len: int, + max_contextual_seq_len: int, + position_embeddings_weight: torch.Tensor, + timestamp_embeddings_weight: torch.Tensor, + seq_offsets: torch.Tensor, + seq_lengths: torch.Tensor, + seq_embeddings: torch.Tensor, + timestamps: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str = "sqrt", + kernel: HammerKernel = HammerKernel.PYTORCH, +) -> torch.Tensor: + assert time_bucket_fn in ["sqrt", "log"] + seq_embeddings = seq_embeddings * alpha + if kernel == HammerKernel.TRITON: + return triton_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + elif kernel == HammerKernel.TRITON_CC: + return triton_cc_add_timestamp_position_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) + else: + return pytorch_add_timestamp_positional_embeddings( + seq_embeddings=seq_embeddings, + seq_offsets=seq_offsets, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + time_bucket_fn=time_bucket_fn, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_attention.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_attention.py new file mode 100644 index 0000000000..e4e5f64f61 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_attention.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def _get_valid_attn_mask( + device: torch.device, + causal: bool, + N: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +) -> torch.Tensor: + ids = torch.arange(0, N, device=device).view(1, N) + max_ids = seq_lengths.view(-1, 1, 1) + if contextual_seq_len > 0: + ids = ids - contextual_seq_len + 1 + ids = torch.clamp(ids, min=0) + max_ids = max_ids - contextual_seq_len + 1 + if num_targets is not None: + max_ids = max_ids - num_targets.view(-1, 1, 1) + ids = torch.clamp( + ids, + max=max_ids, + ) + row_ids = ids.view(-1, N, 1).expand(-1, N, N) + col_ids = ids.view(-1, 1, N).expand(-1, N, N) + else: + row_ids = ids.view(N, 1).expand(N, N) + col_ids = row_ids.t() + row_ids = row_ids.view(1, N, N) + col_ids = col_ids.view(1, N, N) + row_col_dist = row_ids - col_ids + valid_attn_mask = torch.eye(N, device=device, dtype=torch.bool).view(1, N, N) + if not causal: + row_col_dist = torch.where(row_col_dist > 0, row_col_dist, -row_col_dist) + valid_attn_mask = torch.logical_or(valid_attn_mask, row_col_dist > 0) + if max_attn_len > 0: + if min_full_attn_seq_len > 0: + valid_attn_mask = torch.logical_and( + valid_attn_mask, + torch.logical_or( + row_col_dist <= max_attn_len, + row_ids >= max_ids - min_full_attn_seq_len, + ), + ) + else: + valid_attn_mask = torch.logical_and( + valid_attn_mask, row_col_dist <= max_attn_len + ) + if contextual_seq_len > 0: + valid_attn_mask = torch.logical_or( + valid_attn_mask, torch.logical_and(row_ids == 0, col_ids < max_ids) + ) + return valid_attn_mask + + +def _pad_qkv( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + N: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + L, H, D = q.shape + V = v.shape[2] + padded_q = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=q.reshape(L, H * D), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, D) + .transpose(1, 2) + ) # [B, H, N, A] + padded_k = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=k.reshape(L, H * D), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, D) + .transpose(1, 2) + ) # [B, H, N, A] + padded_v = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=v.reshape(L, H * V), + offsets=[seq_offsets], + max_lengths=[N], + padding_value=0.0, + ) + .view(-1, N, H, V) + .transpose(1, 2) + ) # [B, H, N, D] + return padded_q, padded_k, padded_v + + +@torch.fx.wrap +def pytorch_hstu_mha( + max_seq_len: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + causal: bool = True, + dropout_pr: float = 0.0, + training: bool = True, + num_targets: Optional[torch.Tensor] = None, + attn_scale: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + min_full_attn_seq_len: int = 0, +) -> torch.Tensor: + L, H, _ = q.shape + V = v.shape[2] + q, k, v = _pad_qkv( + q, k, v, seq_offsets, max_seq_len + ) # [B, H, N, D) and [B, H, N, V] + qk_attn = torch.einsum("bhxa,bhya->bhxy", q, k) * alpha + if attn_scale is not None: + if attn_scale.ndim > 0: + attn_scale = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=attn_scale.unsqueeze(-1), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .unsqueeze(1) + .to(qk_attn.dtype) + ) + else: + # pyre-ignore[9] + attn_scale = attn_scale.item() + + qk_attn = F.silu(qk_attn) * attn_scale + else: + qk_attn = F.silu(qk_attn) / max_seq_len + valid_attn_mask = _get_valid_attn_mask( + device=q.device, + causal=causal, + N=max_seq_len, + seq_lengths=seq_offsets[1:] - seq_offsets[:-1], + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + min_full_attn_seq_len=min_full_attn_seq_len, + ) + # raise NotImplementedError(valid_attn_mask[0, :, :].to(torch.int32)) + qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) + if dropout_pr > 0.0: + qk_attn = F.dropout(qk_attn, p=dropout_pr, training=training) + attn_dense = torch.einsum("bhxd,bhdv->bhxv", qk_attn, v) # [B, H, N, V] + return torch.ops.fbgemm.dense_to_jagged( + attn_dense.transpose(1, 2).flatten(2, 3), # [B, N, H, V]->[B, N, H * V] + [seq_offsets], + L, + )[0].view(L, H, V) + + +@torch.fx.wrap +def pytorch_cached_hstu_mha( + max_seq_len: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, +) -> torch.Tensor: + L, H, D = delta_q.shape + _, _, V = v.shape + B = seq_offsets.size(0) - 1 + delta_size = L // B + delta_q = delta_q.view(B, -1, H, D).transpose(1, 2) + full_k = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=k.reshape(-1, H * D), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .view(B, -1, H, D) + .transpose(1, 2) + ) + full_v = ( + torch.ops.fbgemm.jagged_to_padded_dense( + values=v.reshape(-1, H * V), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + .view(B, -1, H, V) + .transpose(1, 2) + ) + qk_attn = torch.einsum("bhxa,bhya->bhxy", delta_q, full_k) * alpha + qk_attn = F.silu(qk_attn) / max_seq_len + full_valid_attn_mask = _get_valid_attn_mask( + device=delta_q.device, + causal=True, + N=max_seq_len, + seq_lengths=seq_offsets[1:] - seq_offsets[:-1], + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + ) + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + mask = torch.arange(max_seq_len, device=delta_q.device).view(1, -1) + mask = torch.logical_and( + mask >= (seq_lengths - delta_size).view(-1, 1), + mask < seq_lengths.view(-1, 1), + ) + valid_attn_mask = ( + full_valid_attn_mask.expand(B, -1, -1) + .flatten(0, 1)[mask.view(-1), :] + .view(-1, delta_size, max_seq_len) + ) + qk_attn = qk_attn * valid_attn_mask.unsqueeze(1) + attn_output = torch.einsum("bhxd,bhdv->bhxv", qk_attn, full_v) + return attn_output.transpose(1, 2).reshape(-1, H, V) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_linear.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_linear.py new file mode 100644 index 0000000000..e06f9305c5 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_hstu_linear.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import torch +import torch.nn.functional as F + + +def pytorch_norm_mul_dropout( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> torch.Tensor: + dtype = x.dtype + if silu_u: + u = F.silu(u) + x = x.to(torch.float32) + u = u.to(torch.float32) + if group_norm: + y = u * F.group_norm( + x.view(-1, num_heads, linear_dim), + num_groups=num_heads, + weight=weight.to(torch.float32), + bias=bias.to(torch.float32), + eps=eps, + ).view(-1, num_heads * linear_dim) + else: + y = u * F.layer_norm( + x, + normalized_shape=(x.shape[-1],), + weight=weight.to(torch.float32), + bias=bias.to(torch.float32), + eps=eps, + ) + if concat_ux: + y = torch.cat([u, x, y], dim=1) + y = F.dropout( + y, + p=dropout_ratio, + training=training, + ) + return y.to(dtype) + + +def pytorch_hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, +) -> torch.Tensor: + dtype = x.dtype + y = pytorch_norm_mul_dropout( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + group_norm=group_norm, + num_heads=num_heads, + linear_dim=linear_dim, + ) + return torch.addmm(x, y, output_weight.to(x.dtype)).to(dtype) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged.py new file mode 100644 index 0000000000..67de7cbfce --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Tuple + +import torch + + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def pytorch_jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + bmm_out = torch.bmm(padded_jagged, dense) + jagged_bmm_out = torch.ops.fbgemm.dense_to_jagged( + bmm_out, [seq_offsets], total_L=jagged.shape[0] + )[0] + jagged_bmm_out = jagged_bmm_out.to(dtype) + return jagged_bmm_out + + +def pytorch_jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + out = padded_jagged + dense.unsqueeze(1) + jagged_out = torch.ops.fbgemm.dense_to_jagged( + out, [seq_offsets], total_L=jagged.shape[0] + )[0] + jagged_out = jagged_out.to(dtype) + return jagged_out + + +def pytorch_jagged_dense_bmm_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> torch.Tensor: + dtype = jagged.dtype + jagged = jagged.to(torch.float32) + dense = dense.to(torch.float32) + padded_jagged = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged, + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ) + bmm_out = torch.bmm(padded_jagged, dense) + + if elementwise: + jagged_out = ( + torch.ops.fbgemm.dense_to_jagged( + bmm_out, [seq_offsets], total_L=jagged.shape[0] + )[0] + + bias + ) + else: + jagged_out = torch.ops.fbgemm.dense_to_jagged( + bmm_out + bias.unsqueeze(1), [seq_offsets], total_L=jagged.shape[0] + )[0] + + jagged_out = jagged_out.to(dtype) + return jagged_out + + +@torch.fx.wrap +def _arange(len: int, device: torch.device) -> torch.Tensor: + return torch.arange(len, device=device) + + +def pytorch_concat_2D_dense_jagged( + jagged_max_seq_len: int, + jagged_offsets: torch.Tensor, + jagged_values: torch.Tensor, + dense_values: torch.Tensor, +) -> torch.Tensor: + B, dense_size, D = dense_values.size() + jagged_dense = torch.ops.fbgemm.jagged_to_padded_dense( + values=jagged_values, + offsets=[jagged_offsets], + max_lengths=[jagged_max_seq_len], + padding_value=0.0, + ) + concatted_dense = torch.cat([dense_values, jagged_dense], dim=1) + concatted_offsets = ( + dense_size * _arange(B + 1, device=jagged_offsets.device) + jagged_offsets + ) + return torch.ops.fbgemm.dense_to_jagged( + concatted_dense, + [concatted_offsets], + total_L=jagged_values.shape[0] + dense_size * B, + )[0] + + +def pytorch_concat_2D_jagged_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + max_seq_len_right: int, + offsets_right: torch.Tensor, + values_right: torch.Tensor, + is_replace: bool = False, + n_prefix_from_right: int = 0, +) -> torch.Tensor: + # is_replace with n_prefix_from_right != 0 is not supported yet (neither in triton) + if is_replace: + return pytorch_replace_last_n_with_jagged( + max_seq_len_left, + offsets_left, + values_left, + offsets_right, + values_right, + ) + _, D = values_left.shape + max_seq_len = max_seq_len_left + max_seq_len_right + B = offsets_left.shape[0] - 1 + + lengths_a = offsets_left[1:] - offsets_left[:-1] + lengths_b = offsets_right[1:] - offsets_right[:-1] + dense_a = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_left, + offsets=[offsets_left], + max_lengths=[max_seq_len_left], + padding_value=0.0, + ) + dense_b = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_right, + offsets=[offsets_right], + max_lengths=[max_seq_len_right], + padding_value=0.0, + ) + dense_b_prefix, dense_b_suffix = torch.split( + dense_b, [n_prefix_from_right, max_seq_len_right - n_prefix_from_right], dim=1 + ) + dense = torch.cat([dense_b_prefix, dense_a, dense_b_suffix], dim=1) + mask = _arange(max_seq_len, device=offsets_left.device).expand(B, max_seq_len) + mask = torch.logical_or( + mask < lengths_a.view(B, 1) + n_prefix_from_right, + torch.logical_and( + mask >= max_seq_len_left + n_prefix_from_right, + mask < max_seq_len_left + lengths_b.view(B, 1), + ), + ) + return dense.view(-1, D)[mask.view(-1), :] + + +def pytorch_jagged_remove_first_or_last_1D( + values: torch.Tensor, + lengths: torch.Tensor, + offsets: torch.Tensor, + max_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + values = values.view(-1, 1) + shrunk_lengths = lengths - 1 + k_lengths = torch.stack([shrunk_lengths, torch.ones_like(lengths)], dim=1).view(-1) + q_lengths = torch.stack([torch.ones_like(lengths), shrunk_lengths], dim=1).view(-1) + all_indices = torch.arange( + start=0, end=q_lengths.numel(), device=values.device + ).reshape(-1, 2) + q_indices, k_indices = all_indices[:, 1], all_indices[:, 0] + values_no_first, _ = torch.ops.fbgemm.jagged_index_select( + values, q_lengths, q_indices + ) + values_no_last, _ = torch.ops.fbgemm.jagged_index_select( + values, k_lengths, k_indices + ) + return values_no_first.squeeze(), values_no_last.squeeze() + + +@torch.fx.wrap +def fx_apply_mask( + tensor: torch.Tensor, mask: torch.Tensor, fill_value: torch.Tensor +) -> torch.Tensor: + tensor[mask] = fill_value + return tensor + + +def pytorch_replace_last_n_with_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + offsets_right: torch.Tensor, + values_right: torch.Tensor, +) -> torch.Tensor: + B = offsets_left.shape[0] - 1 + lengths_a = offsets_left[1:] - offsets_left[:-1] + lengths_b = offsets_right[1:] - offsets_right[:-1] + dense_a = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_left, + offsets=[offsets_left], + max_lengths=[max_seq_len_left], + padding_value=0.0, + ) + raw_mask = torch.arange(max_seq_len_left, device=offsets_left.device).expand( + B, max_seq_len_left + ) + mask = torch.logical_and( + raw_mask >= (lengths_a - lengths_b).unsqueeze(1), + raw_mask < lengths_a.unsqueeze(1), + ) + dense_a = fx_apply_mask(dense_a, mask, values_right) + jagged_a = torch.ops.fbgemm.dense_to_jagged( + dense_a, + [offsets_left], + )[0] + return jagged_a diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged_tensors.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged_tensors.py new file mode 100644 index 0000000000..27817f7fbd --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_jagged_tensors.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +from generative_recommenders.common import fx_arange + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +def _concat_2D_jagged_jagged( + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: int, + max_len_right: int, + offsets_left: torch.Tensor, + offsets_right: torch.Tensor, +) -> torch.Tensor: + max_seq_len = max_len_left + max_len_right + lengths_left = offsets_left[1:] - offsets_left[:-1] + lengths_right = offsets_right[1:] - offsets_right[:-1] + padded_left = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_left, + offsets=[offsets_left], + max_lengths=[max_len_left], + padding_value=0.0, + ) + padded_right = torch.ops.fbgemm.jagged_to_padded_dense( + values=values_right, + offsets=[offsets_right], + max_lengths=[max_len_right], + padding_value=0.0, + ) + concatted_dense = torch.cat([padded_left, padded_right], dim=1) + mask = fx_arange(max_seq_len, device=offsets_left.device).view(1, -1) + mask = torch.logical_or( + mask < lengths_left.view(-1, 1), + torch.logical_and( + mask >= max_len_left, + mask < max_len_left + lengths_right.view(-1, 1), + ), + ) + return concatted_dense.flatten(0, 1)[mask.view(-1), :] + + +@torch.fx.wrap +def pytorch_concat_2D_jagged( + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], +) -> torch.Tensor: + if offsets_left is None: + assert max_len_left is not None + B = values_left.shape[0] // max_len_left + offsets_left_non_optional = max_len_left * torch.arange( + B + 1, device=values_left.device + ) + else: + offsets_left_non_optional = offsets_left + if offsets_right is None: + assert max_len_right is not None + B = values_right.shape[0] // max_len_right + offsets_right_non_optional = max_len_right * torch.arange( + B + 1, device=values_left.device + ) + else: + offsets_right_non_optional = offsets_right + max_len_left = ( + int( + (offsets_left_non_optional[1:] - offsets_left_non_optional[:-1]) + .max() + .item() + ) + if max_len_left is None + else max_len_left + ) + max_len_right = ( + int( + (offsets_right_non_optional[1:] - offsets_right_non_optional[:-1]) + .max() + .item() + ) + if max_len_right is None + else max_len_right + ) + return _concat_2D_jagged_jagged( + values_left=values_left, + values_right=values_right, + max_len_left=max_len_left, + max_len_right=max_len_right, + offsets_left=offsets_left_non_optional, + offsets_right=offsets_right_non_optional, + ) + + +def _split_2D_jagged_jagged( + max_seq_len: int, + values: torch.Tensor, + offsets_left: torch.Tensor, + offsets_right: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + offsets = offsets_left + offsets_right + padded_values = torch.ops.fbgemm.jagged_to_padded_dense( + values=values, + offsets=[offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ).flatten(0, 1) + lengths_left = offsets_left[1:] - offsets_left[:-1] + lengths_right = offsets_right[1:] - offsets_right[:-1] + mask = fx_arange(max_seq_len, device=values.device).view(1, -1) + mask_left = mask < lengths_left.view(-1, 1) + mask_right = torch.logical_and( + mask >= lengths_left.view(-1, 1), + mask < (lengths_left + lengths_right).view(-1, 1), + ) + return padded_values[mask_left.view(-1), :], padded_values[mask_right.view(-1), :] + + +@torch.fx.wrap +def pytorch_split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor]: + if offsets_left is None: + assert max_len_left is not None + assert offsets_right is not None + offsets_left_non_optional = max_len_left * torch.arange( + offsets_right.shape[0], device=values.device + ) + else: + offsets_left_non_optional = offsets_left + if offsets_right is None: + assert max_len_right is not None + assert offsets_left is not None + offsets_right_non_optional = max_len_right * torch.arange( + offsets_left.shape[0], device=values.device + ) + else: + offsets_right_non_optional = offsets_right + return _split_2D_jagged_jagged( + max_seq_len=max_seq_len, + values=values, + offsets_left=offsets_left_non_optional, + offsets_right=offsets_right_non_optional, + ) + + +def pytorch_hstu_split_l2_embeddings( + max_seq_len: int, + x: torch.Tensor, + prefix_offsets: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_offsets = prefix_offsets + l2_offsets + x_lengths = x_offsets[1:] - x_offsets[:-1] + padded_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=x, + offsets=[x_offsets], + max_lengths=[max_seq_len], + padding_value=0.0, + ).flatten(0, 1) + prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] + mask = fx_arange(max_seq_len, device=x_offsets.device).view(1, -1) + mask_prefix = torch.logical_and( + mask >= contextual_seq_len, + mask < prefix_lengths.view(-1, 1) + contextual_seq_len, + ) + mask_l2 = torch.logical_or( + mask < contextual_seq_len, + torch.logical_and( + mask >= prefix_lengths.view(-1, 1) + contextual_seq_len, + mask < x_lengths.view(-1, 1), + ), + ) + return padded_x[mask_prefix.view(-1), :], padded_x[mask_l2.view(-1), :] + + +def pytorch_hstu_concat_l2_embeddings( + max_prefix_len: int, + prefix_x: torch.Tensor, + prefix_offsets: torch.Tensor, + max_l2_len: int, + l2_x: torch.Tensor, + l2_offsets: torch.Tensor, + contextual_seq_len: int, +) -> torch.Tensor: + padded_prefix_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=prefix_x, + offsets=[prefix_offsets], + max_lengths=[max_prefix_len], + padding_value=0.0, + ) + padded_l2_x = torch.ops.fbgemm.jagged_to_padded_dense( + values=l2_x, + offsets=[l2_offsets], + max_lengths=[max_l2_len], + padding_value=0.0, + ) + padded_x = torch.cat( + [ + padded_l2_x[:, 0:contextual_seq_len, :], + padded_prefix_x, + padded_l2_x[:, contextual_seq_len:, :], + ], + dim=1, + ) + mask = fx_arange(max_prefix_len + max_l2_len, device=prefix_x.device).view(1, -1) + prefix_lengths = prefix_offsets[1:] - prefix_offsets[:-1] + l2_lengths = l2_offsets[1:] - l2_offsets[:-1] + mask = torch.logical_or( + mask < prefix_lengths.view(-1, 1) + contextual_seq_len, + torch.logical_and( + mask >= max_prefix_len + contextual_seq_len, + mask < max_prefix_len + l2_lengths.view(-1, 1), + ), + ) + return padded_x.flatten(0, 1)[mask.view(-1), :] diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_layer_norm.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_layer_norm.py new file mode 100644 index 0000000000..0666212ce6 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_layer_norm.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# pyre-strict + + +from typing import List + +import torch + + +def pytorch_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> torch.Tensor: + dtype = x.dtype + return torch.nn.functional.layer_norm( + x.to(torch.float32), + normalized_shape, + weight.to(torch.float32), + bias.to(torch.float32), + eps, + ).to(dtype) + + +def pytorch_rms_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + eps: float, + silu: bool = False, +) -> torch.Tensor: + dtype = x.dtype + x_float = x.to(torch.float32) + normalized = torch.nn.functional.rms_norm( + x_float, + normalized_shape, + weight.to(torch.float32), + eps, + ) + if silu: + normalized = torch.nn.functional.silu(normalized) + return normalized.to(dtype) + + +def pytorch_swish_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, +) -> torch.Tensor: + dtype = x.dtype + x = x.to(torch.float32) + return ( + x + * torch.sigmoid( + torch.nn.functional.layer_norm( + x, + normalized_shape, + weight.to(torch.float32), + bias.to(torch.float32), + eps, + ) + ) + ).to(dtype) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_position.py b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_position.py new file mode 100644 index 0000000000..dbe0c7efe9 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/pytorch/pt_position.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional + +import torch +from generative_recommenders.common import ( + fx_unwrap_optional_tensor, + jagged_to_padded_dense, +) + +try: + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") + torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") +except OSError: + pass + + +@torch.fx.wrap +def torch_arange(end: int, device: torch.device) -> torch.Tensor: + return torch.arange(end, device=device) + + +@torch.fx.wrap +def _get_col_indices( + max_seq_len: int, + max_contextual_seq_len: int, + max_pos_ind: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, +) -> torch.Tensor: + B = seq_lengths.size(0) + col_indices = torch.arange(max_seq_len, device=seq_lengths.device).expand( + B, max_seq_len + ) + if num_targets is not None: + if interleave_targets: + high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) * 2 + else: + high_inds = seq_lengths - fx_unwrap_optional_tensor(num_targets) + col_indices = torch.clamp(col_indices, max=high_inds.view(-1, 1)) + col_indices = high_inds.view(-1, 1) - col_indices + else: + col_indices = seq_lengths.view(-1, 1) - col_indices + col_indices = col_indices + max_contextual_seq_len + col_indices = torch.clamp(col_indices, max=max_pos_ind - 1) + if max_contextual_seq_len > 0: + col_indices[:, :max_contextual_seq_len] = torch.arange( + 0, + max_contextual_seq_len, + device=col_indices.device, + dtype=col_indices.dtype, + ).view(1, -1) + return col_indices + + +def pytorch_add_timestamp_positional_embeddings( + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, +) -> torch.Tensor: + max_pos_ind = pos_embeddings.size(0) + # position encoding + pos_inds = _get_col_indices( + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + ) + B, _ = pos_inds.shape + # timestamp encoding + num_time_buckets = ts_embeddings.size(1) - 1 + time_bucket_increments = 60.0 + time_bucket_divisor = 1.0 + time_delta = 0 + timestamps = jagged_to_padded_dense( + values=timestamps.unsqueeze(-1), + offsets=[seq_offsets], + max_lengths=[max_seq_len], + padding_value=0, + ).squeeze(-1) + query_time = torch.gather( + timestamps, dim=1, index=(seq_lengths - 1).unsqueeze(1).clamp(min=0) + ) + ts = query_time - timestamps + ts = ts + time_delta + ts = ts.clamp(min=1e-6) / time_bucket_increments + if time_bucket_fn == "log": + ts = torch.log(ts) + else: + ts = torch.sqrt(ts) + ts = (ts / time_bucket_divisor).clamp(min=0).int() + ts = torch.clamp( + ts, + min=0, + max=num_time_buckets, + ) + position_embeddings = torch.index_select( + pos_embeddings, 0, pos_inds.reshape(-1) + ).view(B, max_seq_len, -1) + time_embeddings = torch.index_select(ts_embeddings, 0, ts.reshape(-1)).view( + B, max_seq_len, -1 + ) + return torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output( + seq_embeddings, + [seq_offsets], + (time_embeddings + position_embeddings).to(seq_embeddings.dtype), + )[0] diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_addmm.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_addmm.py new file mode 100644 index 0000000000..2231387fb6 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_addmm.py @@ -0,0 +1,1040 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import List, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + tlx = None + HAS_TLX = False + +from generative_recommenders.common import triton_autotune, triton_cc +from generative_recommenders.ops.utils import is_sm100 + +try: + # @manual=//triton:triton + from triton.tools.tensor_descriptor import TensorDescriptor + + TMA_AVAILABLE = True +except ImportError: + TMA_AVAILABLE = False + pass + + +ENABLE_FULL_TURNING_SPACE = False + + +def _check_tma_alignment( + x: torch.Tensor, w: torch.Tensor, y: torch.Tensor, min_alignment: int = 16 +) -> bool: + """Check if tensors meet TMA alignment requirements. + + TMA (Tensor Memory Accelerator) on H100 requires: + 1. Base addresses to be 64-byte aligned + 2. Dimensions to be multiples of 64 for optimal performance + 3. Contiguous inner dimensions (stride=1) + + Args: + x: Input tensor [M, K] + w: Weight tensor [K, N] + y: Bias tensor [N] or [M, N] + min_alignment: Minimum alignment requirement (default: 64) + + Returns: + True if all tensors meet TMA alignment requirements + """ + _, K = x.shape + KB, N = w.shape + assert K == KB, f"incompatible dimensions {K}, {KB}" + + is_y_1d = y.dim() == 1 + NY = y.shape[0] if is_y_1d else y.shape[1] + assert N == NY, f"incompatible dimensions {N}, {NY}" + + return (K % min_alignment == 0) and (N % min_alignment == 0) + + +def get_mm_configs(pre_hook=None) -> List[triton.Config]: + if torch.version.hip: + if ENABLE_FULL_TURNING_SPACE: + block_m_range = [32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [32, 64] + group_m_range = [4, 8] + matrix_instr_nonkdim_range = [16] + waves_per_eu_range = [0] + kpack_range = [1, 2] + num_warps_range = [4, 8] + num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] + else: + block_m_range = [256] + block_n_range = [256] + block_k_range = [32] + group_m_range = [8] + matrix_instr_nonkdim_range = [16] + waves_per_eu_range = [0] + kpack_range = [2] + num_warps_range = [8] + num_stage_range = [2] if triton.__version__ >= "3.2.0" else [0] + + return [ + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "kpack": kpack, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=pre_hook, + ) + for block_m in block_m_range + for block_n in block_n_range + for block_k in block_k_range + for group_m in group_m_range + for matrix_instr_nonkdim in matrix_instr_nonkdim_range + for waves_per_eu in waves_per_eu_range + for kpack in kpack_range + for num_stages in num_stage_range + for num_warps in num_warps_range + ] + else: + block_m_range = [32, 64, 128, 256] + block_n_range = [32, 64, 128, 256] + block_k_range = [32, 64] + group_m_range = [4, 8] + # WARP_SPECIALIZE only works with num_warps >=4 + num_warps_range = [4, 8] if is_sm100() else [2, 4, 8] + num_stage_range = [2, 3, 4, 5] + if ENABLE_FULL_TURNING_SPACE: + return [ + triton.Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "GROUP_M": group_m, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=pre_hook, + ) + for block_m in block_m_range + for block_n in block_n_range + for block_k in block_k_range + for group_m in group_m_range + for num_stages in num_stage_range + for num_warps in num_warps_range + ] + else: + configs = [ + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 64, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=5, + num_warps=2, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 256, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=3, + num_warps=8, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 256, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 128, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=4, + num_warps=4, + pre_hook=pre_hook, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 32, + "BLOCK_K": 32, + "GROUP_M": 8, + }, + num_stages=5, + num_warps=2, + pre_hook=pre_hook, + ), + ] + if is_sm100(): + configs += [ + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 256, + "BLOCK_K": 64, + "GROUP_M": 8, + }, + num_stages=3, + num_warps=4, + pre_hook=pre_hook, + ), + ] + return [c for c in configs if c.num_warps >= 4] + + return configs + + +@triton_cc( + annotations={ + "M": "i32", + "N": ("i32", 16), + "K": ("i32", 16), + "stride_xm": ("i32", 16), + "stride_xk": ("i32", 1), + "stride_wk": ("i32", 16), + "stride_wn": ("i32", 1), + "stride_ym": ("i32", 16), + "stride_yn": ("i32", 1), + "stride_zm": ("i32", 16), + "stride_zn": ("i32", 1), + }, +) +@triton_autotune( + configs=get_mm_configs(), + key=["N", "K"], +) +@triton.jit +def _addmm_fwd( + x_ptr, + w_ptr, + y_ptr, + z_ptr, + M, + N, + K, + stride_xm, + stride_xk, + stride_wk, + stride_wn, + stride_ym, + stride_yn, + stride_zm, + stride_zn, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, +): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K) + offs_n = tl.arange(0, BLOCK_N) + mask_m = (pid_m * BLOCK_M + offs_m)[:, None] < M + mask_n = (pid_n * BLOCK_N + offs_n)[None, :] < N + x_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_xm + x_ptrs = x_ptr + (offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk) + w_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_wn + w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + mask_k = offs_k[None, :] < K - k * BLOCK_K + x = tl.load(x_ptrs, mask=mask_k & mask_m, other=0.0) + mask_k = offs_k[:, None] < K - k * BLOCK_K + w = tl.load(w_ptrs, mask=mask_k & mask_n, other=0.0) + accumulator += tl.dot(x, w, allow_tf32=ALLOW_TF32) + x_ptrs += BLOCK_K * stride_xk + w_ptrs += BLOCK_K * stride_wk + + z_mask = mask_m & mask_n + if BROADCAST_Y: + # y is a vector, broadcast to add to z + y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn + y_ptrs = y_ptr + stride_yn * offs_n[None, :] + y = tl.load(y_ptrs, mask=mask_n) + else: + y_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_ym + y_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_yn + y_ptrs = y_ptr + stride_ym * offs_m[:, None] + stride_yn * offs_n[None, :] + y = tl.load(y_ptrs, mask=z_mask) + z = (accumulator + y.to(tl.float32)).to(z_ptr.dtype.element_ty) + z_ptr += pid_m.to(tl.int64) * BLOCK_M * stride_zm + z_ptr += pid_n.to(tl.int64) * BLOCK_N * stride_zn + z_ptrs = z_ptr + stride_zm * offs_m[:, None] + stride_zn * offs_n[None, :] + tl.store(z_ptrs, z, mask=z_mask) + + +def _addmm_tma_set_block_size_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + BLOCK_K = nargs["BLOCK_K"] + nargs["x_desc"].block_shape = [BLOCK_M, BLOCK_K] + nargs["w_desc"].block_shape = [BLOCK_K, BLOCK_N] + nargs["z_desc"].block_shape = [BLOCK_M, BLOCK_N] + if nargs["BROADCAST_Y"]: + nargs["y_desc"].block_shape = [1, BLOCK_N] + else: + nargs["y_desc"].block_shape = [BLOCK_M, BLOCK_N] + + +@triton.jit +def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + return pid_m, pid_n + + +@triton_autotune( + configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook), + key=["N", "K", "WARP_SPECIALIZE"], +) +@triton.jit +def _addmm_fwd_tma_persistent( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + WARP_SPECIALIZE: tl.constexpr, + NUM_SMS: tl.constexpr, +): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + k_tiles = tl.cdiv(K, BLOCK_K) + num_tiles = num_pid_m * num_pid_n + + num_pid_in_group = GROUP_M * num_pid_n + + for tile_id in tl.range( + start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE + ): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in tl.range(0, k_tiles, warp_specialize=WARP_SPECIALIZE): + offs_k = k * BLOCK_K + x = x_desc.load([offs_xm, offs_k]) + w = w_desc.load([offs_k, offs_wn]) + accumulator = tl.dot(x, w, accumulator, allow_tf32=ALLOW_TF32) + if BROADCAST_Y: + y = y_desc.load([0, offs_wn]) + else: + y = y_desc.load([offs_xm, offs_wn]) + z = (accumulator + y.to(tl.float32)).to(z_desc.dtype) + z_desc.store([offs_xm, offs_wn], z) + + +@triton_autotune( + configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook), + key=["N", "K"], +) +@triton.jit +def _addmm_fwd_tma_ws( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + NUM_SMEM_BUFFERS: tl.constexpr, +): + x_buffers = tlx.local_alloc((BLOCK_M, BLOCK_K), x_desc.dtype, NUM_SMEM_BUFFERS) + w_buffers = tlx.local_alloc((BLOCK_K, BLOCK_N), w_desc.dtype, NUM_SMEM_BUFFERS) + acc_tmem_buffer = tlx.local_alloc( + (BLOCK_M, BLOCK_N), tl.float32, tl.constexpr(1), tlx.storage_kind.tmem + ) + + if BROADCAST_Y: + y_buffer = tlx.local_alloc((1, BLOCK_N), y_desc.dtype, tl.constexpr(1)) + else: + y_buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), y_desc.dtype, tl.constexpr(1)) + z_buffer = tlx.local_alloc((BLOCK_M, BLOCK_N), z_desc.dtype, tl.constexpr(1)) + + smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + y_load_barrier = tlx.alloc_barriers(num_barriers=1, arrive_count=1) + + with tlx.async_tasks(): + # Producer task: TMA loads + with tlx.async_task("default"): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + k_tiles = tl.cdiv(K, BLOCK_K) + + load_phase = 0 + for k in range(0, k_tiles): + buf = k % int(NUM_SMEM_BUFFERS) + + # Wait for buffer to be free + if k >= NUM_SMEM_BUFFERS: + tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1) + + offs_k = k * BLOCK_K + tlx.barrier_expect_bytes( + smem_full_bars[buf], + 2 * (BLOCK_M * BLOCK_K + BLOCK_K * BLOCK_N), + ) + tlx.async_descriptor_load( + x_desc, x_buffers[buf], [offs_xm, offs_k], smem_full_bars[buf] + ) + tlx.async_descriptor_load( + w_desc, w_buffers[buf], [offs_k, offs_wn], smem_full_bars[buf] + ) + + load_phase = load_phase ^ (buf == NUM_SMEM_BUFFERS - 1) + + # Consumer task: async_dot MMA + with tlx.async_task(num_warps=4, num_regs=232): + pid_0, pid_1 = tl.program_id(axis=0), tl.program_id(axis=1) + pid = pid_0 * tl.num_programs(axis=1) + pid_1 + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + k_tiles = tl.cdiv(K, BLOCK_K) + + # Start async load of y early + y_buf_view = tlx.local_view(y_buffer, 0) + y_load_bar = tlx.local_view(y_load_barrier, 0) + if BROADCAST_Y: + tlx.barrier_expect_bytes(y_load_bar, 1 * BLOCK_N * 2) + tlx.async_descriptor_load(y_desc, y_buf_view, [0, offs_wn], y_load_bar) + else: + tlx.barrier_expect_bytes(y_load_bar, BLOCK_M * BLOCK_N * 2) + tlx.async_descriptor_load( + y_desc, y_buf_view, [offs_xm, offs_wn], y_load_bar + ) + + dot_phase = 0 + for k in range(0, k_tiles): + buf = k % int(NUM_SMEM_BUFFERS) + tlx.barrier_wait(smem_full_bars[buf], dot_phase) + + tlx.async_dot( + x_buffers[buf], + w_buffers[buf], + acc_tmem_buffer[0], + use_acc=k > 0, + mBarriers=[smem_empty_bars[buf]], + out_dtype=tl.float32, + ) + + dot_phase = dot_phase ^ (buf == NUM_SMEM_BUFFERS - 1) + + last_buf = (k_tiles - 1) % NUM_SMEM_BUFFERS + last_dot_phase = dot_phase ^ (last_buf == NUM_SMEM_BUFFERS - 1) + tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase) + + tmem_result = tlx.local_load(acc_tmem_buffer[0]) + + tlx.barrier_wait(y_load_bar, 0) + y = tlx.local_load(y_buf_view) + + z = (tmem_result + y.to(tl.float32)).to(z_desc.dtype) + z_buf_view = tlx.local_view(z_buffer, 0) + tlx.local_store(z_buf_view, z) + tlx.async_descriptor_store(z_desc, z_buf_view, [offs_xm, offs_wn]) + tlx.async_descriptor_store_wait(0) + + +@triton_autotune( + configs=get_mm_configs(pre_hook=_addmm_tma_set_block_size_hook), + key=["N", "K"], +) +@triton.jit +def _addmm_fwd_tma_ws_persistent( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BROADCAST_Y: tl.constexpr, + NUM_SMEM_BUFFERS: tl.constexpr, + NUM_TMEM_BUFFERS: tl.constexpr, + NUM_SMS: tl.constexpr, +): + # Allocate buffers once for all tiles + x_buffers = tlx.local_alloc((BLOCK_M, BLOCK_K), x_desc.dtype, NUM_SMEM_BUFFERS) + w_buffers = tlx.local_alloc((BLOCK_K, BLOCK_N), w_desc.dtype, NUM_SMEM_BUFFERS) + tmem_buffers = tlx.local_alloc( + (BLOCK_M, BLOCK_N), tl.float32, NUM_TMEM_BUFFERS, tlx.storage_kind.tmem + ) + + # Barriers for producer <-> MMA + smem_full_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + smem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_SMEM_BUFFERS, arrive_count=1) + # Barriers for MMA <-> Epilogue + tmem_full_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + tmem_empty_bars = tlx.alloc_barriers(num_barriers=NUM_TMEM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # Epilogue consumer: loads Y, adds bias, stores Z + with tlx.async_task("default"): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + + tmem_read_phase = 0 + cur_tmem_buf = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + + # Wait for MMA to finish computing this tile + tlx.barrier_wait(tmem_full_bars[cur_tmem_buf], tmem_read_phase) + tmem_read_phase = tmem_read_phase ^ ( + cur_tmem_buf == int(NUM_TMEM_BUFFERS) - 1 + ) + + # Load Y synchronously + if BROADCAST_Y: + y = y_desc.load([0, offs_wn]) + else: + y = y_desc.load([offs_xm, offs_wn]) + + # Load result from TMEM and add bias + acc_tmem = tmem_buffers[cur_tmem_buf] + result = tlx.local_load(acc_tmem) + z = (result + y.to(tl.float32)).to(z_desc.dtype) + + # Store result directly via TMA + z_desc.store([offs_xm, offs_wn], z) + + # Signal MMA that this TMEM buffer is now free + tlx.barrier_arrive(tmem_empty_bars[cur_tmem_buf], 1) + + cur_tmem_buf = (cur_tmem_buf + 1) % int(NUM_TMEM_BUFFERS) + + # MMA consumer: performs matrix multiplication + with tlx.async_task(num_warps=4, num_regs=232): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + dot_phase = 0 + tmem_write_phase = 1 + cur_tmem_buf = 0 + processed_k_iters = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + + # Wait for epilogue to finish with this TMEM buffer + tlx.barrier_wait(tmem_empty_bars[cur_tmem_buf], tmem_write_phase) + tmem_write_phase = tmem_write_phase ^ ( + cur_tmem_buf == int(NUM_TMEM_BUFFERS) - 1 + ) + + # Perform K-dimension reduction + for k in range(0, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + tlx.barrier_wait(smem_full_bars[buf], dot_phase) + + tlx.async_dot( + x_buffers[buf], + w_buffers[buf], + tmem_buffers[cur_tmem_buf], + use_acc=(k > 0), + mBarriers=[smem_empty_bars[buf]], + out_dtype=tl.float32, + ) + + dot_phase = dot_phase ^ (buf == int(NUM_SMEM_BUFFERS) - 1) + + # Wait for last MMA to complete + last_buf = (processed_k_iters + k_tiles - 1) % int(NUM_SMEM_BUFFERS) + last_dot_phase = dot_phase ^ (last_buf == int(NUM_SMEM_BUFFERS) - 1) + tlx.barrier_wait(smem_empty_bars[last_buf], last_dot_phase) + + # Signal epilogue that result is ready + tlx.barrier_arrive(tmem_full_bars[cur_tmem_buf], 1) + + cur_tmem_buf = (cur_tmem_buf + 1) % int(NUM_TMEM_BUFFERS) + processed_k_iters += k_tiles + + # Producer: TMA loads for X and W + with tlx.async_task(num_warps=1, num_regs=24): + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_M * num_pid_n + num_tiles = num_pid_m * num_pid_n + k_tiles = tl.cdiv(K, BLOCK_K) + + load_phase = 0 + processed_k_iters = 0 + + for tile_id in range(start_pid, num_tiles, NUM_SMS): + pid_m, pid_n = _compute_pid( + tile_id, num_pid_in_group, num_pid_m, GROUP_M, NUM_SMS + ) + offs_xm = pid_m * BLOCK_M + offs_wn = pid_n * BLOCK_N + + for k in range(0, k_tiles): + buf = (processed_k_iters + k) % int(NUM_SMEM_BUFFERS) + + # Wait for buffer to be free + tlx.barrier_wait(smem_empty_bars[buf], load_phase ^ 1) + + offs_k = k * BLOCK_K + tlx.barrier_expect_bytes( + smem_full_bars[buf], + 2 * (BLOCK_M + BLOCK_N) * BLOCK_K, + ) + tlx.async_descriptor_load( + x_desc, x_buffers[buf], [offs_xm, offs_k], smem_full_bars[buf] + ) + tlx.async_descriptor_load( + w_desc, w_buffers[buf], [offs_k, offs_wn], smem_full_bars[buf] + ) + + load_phase = load_phase ^ (buf == int(NUM_SMEM_BUFFERS) - 1) + + processed_k_iters += k_tiles + + +@torch.fx.wrap +def triton_addmm_fwd_tma_persistent( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + warp_specialize: bool = False, +) -> torch.Tensor: + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + # A dummy block value that will be overwritten when we have the real block size + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + def grid(meta): + nonlocal x_desc, w_desc, z_desc + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + min( + NUM_SMS, + triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + ), + ) + + _addmm_fwd_tma_persistent[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + WARP_SPECIALIZE=warp_specialize, + NUM_SMS=NUM_SMS, + ) + return z + + +@torch.fx.wrap +def triton_addmm_fwd_tma_ws_tlx( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + # A dummy block value that will be overwritten when we have the real block size + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + triton.cdiv(M, BLOCK_M), + triton.cdiv(N, BLOCK_N), + ) + + _addmm_fwd_tma_ws[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + NUM_SMEM_BUFFERS=2, # Double buffering + ) + return z + + +@torch.fx.wrap +def triton_addmm_fwd_tma_ws_persistent_tlx( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + _, N = w.shape + + is_y_1d = y.dim() == 1 + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + NUM_SMEM_BUFFERS = 2 + NUM_TMEM_BUFFERS = 2 + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + + # A dummy block value that will be overwritten by the hook + dummy_block = [1, 1] + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + x_desc = TensorDescriptor(x, x.shape, x.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + w_desc = TensorDescriptor(w, w.shape, w.stride(), dummy_block) + y = y.reshape(1, -1) if is_y_1d else y + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + y_desc = TensorDescriptor(y, y.shape, y.stride(), dummy_block) + # pyre-ignore[6]: In call `TensorDescriptor.__init__`, for 2nd positional + # argument, expected `List[int]` but got `Size` + z_desc = TensorDescriptor(z, z.shape, z.stride(), dummy_block) + + def grid(meta): + BLOCK_M = meta["BLOCK_M"] + BLOCK_N = meta["BLOCK_N"] + return ( + min( + NUM_SMS, + triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), + ), + ) + + _addmm_fwd_tma_ws_persistent[grid]( + x_desc, + w_desc, + y_desc, + z_desc, + M, + N, + K, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + NUM_SMEM_BUFFERS=NUM_SMEM_BUFFERS, + NUM_TMEM_BUFFERS=NUM_TMEM_BUFFERS, + NUM_SMS=NUM_SMS, + ) + return z + + +@torch.fx.wrap +def triton_addmm_fwd( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + M, K = x.shape + KB, N = w.shape + assert K == KB, f"incompatible dimensions {K}, {KB}" + + is_y_1d = y.dim() == 1 + NY = y.shape[0] if is_y_1d else y.shape[1] + assert N == NY, f"incompatible dimensions {N}, {NY}" + + # Allocate output + z = torch.empty((M, N), device=x.device, dtype=x.dtype) + if M == 0 or N == 0: + return z + + grid = lambda meta: ( # noqa E731 + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + + _addmm_fwd[grid]( + x, + w, + y, + z, + M, + N, + K, + x.stride(0), + x.stride(1), + w.stride(0), + w.stride(1), + y.stride(0) if not is_y_1d else 0, + y.stride(1) if not is_y_1d else y.stride(0), + z.stride(0), + z.stride(1), + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BROADCAST_Y=is_y_1d, + ) + return z + + +def triton_addmm_bwd( + x: torch.Tensor, + w: torch.Tensor, + dz: torch.Tensor, + is_y_1d: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if is_y_1d: + dy = torch.sum(dz, dim=0) + else: + dy = dz + dw = torch.mm(x.t(), dz) + dx = torch.mm(dz, w.t()) + + return dx, dw, dy + + +@torch.fx.wrap +def maybe_triton_addmm_fwd( + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, +) -> torch.Tensor: + # triton addmm is slower than torch (cublas) on AMD/Blackwell. + # Default to pytorch addmm on AMD/Blackwell for now. + if is_sm100() or torch.version.hip is not None: + return torch.addmm(y, x, w) + else: + return triton_addmm_fwd(x=x, w=w, y=y) + + +class _AddMmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + w: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + ctx.save_for_backward(x, w) + ctx.is_y_1d = y.dim() == 1 + if is_sm100() and TMA_AVAILABLE and _check_tma_alignment(x, w, y): + if x.dtype == torch.float32 or HAS_TLX == False: + # use TMA persistent kernel on sm100 + return triton_addmm_fwd_tma_persistent(x, w, y, warp_specialize=True) + else: + return triton_addmm_fwd_tma_ws_persistent_tlx( + x, w, y + ) # tlx.async_dot doesn't support fp32 inputs because of WGMMA requirements + else: + return triton_addmm_fwd(x, w, y) + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dz: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + (x, w) = ctx.saved_tensors + return triton_addmm_bwd(x, w, dz, ctx.is_y_1d) + + +def triton_addmm( + input: torch.Tensor, + mat1: torch.Tensor, + mat2: torch.Tensor, +) -> torch.Tensor: + return _AddMmFunction.apply(mat1, mat2, input) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_attention_utils.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_attention_utils.py new file mode 100644 index 0000000000..4022709c51 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_attention_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + + +@triton.jit +def acc_dq( + dq_ptrs_trans, + start_m, + stride_dqm, + k, + dqk_trans, + alpha, + mask_m, + MAX_SEQ_LEN, + LOCK, + BLOCK_M: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ALLOW_TF32: tl.constexpr, +): + if ATOMIC_ADD: + lock_id = start_m // BLOCK_M + stride_lock = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) + lock = LOCK + tl.program_id(0) * stride_lock + lock_id + tl.debug_barrier() # add a barrier to force sync + while tl.atomic_cas(lock, 0, 1) == 1: + pass + dq_trans = tl.load( + dq_ptrs_trans + start_m * stride_dqm, + mask=mask_m[None, :], + other=0.0, + eviction_policy="evict_last", + ) + dq_trans += tl.dot(tl.trans(k), dqk_trans, allow_tf32=ALLOW_TF32) * alpha + dq_trans = dq_trans.to(k.dtype) + tl.store( + dq_ptrs_trans + start_m * stride_dqm, + dq_trans, + mask=mask_m[None, :], + eviction_policy="evict_last", + ) + if ATOMIC_ADD: + tl.atomic_xchg(lock, 0) # pyre-ignore [61] diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_attention.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_attention.py new file mode 100644 index 0000000000..36080561fc --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_attention.py @@ -0,0 +1,3011 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python3 + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +try: + # @manual=//triton:triton + import triton.language.extra.tlx as tlx # type: ignore + + HAS_TLX = True +except ImportError: + # suppress type checking errors + tlx = None + + HAS_TLX = False + +from generative_recommenders.common import ( + autotune_max_seq_len, + prev_power_of_2, + switch_to_contiguous_if_needed, + triton_autotune, +) + +from triton.language.extra.libdevice import ( # @manual=//triton:triton + fast_dividef, + fast_expf, +) + +try: + # @manual=//triton:triton + from triton.tools.tensor_descriptor import TensorDescriptor + + tensor_descriptor_tma = True +except ImportError: + tensor_descriptor_tma = False + +try: + from generative_recommenders.ops.triton.fb.triton_attention_utils import acc_dq +except ImportError: + from generative_recommenders.ops.triton.triton_attention_utils import acc_dq + + +def _host_descriptor_pre_hook(nargs): + if not tensor_descriptor_tma: + return + + if not isinstance(nargs["Q"], TensorDescriptor): + return + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + BLOCK_D_Q = nargs["BLOCK_D_Q"] + BLOCK_D_V = nargs["BLOCK_D_V"] + if "USE_TLX" in nargs and nargs["USE_TLX"]: + BLOCK_M = BLOCK_M // nargs["NUM_MMA_GROUPS"] + nargs["Q"].block_shape = [BLOCK_M, BLOCK_D_Q] + nargs["V"].block_shape = [BLOCK_N, BLOCK_D_V] + nargs["K"].block_shape = [BLOCK_N, BLOCK_D_Q] + + +def _get_fw_configs() -> List[triton.Config]: # noqa: C901 + configs = [] + if torch.version.hip: + for BLOCK_M in [32, 64, 128]: + for BLOCK_N in [32, 64]: + for num_stages in [1, 2]: + for num_warps in [4, 8]: + for matrix_instr_nonkdim in [16, 32]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": 0, + "kpack": 2, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + else: + configs = [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=2, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 32}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=2, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64}, + num_stages=4, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_stages=4, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128}, + num_stages=2, + num_warps=8, + pre_hook=_host_descriptor_pre_hook, + ), + ] + + # Add 'USE_TLX' : False, 'NUM_BUFFERS': 1, 'NUM_MMA_WARPS_PER_GROUP': 1, 'NUM_MMA_GROUPS': 1 to non-TLX configs + for config in configs: + if not config.kwargs.get("USE_TLX", False): + config.kwargs["USE_TLX"] = False + config.kwargs["NUM_BUFFERS"] = 1 + config.kwargs["NUM_MMA_WARPS_PER_GROUP"] = 1 + config.kwargs["NUM_MMA_GROUPS"] = 1 + + # Add TLX configs if TLX is available + if HAS_TLX: + try: + device_capability = torch.cuda.get_device_capability()[0] + except (RuntimeError, AssertionError): + # No CUDA device available + device_capability = None + + if device_capability == 9: + # H100 configs + configs.append( + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "USE_TLX": True, + "NUM_BUFFERS": 2, + "NUM_MMA_WARPS_PER_GROUP": 4, + "NUM_MMA_GROUPS": 2, + }, + num_stages=0, + num_warps=4, + pre_hook=_host_descriptor_pre_hook, + ), + ) + + return configs + + +@triton.jit +def _hstu_attn_fwd_one_block( # noqa: C901 + start_n, + seq_len, + offs_m, + offs_n, + q, + K, + V, + K_block_ptr, + V_block_ptr, + offset_kh, + offset_vh, + seq_start, + n_targets, + alpha, + MAX_SEQ_LEN, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_N: tl.constexpr, + ENABLE_TMA: tl.constexpr, +): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = None + qk = None + if ENABLE_TMA: + k = K.load( + [(seq_start + start_n).to(tl.int32), offset_kh.to(tl.int32)], + ) + # tma can only be loaded in one order, use trans afterwards + qk = tl.dot(q, tl.trans(k), allow_tf32=ALLOW_TF32) * alpha + else: + k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero") + qk = tl.dot(q, k, allow_tf32=ALLOW_TF32) * alpha + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + offs_m = offs_m - contextual_seq_len + 1 + offs_m = tl.where( + offs_m > 0, + offs_m, + 0, + ) + offs_n = offs_n - contextual_seq_len + 1 + offs_n = tl.where( + offs_n > 0, + offs_n, + 0, + ) + max_ids = max_ids - contextual_seq_len + 1 + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + v = None + if ENABLE_TMA: + v = V.load( + [(seq_start + start_n).to(tl.int32), offset_vh.to(tl.int32)], + ) + else: + v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero") + silu = silu.to(v.dtype) + return tl.dot(silu, v, allow_tf32=ALLOW_TF32) + + +@triton.jit +def _hstu_attn_fwd_compute( # noqa C901 + Q, + K, + V, + H, + DimQ, + DimV, + workspace_ptr, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + MAX_SEQ_LEN, + DeltaSize, + contextual_seq_len, + max_attn_len, + off_z, + off_h, + pid, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + off_h = off_h.to(tl.int64) + off_z = off_z.to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + start_m = (start_m_delta + seq_len - DeltaSize).to(tl.int32) + else: + start_m_delta = 0 + start_m = pid * BLOCK_M + if start_m < seq_len: + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + + # initialize offsets + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + Q_block_ptr = None + K_block_ptr = None + V_block_ptr = None + if not ENABLE_TMA: + if IS_DELTA_Q: + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * DeltaSize * stride_qm, + shape=(DeltaSize, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m_delta, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + else: + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + seq_start * stride_qm, + shape=(seq_len, BLOCK_D_Q), + strides=(stride_qm, 1), + offsets=(start_m, 0), + block_shape=(BLOCK_M, BLOCK_D_Q), + order=(1, 0), + ) + q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero") + + K_block_ptr = tl.make_block_ptr( + base=K + off_h * stride_kh + seq_start * stride_kn, + shape=(BLOCK_D_Q, seq_len), + strides=(1, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_D_Q, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + off_h * stride_vh + seq_start * stride_vn, + shape=(seq_len, BLOCK_D_V), + strides=(stride_vn, 1), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_D_V), + order=(1, 0), + ) + else: + if IS_DELTA_Q: + q = Q.load( + [ + (off_z * DeltaSize + start_m_delta).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ] + ) + else: + q = Q.load( + [ + (seq_start + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ] + ) + + acc = tl.zeros([BLOCK_M, BLOCK_D_V], dtype=tl.float32) + if HAS_MULTIPLE_TARGETS: + uih_end = seq_len - n_targets + else: + uih_end = seq_len + if HAS_CONTEXTUAL_SEQ_LEN is True and start_m < contextual_seq_len: + # uih_end must be larger than start_m + low = 0 + high = seq_len + else: + low = 0 + high = start_m + BLOCK_M + if HAS_MAX_ATTN_LEN: + if start_m > uih_end: + low = uih_end - max_attn_len + else: + low = start_m - max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + low = low if low > contextual_seq_len else 0 + else: + low = low if low > 0 else 0 + if HAS_MULTIPLE_TARGETS: + uih_end = (uih_end + BLOCK_N - 1) // BLOCK_N * BLOCK_N + if uih_end < start_m: + high = seq_len - n_targets + + if low > 0: + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, low)) + V_block_ptr = tl.advance(V_block_ptr, (low, 0)) + end_n = low + for start_n in range(low, high, BLOCK_N): + acc += _hstu_attn_fwd_one_block( + start_n=start_n, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n + start_n, + q=q, + K=K, + V=V, + K_block_ptr=K_block_ptr, + V_block_ptr=V_block_ptr, + offset_kh=off_h * stride_kh, + offset_vh=off_h * stride_vh, + seq_start=seq_start, + n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + ) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + end_n += BLOCK_N + + if HAS_MULTIPLE_TARGETS: + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + offset = (low_delta - end_n).to(tl.int32) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + for start_delta in tl.range( + low_delta, high_delta, BLOCK_N, num_stages=0 + ): + acc += _hstu_attn_fwd_one_block( + start_n=start_delta, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n + start_delta, + q=q, + K=K, + V=V, + K_block_ptr=K_block_ptr, + V_block_ptr=V_block_ptr, + offset_kh=off_h * stride_kh, + offset_vh=off_h * stride_vh, + seq_start=seq_start, + n_targets=n_targets if HAS_MULTIPLE_TARGETS else None, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + ) + if not ENABLE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # Don't use TMA in Jagged case since we don't want to overwrite + # the output of another sequence + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + offs_m_delta = start_m_delta + tl.arange(0, BLOCK_M) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_z * DeltaSize * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m_delta[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) + else: + # rematerialize offsets to save registers + start_m = pid * BLOCK_M + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + seq_start * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton.jit +def _hstu_attn_fwd_compute_main_loop_tlx( # noqa C901 + low, + high, + seq_len, + offs_m, + offs_n, + acc, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + v_dtype, + n_targets, + alpha, + end_n, + loop_trip_cnt, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + cid: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + WAIT_FOR_Q: tl.constexpr, +): + if WAIT_FOR_Q: + # wait for the Q buffer to be populated by the producer + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_wait(q_full, 0) + + q_tile = tlx.local_view(q_tiles, cid) + + for start in tl.range(low + BLOCK_N, high, BLOCK_N, num_stages=0): + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + start_n = tl.multiple_of(start, BLOCK_N) + offs_n_start = offs_n + offs_n = offs_n_start + start_n + + # wait for the K buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, buf_id) + tlx.barrier_wait(k_full, kv_phase) + k_tile = tlx.local_view(k_tiles, buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + # second + qk = tlx.async_dot(q_tile, k_tile) + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + # wait for the V buffer to be populated by the producer + v_full = tlx.local_view(v_fulls, buf_id) + tlx.barrier_wait(v_full, kv_phase) + v_tile = tlx.local_view(v_tiles, buf_id) + acc = tlx.async_dot(silu, v_tile, acc) + # wait for the MMA using to complete + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, buf_id) + tlx.barrier_arrive(v_empty, 1) + + end_n += BLOCK_N + + # increment loop trip counts + loop_trip_cnt += 1 + + return acc, end_n, loop_trip_cnt + + +@triton.jit +def _hstu_attn_fwd_compute_main_loop_tlx_pipelined( # noqa C901 + low, + high, + seq_len, + offs_m, + offs_n, + acc, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + v_dtype, + n_targets, + alpha, + end_n, + loop_trip_cnt, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + cid: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + WAIT_FOR_Q: tl.constexpr, +): + if WAIT_FOR_Q: + # wait for the Q buffer to be populated by the producer + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_wait(q_full, 0) + q_tile = tlx.local_view(q_tiles, cid) + + # wait for the K buffer to be populated by the producer + k_buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + k_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + + # Pingpong + if cid == 0: + # Consumer 0 waits for Consumer 1 to reach synchronization point at barrier 9. + tlx.named_barrier_wait(9, 256) + else: + # Consumer 1 signals its arrival at barrier 9. + tlx.named_barrier_arrive(9, 256) + # Then waits at barrier 10 until Consumer 0 finishes issuing its async_dot. + tlx.named_barrier_wait(10, 256) + + qk = tlx.async_dot(q_tile, k_tile) + + if cid == 0: + # After issuing async_dot, Consumer 0 signals barrier 10 to unblock Consumer 1. + tlx.named_barrier_arrive(10, 256) + + # wait for the MMA using to complete + qk = tlx.async_dot_wait(0, qk) + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + + start_n = tl.multiple_of(low, BLOCK_N) + offs_n_start = offs_n + offs_n = offs_n_start + start_n + + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + loop_trip_cnt += 1 + + for start in tl.range(low + BLOCK_N, high, BLOCK_N, num_stages=0): + start_n = tl.multiple_of(start, BLOCK_N) + offs_n = offs_n_start + start_n + + k_buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + k_phase = k_phase ^ (k_buf_id == 0) + + # wait for the K buffer to be populated by the producer + k_full = tlx.local_view(k_fulls, k_buf_id) + tlx.barrier_wait(k_full, k_phase) + k_tile = tlx.local_view(k_tiles, k_buf_id) + + # tma can only be loaded in one order, use trans afterwards + k_tile = tlx.local_trans(k_tile) + + qk = tlx.async_dot(q_tile, k_tile) + # wait for the MMA using to complete + prev_silu = silu + + v_buf_id = (loop_trip_cnt - 1) % NUM_BUFFERS + # v_phase = v_phase ^ (v_buf_id == 0) + v_phase = ((loop_trip_cnt - 1) // NUM_BUFFERS) % 2 + v_full = tlx.local_view(v_fulls, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + v_tile = tlx.local_view(v_tiles, v_buf_id) + acc = tlx.async_dot(prev_silu, v_tile, acc) + qk = tlx.async_dot_wait(1, qk) + + # release the K buffer + k_empty = tlx.local_view(k_empties, k_buf_id) + tlx.barrier_arrive(k_empty, 1) + + qk = qk * alpha + invalid_mask = offs_m[:, None] == offs_n[None, :] + max_ids = seq_len + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + offs_m = tl.where( + offs_m < max_ids, + offs_m, + max_ids, + ) + offs_n = tl.where( + offs_n < max_ids, + offs_n, + max_ids, + ) + offs_m_minus_n = offs_m[:, None] - offs_n[None, :] + invalid_mask = invalid_mask or (offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask = invalid_mask and offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask = invalid_mask or ( + offs_m[:, None] == 0 and offs_n[None, :] < max_ids + ) + scale = tl.where(invalid_mask, (1.0 / MAX_SEQ_LEN), 0.0) + silu = fast_dividef(qk, 1.0 + fast_expf(-qk)) * scale + silu = silu.to(v_dtype) + + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + + end_n += BLOCK_N + + # increment loop trip counts + loop_trip_cnt += 1 + # v_buf_id = loop_trip_cnt % NUM_BUFFERS + # v_phase = (loop_trip_cnt // NUM_BUFFERS) % 2 + + # wait for the V buffer to be populated by the producer + v_buf_id = (loop_trip_cnt - 1) % NUM_BUFFERS + v_phase = ((loop_trip_cnt - 1) // NUM_BUFFERS) % 2 + v_full = tlx.local_view(v_fulls, v_buf_id) + # tlx.barrier_wait(v_full, v_buf_id) + v_tile = tlx.local_view(v_tiles, v_buf_id) + tlx.barrier_wait(v_full, v_phase) + acc = tlx.async_dot(silu, v_tile, acc) + acc = tlx.async_dot_wait(0, acc) + # release the V buffer + v_empty = tlx.local_view(v_empties, v_buf_id) + tlx.barrier_arrive(v_empty, 1) + + return acc, end_n, loop_trip_cnt + + +@triton.jit +def _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + k_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # wait for the K buffer to be released by the consumer + k_empty = tlx.local_view(k_empties, buf_id) + tlx.barrier_wait(k_empty, k_phase) + # load K + k_full = tlx.local_view(k_fulls, buf_id) + k_tile = tlx.local_view(k_tiles, buf_id) + tlx.barrier_expect_bytes(k_full, 2 * BLOCK_N * BLOCK_D_Q) # float16 + tlx.async_descriptor_load( + K, + k_tile, + [(seq_start + start_n).to(tl.int32), offset_kh.to(tl.int32)], + k_full, + ) + + +@triton.jit +def _hstu_attn_fwd_load_Q( + Q, + q_tiles, + q_fulls, + cid, + off_z, + off_h, + stride_qh, + start_m, + seq_start, + DeltaSize, + IS_DELTA_Q: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_M: tl.constexpr, +): + q_full = tlx.local_view(q_fulls, cid) + tlx.barrier_expect_bytes(q_full, 2 * BLOCK_M * BLOCK_D_Q) # float16 + q_tile = tlx.local_view(q_tiles, cid) + seq_offset = start_m + cid * BLOCK_M + if IS_DELTA_Q: + tlx.async_descriptor_load( + Q, + q_tile, + [ + (off_z * DeltaSize + start_m).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + q_full, + ) + else: + tlx.async_descriptor_load( + Q, + q_tile, + [ + (seq_start + seq_offset).to(tl.int32), + (off_h * stride_qh).to(tl.int32), + ], + q_full, + ) + + +@triton.jit +def _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + if HAS_MULTIPLE_TARGETS: + uih_end = seq_len - n_targets + else: + uih_end = seq_len + + if HAS_CONTEXTUAL_SEQ_LEN is True and start_m < contextual_seq_len: + # uih_end must be larger than start_m + low = 0 + high = seq_len + else: + low = 0 + high = start_m + BLOCK_M + if HAS_MAX_ATTN_LEN: + if start_m > uih_end: + low = uih_end - max_attn_len + else: + low = start_m - max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + low = low if low > contextual_seq_len else 0 + else: + low = low if low > 0 else 0 + if HAS_MULTIPLE_TARGETS: + uih_end = (uih_end + BLOCK_N - 1) // BLOCK_N * BLOCK_N + if uih_end < start_m: + high = seq_len - n_targets + + return low, high, uih_end + + +@triton.jit +def _hstu_attn_fwd_load_Q_K_V( + Q, + K, + V, + q_tiles, + k_tiles, + v_tiles, + q_fulls, + k_fulls, + v_fulls, + k_empties, + v_empties, + stride_qh, + stride_kh, + stride_vh, + contextual_seq_len, + max_attn_len, + DeltaSize, + off_z, + off_h, + start_m, + seq_start, + seq_len, + n_targets, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, + NUM_MMA_GROUPS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, +): + # load q: it will stay in SRAM throughout + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + + _hstu_attn_fwd_load_Q( + Q=Q, + q_tiles=q_tiles, + q_fulls=q_fulls, + cid=0, + off_z=off_z, + off_h=off_h, + stride_qh=stride_qh, + start_m=start_m, + seq_start=seq_start, + DeltaSize=DeltaSize, + IS_DELTA_Q=IS_DELTA_Q, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_M=BLOCK_M_SPLIT, + ) + + off_h = off_h.to(tl.int64) + off_z = off_z.to(tl.int64) + offset_kh = off_h * stride_kh + offset_vh = off_h * stride_vh + + low, high, uih_end = _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN, + BLOCK_M, + BLOCK_N, + ) + + kv_phase = 0 + loop_trip_cnt = 0 + + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(low, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + for cid in tl.range(1, NUM_MMA_GROUPS, loop_unroll_factor=NUM_MMA_GROUPS - 1): + _hstu_attn_fwd_load_Q( + Q, + q_tiles, + q_fulls, + cid, + off_z, + off_h, + stride_qh, + start_m, + seq_start, + DeltaSize, + IS_DELTA_Q, + BLOCK_D_Q, + BLOCK_M_SPLIT, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + loop_trip_cnt += 1 + + for start in range(low + BLOCK_N, high, BLOCK_N): + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(start, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + # increment loop trip counts + loop_trip_cnt += 1 + + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + for start_delta in tl.range(low_delta, high_delta, BLOCK_N, num_stages=0): + # pyre-ignore[58] + buf_id = loop_trip_cnt % NUM_BUFFERS + # buffers in a row share the same phase + kv_phase = kv_phase ^ (buf_id == 0) + + start_n = tl.multiple_of(start_delta, BLOCK_N) + + _hstu_attn_fwd_load_K_or_V( + K, + k_tiles, + k_empties, + k_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_kh, + BLOCK_D_Q, + BLOCK_N, + ) + + _hstu_attn_fwd_load_K_or_V( + V, + v_tiles, + v_empties, + v_fulls, + buf_id, + kv_phase, + start_n, + seq_start, + offset_vh, + BLOCK_D_V, + BLOCK_N, + ) + + # increment loop trip counts + loop_trip_cnt += 1 + + +@triton.jit +def _hstu_attn_fwd_compute_tlx( # noqa C901 + Q, + K, + V, + H, + DimQ, + DimV, + seq_offsets, + num_targets, + Out, + stride_qh, + stride_kh, + stride_vh, + stride_om, + stride_oh, + alpha, + MAX_SEQ_LEN, + DeltaSize, + contextual_seq_len, + max_attn_len, + off_z, + off_h, + pid, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, +): + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + + if IS_DELTA_Q: + start_m = pid * BLOCK_M + start_m = (start_m + seq_len - DeltaSize).to(tl.int32) + else: + start_m = pid * BLOCK_M + + if start_m >= seq_len: + return + + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + + BLOCK_M_SPLIT: tl.constexpr = BLOCK_M // NUM_MMA_GROUPS + # allocate buffers + q_tiles = tlx.local_alloc( + (BLOCK_M_SPLIT, BLOCK_D_Q), tlx.dtype_of(Q), NUM_MMA_GROUPS + ) + k_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D_Q), tlx.dtype_of(K), NUM_BUFFERS) + v_tiles = tlx.local_alloc((BLOCK_N, BLOCK_D_V), tlx.dtype_of(V), NUM_BUFFERS) + + # allocate barriers + q_fulls = tlx.alloc_barriers(num_barriers=NUM_MMA_GROUPS, arrive_count=1) + k_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + k_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + v_empties = tlx.alloc_barriers( + num_barriers=NUM_BUFFERS, arrive_count=NUM_MMA_GROUPS + ) + v_fulls = tlx.alloc_barriers(num_barriers=NUM_BUFFERS, arrive_count=1) + + with tlx.async_tasks(): + # producer group + with tlx.async_task("default"): + _hstu_attn_fwd_load_Q_K_V( + Q=Q, + K=K, + V=V, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + DeltaSize=DeltaSize, + off_z=off_z, + off_h=off_h, + start_m=start_m, + seq_start=seq_start, + seq_len=seq_len, + n_targets=n_targets, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + NUM_MMA_GROUPS=NUM_MMA_GROUPS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ) + + # consumer groups + with tlx.async_task( + num_warps=NUM_MMA_WARPS_PER_GROUP, registers=232, replicate=NUM_MMA_GROUPS + ): + cid = tlx.async_task_replica_id() + acc = tl.zeros([BLOCK_M_SPLIT, BLOCK_D_V], dtype=tl.float32) + # initialize offsets + offs_m = start_m + tl.arange(0, BLOCK_M_SPLIT) + cid * BLOCK_M_SPLIT + offs_n = tl.arange(0, BLOCK_N) + + low, high, uih_end = _hstu_attn_fwd_caculate_range( + seq_len, + start_m, + n_targets, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN, + BLOCK_M, + BLOCK_N, + ) + + end_n = low + loop_trip_cnt = 0 + + acc, end_n, loop_trip_cnt = _hstu_attn_fwd_compute_main_loop_tlx_pipelined( + low=low, + high=high, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n, + acc=acc, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + v_dtype=tlx.dtype_of(V), + n_targets=n_targets, + alpha=alpha, + end_n=end_n, + loop_trip_cnt=loop_trip_cnt, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + cid=cid, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + MAX_SEQ_LEN=MAX_SEQ_LEN, + WAIT_FOR_Q=1, + ) + + # pyre-ignore[61] + if uih_end < start_m: + low_delta = start_m + high_delta = start_m + BLOCK_M + acc, end_n, loop_trip_cnt = _hstu_attn_fwd_compute_main_loop_tlx( + low=low_delta, + high=high_delta, + seq_len=seq_len, + offs_m=offs_m, + offs_n=offs_n, + acc=acc, + q_tiles=q_tiles, + k_tiles=k_tiles, + v_tiles=v_tiles, + q_fulls=q_fulls, + k_fulls=k_fulls, + v_fulls=v_fulls, + k_empties=k_empties, + v_empties=v_empties, + v_dtype=tlx.dtype_of(V), + n_targets=n_targets, + alpha=alpha, + end_n=end_n, + loop_trip_cnt=loop_trip_cnt, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + cid=cid, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + MAX_SEQ_LEN=MAX_SEQ_LEN, + WAIT_FOR_Q=0, + ) + + # Don't use TMA in Jagged case since we don't want to overwrite + # the output of another sequence + if IS_DELTA_Q: + start_m_delta = pid * BLOCK_M + cid * BLOCK_M_SPLIT + offs_m_delta = start_m_delta + tl.arange(0, BLOCK_M_SPLIT) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + off_z * DeltaSize * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m_delta[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m_delta < DeltaSize)[:, None]) + else: + # rematerialize offsets to save registers + start_m = pid * BLOCK_M + cid * BLOCK_M_SPLIT + offs_m = start_m + tl.arange(0, BLOCK_M_SPLIT) + offs_v_d = tl.arange(0, BLOCK_D_V) + off_o = Out + seq_start * stride_om + off_h * stride_oh + out_ptrs = off_o + offs_m[:, None] * stride_om + offs_v_d[None, :] + tl.store(out_ptrs, acc, mask=(offs_m < seq_len)[:, None]) + + +@triton_autotune( + configs=_get_fw_configs(), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + "DeltaSize", + "IS_DELTA_Q", + ], +) +@triton.jit +def _hstu_attn_fwd( # noqa C901 + Q, + K, + V, + workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + DeltaSize, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_TLX: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + off_hz = tl.program_id(1) + off_z = off_hz // H + if HAS_SORT_BY_LENGTH_INDICES: + off_z = tl.load(sort_by_length_indices + off_z) + off_h = off_hz % H + pid = tl.program_id(0) + if USE_TLX: + _hstu_attn_fwd_compute_tlx( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + NUM_BUFFERS=NUM_BUFFERS, + NUM_MMA_WARPS_PER_GROUP=NUM_MMA_WARPS_PER_GROUP, + NUM_MMA_GROUPS=NUM_MMA_GROUPS, + ) + else: + _hstu_attn_fwd_compute( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + workspace_ptr=workspace_ptr, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qm=stride_qm, + stride_qh=stride_qh, + stride_kn=stride_kn, + stride_kh=stride_kh, + stride_vn=stride_vn, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + + +@triton_autotune( + configs=_get_fw_configs(), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + "DeltaSize", + "IS_DELTA_Q", + ], +) +@triton.jit +def _hstu_attn_fwd_persistent( # noqa C901 + Q, + K, + V, + workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + Out, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_om, + stride_oh, + alpha, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + DeltaSize, + contextual_seq_len, + max_attn_len, + HAS_MULTIPLE_TARGETS: tl.constexpr, + IS_DELTA_Q: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_TLX: tl.constexpr, + NUM_BUFFERS: tl.constexpr, # + NUM_MMA_WARPS_PER_GROUP: tl.constexpr, # + NUM_MMA_GROUPS: tl.constexpr, # + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, +): + n_tile_num = tl.cdiv(MAX_SEQ_LEN, BLOCK_M) + prog_id = tl.program_id(0) + num_progs = tl.num_programs(0) + + total_tiles = n_tile_num * Z * H + + tiles_per_sm = total_tiles // num_progs + if prog_id < total_tiles % num_progs: + tiles_per_sm += 1 + + tile_idx = prog_id + for _ in range(0, tiles_per_sm): + pid = (total_tiles - tile_idx - 1) // (Z * H) + off_hz = (total_tiles - tile_idx - 1) % (Z * H) + off_z = off_hz // H + off_h = off_hz % H + _hstu_attn_fwd_compute( + Q=Q, + K=K, + V=V, + H=H, + DimQ=DimQ, + DimV=DimV, + workspace_ptr=workspace_ptr, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=Out, + stride_qm=stride_qm, + stride_qh=stride_qh, + stride_kn=stride_kn, + stride_kh=stride_kh, + stride_vn=stride_vn, + stride_vh=stride_vh, + stride_om=stride_om, + stride_oh=stride_oh, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + DeltaSize=DeltaSize, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + off_z=off_z, + off_h=off_h, + pid=pid, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + IS_DELTA_Q=IS_DELTA_Q, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ENABLE_TMA=ENABLE_TMA, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + tile_idx += num_progs + + +@triton.jit +def _hstu_attn_bwd_one_block( # noqa C901 + start_m, + offs_n, + offs_m, + q_ptrs_trans, + dq_ptrs_trans, + do_ptrs, + device_desc_q, + device_desc_do, + dk, + dv, + k, + v, + pos_offs_n, + seq_len, + max_ids, + contextual_seq_len, + max_attn_len, + LOCK, + off_h, + stride_qh, + stride_doh, + stride_qm, + stride_dom, + stride_dqm, + alpha, + MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ENABLE_TMA: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, +): + pos_offs_m = offs_m + start_m + mask_m = pos_offs_m < seq_len + invalid_mask_trans = pos_offs_m[None, :] == offs_n[:, None] + # recompute qk and silu + if HAS_CONTEXTUAL_SEQ_LEN: + pos_offs_m = pos_offs_m - contextual_seq_len + 1 + pos_offs_m = tl.where( + pos_offs_m > 0, + pos_offs_m, + 0, + ) + if HAS_MULTIPLE_TARGETS: + pos_offs_m = tl.where( + pos_offs_m < max_ids, + pos_offs_m, + max_ids, + ) + if ENABLE_TMA: + q = device_desc_q.load( + [start_m, (off_h * stride_qh).to(tl.int32)], + ) + q_trans = tl.trans(q) + else: + q_trans = tl.load( + q_ptrs_trans + start_m * stride_qm, + mask=mask_m[None, :], + other=0.0, + ) + qk_trans = tl.dot(k, q_trans, allow_tf32=ALLOW_TF32) * alpha + sig_trans = fast_dividef(1.0, 1.0 + tl.exp(-qk_trans)) + silu_trans = qk_trans * sig_trans * (1.0 / MAX_SEQ_LEN) + pos_offs_m_minus_n = pos_offs_m[None, :] - pos_offs_n[:, None] + invalid_mask_trans = invalid_mask_trans or (pos_offs_m_minus_n > 0) + if HAS_MAX_ATTN_LEN: + invalid_mask_trans = invalid_mask_trans and pos_offs_m_minus_n <= max_attn_len + if HAS_CONTEXTUAL_SEQ_LEN: + invalid_mask_trans = invalid_mask_trans or ( + pos_offs_m[None, :] == 0 and pos_offs_n[:, None] < max_ids + ) + silu_trans = tl.where(invalid_mask_trans, silu_trans, 0) + silu_trans = silu_trans.to(k.dtype) + # compute dv + if ENABLE_TMA: + do = device_desc_do.load( + [start_m, (off_h * stride_doh).to(tl.int32)], + ) + else: + do = tl.load( + do_ptrs + start_m * stride_dom, + mask=mask_m[:, None], + other=0.0, + ) + dv += tl.dot(silu_trans, do, allow_tf32=ALLOW_TF32) + + # compute dk and dq + dqk_trans = tl.dot(v, tl.trans(do), allow_tf32=ALLOW_TF32) + dqk_trans = ( + dqk_trans * sig_trans * (1 + qk_trans * (1 - sig_trans)) * (1.0 / MAX_SEQ_LEN) + ) + dqk_trans = tl.where(invalid_mask_trans, dqk_trans, 0) + dqk_trans = dqk_trans.to(k.dtype) + + # Note: the factor `alpha` is delayed until the end of the function to reduce the cost + dk += tl.dot(dqk_trans, tl.trans(q_trans), allow_tf32=ALLOW_TF32) + acc_dq( + dq_ptrs_trans=dq_ptrs_trans, + start_m=start_m, + stride_dqm=stride_dqm, + k=k, + dqk_trans=dqk_trans, + alpha=alpha, + mask_m=mask_m, + MAX_SEQ_LEN=MAX_SEQ_LEN, + LOCK=LOCK, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ALLOW_TF32=ALLOW_TF32, + ) + return dk, dv + + +@triton.jit +def _hstu_attn_bwd_one_col_block( # noqa C901 + start_n, + seq_len, + n_targets, + contextual_seq_len, + max_attn_len, + Q, + K, + V, + DOut, + DQ, + DK, + DV, + device_desc_q, + device_desc_k, + device_desc_v, + device_desc_do, + device_desc_dk, + device_desc_dv, + LOCK, + off_h, + stride_qh, + stride_kh, + stride_vh, + stride_doh, + stride_dkh, + stride_dvh, + stride_qm, + stride_kn, + stride_vn, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + alpha, + MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + UNROLL: tl.constexpr, + ATOMIC_ADD: tl.constexpr, + ENABLE_TMA: tl.constexpr, +): + if HAS_MULTIPLE_TARGETS: + low = start_n + if HAS_MAX_ATTN_LEN: + high = start_n + max_attn_len + BLOCK_N + high = high if high + n_targets < seq_len else seq_len + else: + high = seq_len + else: + low = start_n + if HAS_MAX_ATTN_LEN: + high = start_n + max_attn_len + BLOCK_N + high = high if high < seq_len else seq_len + else: + high = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + contextual_block_end = tl.cdiv(contextual_seq_len, BLOCK_M) * BLOCK_M + if low < contextual_block_end: + low = contextual_block_end + + offs_m = tl.arange(0, BLOCK_M) + offs_qk_d = tl.arange(0, BLOCK_D_Q) + offs_v_d = tl.arange(0, BLOCK_D_V) + offs_n = start_n + tl.arange(0, BLOCK_N) + + dq_ptrs_trans = DQ + (offs_m[None, :] * stride_dqm + offs_qk_d[:, None]) + dv = tl.zeros([BLOCK_N, BLOCK_D_V], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_D_Q], dtype=tl.float32) + if ENABLE_TMA: + q_ptrs_trans = None + do_ptrs = None + k = device_desc_k.load( + [start_n, (off_h * stride_kh).to(tl.int32)], + ) + v = device_desc_v.load( + [start_n, (off_h * stride_vh).to(tl.int32)], + ) + else: + mask_n = offs_n < seq_len + q_ptrs_trans = Q + (offs_m[None, :] * stride_qm + offs_qk_d[:, None]) + do_ptrs = DOut + (offs_m[:, None] * stride_dom + offs_v_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_qk_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_v_d[None, :]) + k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) + v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) + max_ids = seq_len + if HAS_CONTEXTUAL_SEQ_LEN: + pos_offs_n = offs_n - contextual_seq_len + 1 + pos_offs_n = tl.where( + pos_offs_n > 0, + pos_offs_n, + 0, + ) + max_ids = max_ids - contextual_seq_len + 1 + else: + pos_offs_n = offs_n + if HAS_MULTIPLE_TARGETS: + max_ids = max_ids - n_targets + pos_offs_n = tl.where( + pos_offs_n < max_ids, + pos_offs_n, + max_ids, + ) + # loop over rows + if HAS_CONTEXTUAL_SEQ_LEN: + for start_m in range(0, contextual_seq_len, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs_trans=q_ptrs_trans, + dq_ptrs_trans=dq_ptrs_trans, + do_ptrs=do_ptrs, + device_desc_q=device_desc_q, + device_desc_do=device_desc_do, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_doh=stride_doh, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ENABLE_TMA=ENABLE_TMA, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + ) + for start_m in tl.range(low, high, BLOCK_M, loop_unroll_factor=UNROLL): + start_m = tl.multiple_of(start_m, BLOCK_M) + dk, dv = _hstu_attn_bwd_one_block( + start_m=start_m, + offs_n=offs_n, + offs_m=offs_m, + q_ptrs_trans=q_ptrs_trans, + dq_ptrs_trans=dq_ptrs_trans, + do_ptrs=do_ptrs, + device_desc_q=device_desc_q, + device_desc_do=device_desc_do, + dk=dk, + dv=dv, + k=k, + v=v, + pos_offs_n=pos_offs_n, + seq_len=seq_len, + max_ids=max_ids, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_doh=stride_doh, + stride_qm=stride_qm, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_M=BLOCK_M, + ATOMIC_ADD=ATOMIC_ADD, + ENABLE_TMA=ENABLE_TMA, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + ) + # write-back + dk = dk * alpha + if ENABLE_TMA: + device_desc_dv.store( + [start_n, (off_h * stride_dvh).to(tl.int32)], + dv.to(k.dtype), + ) + device_desc_dk.store( + [start_n, (off_h * stride_dkh).to(tl.int32)], + dk.to(k.dtype), + ) + else: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_v_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_qk_d[None, :]) + tl.store(dv_ptrs, dv.to(k.dtype), mask=mask_n[:, None]) # pyre-ignore[61] + tl.store(dk_ptrs, dk.to(k.dtype), mask=mask_n[:, None]) # pyre-ignore[61] + + +def _bwd_pre_hook(nargs): + nargs["DQ"].zero_() + if nargs["SEQUENCE_PARALLEL"] is True: + nargs["LOCK"].zero_() + + +def _get_bw_configs() -> List[triton.Config]: + if torch.version.hip: + configs = [] + for BLOCK_M in [32, 64]: + for BLOCK_N in [32, 64, 128]: + for num_stages in [1, 2]: + for num_warps in [4, 8]: + for matrix_instr_nonkdim in [16, 32]: + for waves_per_eu in [0, 2, 4]: + for sp in [True, False]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "matrix_instr_nonkdim": matrix_instr_nonkdim, + "waves_per_eu": waves_per_eu, + "SEQUENCE_PARALLEL": sp, + "UNROLL": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=_bwd_pre_hook, + ) + ) + return configs + + configs = [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 16, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=3, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False, "UNROLL": 4}, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=2, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 32, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + ] + if torch.cuda.is_available() and torch.version.cuda < "12.8": + configs += [ + triton.Config( + {"BLOCK_M": 16, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=3, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=1, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True, "UNROLL": 1}, + num_stages=2, + num_warps=4, + pre_hook=_bwd_pre_hook, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 128, + "SEQUENCE_PARALLEL": False, + "UNROLL": 2, + }, + num_stages=2, + num_warps=8, + pre_hook=_bwd_pre_hook, + ), + ] + else: + print("WARNING: temporarily disabled some autotune configs for CUDA 12.8+") + return configs + + +@triton_autotune( + configs=_get_bw_configs(), + key=[ + "AUTOTUNE_Z", + "H", + "AUTOTUNE_MAX_SEQ_LEN", + "DimQ", + "DimV", + ], +) +@triton.jit +def _hstu_attn_bwd( # noqa C901 + Q, + K, + V, + tma_workspace_ptr, + sort_by_length_indices, + seq_offsets, + num_targets, + DOut, + DQ, + DK, + DV, + LOCK, + stride_qm, + stride_qh, + stride_kn, + stride_kh, + stride_vn, + stride_vh, + stride_dom, + stride_doh, + stride_dqm, + stride_dqh, + stride_dkn, + stride_dkh, + stride_dvn, + stride_dvh, + alpha, + contextual_seq_len, + max_attn_len, + Z, + AUTOTUNE_Z, + H, + MAX_SEQ_LEN, + AUTOTUNE_MAX_SEQ_LEN, # Quantized MAX_SEQ_LEN used as an autotuning key + DimQ, + DimV, + HAS_MULTIPLE_TARGETS: tl.constexpr, + HAS_CONTEXTUAL_SEQ_LEN: tl.constexpr, + HAS_MAX_ATTN_LEN: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_D_Q: tl.constexpr, + BLOCK_D_V: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + UNROLL: tl.constexpr, + HAS_SORT_BY_LENGTH_INDICES: tl.constexpr, + ENABLE_TMA: tl.constexpr, + TMA_DESC_SIZE: tl.constexpr, + ENABLE_BUFFER_OPS_ASSUMES: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + if HAS_SORT_BY_LENGTH_INDICES: + off_z = tl.load(sort_by_length_indices + off_z) + off_h = off_hz % H + off_h = off_h.to(tl.int64) + seq_start = tl.load(seq_offsets + off_z).to(tl.int64) + seq_end = tl.load(seq_offsets + off_z + 1) + seq_len = (seq_end - seq_start).to(tl.int32) + if HAS_MULTIPLE_TARGETS: + n_targets = tl.load(num_targets + off_z).to(tl.int32) + else: + n_targets = None + if ENABLE_BUFFER_OPS_ASSUMES: + tl.assume(off_hz >= 0) + tl.assume(off_z >= 0) + tl.assume(off_h >= 0) + tl.assume(seq_start >= 0) + tl.assume(stride_qm >= 0) + tl.assume(stride_qh >= 0) + tl.assume(stride_kn >= 0) + tl.assume(stride_kh >= 0) + tl.assume(stride_vn >= 0) + tl.assume(stride_vh >= 0) + tl.assume(stride_dom >= 0) + tl.assume(stride_doh >= 0) + tl.assume(stride_dqm >= 0) + tl.assume(stride_dqh >= 0) + tl.assume(stride_dkn >= 0) + tl.assume(stride_dkh >= 0) + tl.assume(stride_dvn >= 0) + tl.assume(stride_dvh >= 0) + + # offset pointers for batch/head + Q = Q + seq_start * stride_qm + K = K + seq_start * stride_kn + V = V + seq_start * stride_vn + DOut = DOut + seq_start * stride_dom + DQ = DQ + seq_start * stride_dqm + off_h * stride_dqh + DK = DK + seq_start * stride_dkn + DV = DV + seq_start * stride_dvn + device_desc_q = None + device_desc_k = None + device_desc_v = None + device_desc_do = None + device_desc_dk = None + device_desc_dv = None + if ENABLE_TMA: + device_desc_q = tl.make_tensor_descriptor( + Q, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_M, BLOCK_D_Q], + ) + device_desc_do = tl.make_tensor_descriptor( + DOut, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_M, BLOCK_D_V], + ) + device_desc_k = tl.make_tensor_descriptor( + K, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_N, BLOCK_D_Q], + ) + device_desc_dk = tl.make_tensor_descriptor( + DK, + shape=[seq_len, H * DimQ], + strides=[H * DimQ, 1], + block_shape=[BLOCK_N, BLOCK_D_Q], + ) + device_desc_v = tl.make_tensor_descriptor( + V, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_N, BLOCK_D_V], + ) + device_desc_dv = tl.make_tensor_descriptor( + DV, + shape=[seq_len, H * DimV], + strides=[H * DimV, 1], + block_shape=[BLOCK_N, BLOCK_D_V], + ) + else: + Q += off_h * stride_qh + K += off_h * stride_kh + V += off_h * stride_vh + DOut += off_h * stride_doh + DK += off_h * stride_dkh + DV += off_h * stride_dvh + if SEQUENCE_PARALLEL: + start_n = tl.program_id(1) * BLOCK_N + if start_n >= seq_len: + return + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + n_targets=n_targets, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Q=Q, + K=K, + V=V, + DOut=DOut, + DQ=DQ, + DK=DK, + DV=DV, + device_desc_q=device_desc_q, + device_desc_k=device_desc_k, + device_desc_v=device_desc_v, + device_desc_do=device_desc_do, + device_desc_dk=device_desc_dk, + device_desc_dv=device_desc_dv, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_doh=stride_doh, + stride_dkh=stride_dkh, + stride_dvh=stride_dvh, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + stride_dkn=stride_dkn, + stride_dvn=stride_dvn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + UNROLL=UNROLL, + ATOMIC_ADD=True, + ENABLE_TMA=ENABLE_TMA, + ) + else: + for start_n in range(0, seq_len, BLOCK_N): + _hstu_attn_bwd_one_col_block( + start_n=start_n, + seq_len=seq_len, + n_targets=n_targets, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Q=Q, + K=K, + V=V, + DOut=DOut, + DQ=DQ, + DK=DK, + DV=DV, + device_desc_q=device_desc_q, + device_desc_k=device_desc_k, + device_desc_v=device_desc_v, + device_desc_do=device_desc_do, + device_desc_dk=device_desc_dk, + device_desc_dv=device_desc_dv, + LOCK=LOCK, + off_h=off_h, + stride_qh=stride_qh, + stride_kh=stride_kh, + stride_vh=stride_vh, + stride_doh=stride_doh, + stride_dkh=stride_dkh, + stride_dvh=stride_dvh, + stride_qm=stride_qm, + stride_kn=stride_kn, + stride_vn=stride_vn, + stride_dom=stride_dom, + stride_dqm=stride_dqm, + stride_dkn=stride_dkn, + stride_dvn=stride_dvn, + alpha=alpha, + MAX_SEQ_LEN=MAX_SEQ_LEN, + HAS_MULTIPLE_TARGETS=HAS_MULTIPLE_TARGETS, + HAS_CONTEXTUAL_SEQ_LEN=HAS_CONTEXTUAL_SEQ_LEN, + HAS_MAX_ATTN_LEN=HAS_MAX_ATTN_LEN, + ALLOW_TF32=ALLOW_TF32, + BLOCK_D_Q=BLOCK_D_Q, + BLOCK_D_V=BLOCK_D_V, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + UNROLL=UNROLL, + ATOMIC_ADD=False, + ENABLE_TMA=ENABLE_TMA, + ) + + +def triton_hstu_attention_fwd( + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, +) -> torch.Tensor: + Z = seq_offsets.numel() - 1 + AUTOTUNE_Z = prev_power_of_2(Z) + L, H, DimQ = q.shape + _, _, DimV = v.shape + out = torch.empty_like(v) + has_multiple_targets = num_targets is not None + has_contextual_seq_len = contextual_seq_len > 0 + has_max_attn_len = max_attn_len > 0 + has_sort_by_length_indices = sort_by_length_indices is not None + if L == 0: + return out + + TMA_DESC_SIZE = 128 + workspace = None + desc_q = q + desc_k = k + desc_v = v + + if enable_tma and tensor_descriptor_tma: + dummy_block = [1, 1] + desc_q = TensorDescriptor( + q, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + v, + shape=[L, H * DimV], + strides=[H * DimV, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_M"]), + Z * H, + ) + + _hstu_attn_fwd[grid]( + Q=desc_q, + K=desc_k, + V=desc_v, + workspace_ptr=workspace, + sort_by_length_indices=sort_by_length_indices, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=out, + stride_qm=q.stride(0), + stride_qh=q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_om=out.stride(0), + stride_oh=out.stride(1), + alpha=alpha, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + DeltaSize=0, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + HAS_MULTIPLE_TARGETS=has_multiple_targets, + IS_DELTA_Q=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, + HAS_MAX_ATTN_LEN=has_max_attn_len, + HAS_SORT_BY_LENGTH_INDICES=has_sort_by_length_indices, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + return out + + +def triton_hstu_attention_bwd( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + dv: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + N: int, + alpha: float, + max_attn_len: int, + contextual_seq_len: int, + sort_by_length_indices: Optional[torch.Tensor], + enable_tma: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dout = switch_to_contiguous_if_needed(dout) + dq = switch_to_contiguous_if_needed(dq) + dk = switch_to_contiguous_if_needed(dk) + dv = switch_to_contiguous_if_needed(dv) + if dout.shape[0] == 0: + return torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v) + Z = seq_offsets.numel() - 1 + _, H, DimQ = q.shape + _, _, DimV = v.shape + grid = lambda meta: ( # noqa E731 + Z * H, + (triton.cdiv(N, meta["BLOCK_N"]) if meta["SEQUENCE_PARALLEL"] else 1), + ) + # The minimum size of BLOCK_M used in `_get_bw_configs`. + # TODO (linjianma): avoid hardcoding the value. + MIN_BLOCK_M = 16 + lock = torch.empty( + (Z * H, triton.cdiv(N, MIN_BLOCK_M)), + dtype=torch.int32, + device=q.device, + ) + AUTOTUNE_Z = prev_power_of_2(Z) + TMA_DESC_SIZE = 128 + tma_workspace = None + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + + # Enable BufferOps on AMD + ENABLE_BUFFER_OPS_ASSUMES = torch.version.hip is not None + _hstu_attn_bwd[grid]( + Q=q, + K=k, + V=v, + tma_workspace_ptr=tma_workspace, + sort_by_length_indices=sort_by_length_indices, + seq_offsets=seq_offsets, + num_targets=num_targets, + DOut=dout, + DQ=dq, + DK=dk, + DV=dv, + LOCK=lock, + stride_qm=q.stride(0), + stride_qh=q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_dom=dout.stride(0), + stride_doh=dout.stride(1), + stride_dqm=dq.stride(0), + stride_dqh=dq.stride(1), + stride_dkn=dk.stride(0), + stride_dkh=dk.stride(1), + stride_dvn=dv.stride(0), + stride_dvh=dv.stride(1), + alpha=alpha, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + HAS_MULTIPLE_TARGETS=num_targets is not None, + HAS_CONTEXTUAL_SEQ_LEN=contextual_seq_len > 0, + HAS_MAX_ATTN_LEN=max_attn_len > 0, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_SORT_BY_LENGTH_INDICES=sort_by_length_indices is not None, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ENABLE_BUFFER_OPS_ASSUMES=ENABLE_BUFFER_OPS_ASSUMES, + ) + + return dq, dk, dv + + +class _AttentionFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + sort_by_length: bool, + enable_tma: bool, + ) -> torch.Tensor: + sort_by_length_indices = None + if sort_by_length: + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + _, sort_by_length_indices = torch.sort( + seq_lengths, descending=True, stable=False + ) + saved_tensors = [q, k, v, seq_offsets] + if num_targets is not None: + saved_tensors.append(num_targets) + if sort_by_length_indices is not None: + saved_tensors.append(sort_by_length_indices) + ctx.save_for_backward(*saved_tensors) + ctx.alpha = alpha + ctx.has_multiple_targets = num_targets is not None + ctx.max_attn_len = max_attn_len + ctx.N = N + ctx.contextual_seq_len = contextual_seq_len + ctx.sort_by_length = sort_by_length + ctx.enable_tma = enable_tma + return triton_hstu_attention_fwd( + N=N, + alpha=alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=enable_tma, + ) + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dout: torch.Tensor + ) -> Tuple[ + None, + None, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + ]: + with torch.inference_mode(): + q, k, v, seq_offsets = ctx.saved_tensors[:4] + idx = 4 + if ctx.has_multiple_targets: + num_targets = ctx.saved_tensors[idx] + idx += 1 + else: + num_targets = None + if ctx.sort_by_length: + sort_by_length_indices = ctx.saved_tensors[idx] + else: + sort_by_length_indices = None + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dq, dk, dv = triton_hstu_attention_bwd( + dout=dout, + q=q, + k=k, + v=v, + dq=dq, + dk=dk, + dv=dv, + seq_offsets=seq_offsets, + num_targets=num_targets, + N=ctx.N, + alpha=ctx.alpha, + max_attn_len=ctx.max_attn_len, + contextual_seq_len=ctx.contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=ctx.enable_tma, + ) + return ( + None, + None, + dq, + dk, + dv, + None, + None, + None, + None, + None, + None, + ) + + +@torch.fx.wrap +def triton_hstu_mha( + N: int, + alpha: float, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + sort_by_length: bool = False, + enable_tma: bool = False, +) -> torch.Tensor: + return _AttentionFunction.apply( + N, + alpha, + q, + k, + v, + seq_offsets, + num_targets, + max_attn_len, + contextual_seq_len, + sort_by_length, + enable_tma, + ) + + +@torch.fx.wrap +def triton_cached_hstu_mha( + N: int, + alpha: float, + delta_q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_offsets: torch.Tensor, + num_targets: Optional[torch.Tensor] = None, + max_attn_len: int = 0, + contextual_seq_len: int = 0, + enable_tma: bool = False, +) -> torch.Tensor: + Z = seq_offsets.size(0) - 1 + AUTOTUNE_Z = prev_power_of_2(Z) + DELTA_L, H, DimQ = delta_q.shape + DeltaSize = DELTA_L // Z + L, _, DimV = v.shape + out = torch.empty((DELTA_L, H, DimV), dtype=delta_q.dtype, device=delta_q.device) + + TMA_DESC_SIZE = 128 + desc_q = delta_q + desc_k = k + desc_v = v + + if enable_tma and tensor_descriptor_tma: + dummy_block = [1, 1] + desc_q = TensorDescriptor( + delta_q, + shape=[DELTA_L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + desc_v = TensorDescriptor( + v, + shape=[L, H * DimV], + strides=[H * DimV, 1], + block_shape=dummy_block, + ) + desc_k = TensorDescriptor( + k, + shape=[L, H * DimQ], + strides=[H * DimQ, 1], + block_shape=dummy_block, + ) + + def alloc_fn(size: int, align: int, stream: Optional[int]): + assert align == TMA_DESC_SIZE + return torch.empty(size, dtype=torch.int8, device="cuda") + + # pyre-ignore [6] + triton.set_allocator(alloc_fn) + grid = lambda meta: ( # noqa E731 + triton.cdiv(DeltaSize, meta["BLOCK_M"]), + Z * H, + ) + + has_contextual_seq_len = contextual_seq_len > 0 + has_max_attn_len = max_attn_len > 0 + _hstu_attn_fwd[grid]( + Q=desc_q, + K=desc_k, + V=desc_v, + workspace_ptr=None, + sort_by_length_indices=None, + seq_offsets=seq_offsets, + num_targets=num_targets, + Out=out, + stride_qm=delta_q.stride(0), + stride_qh=delta_q.stride(1), + stride_kn=k.stride(0), + stride_kh=k.stride(1), + stride_vn=v.stride(0), + stride_vh=v.stride(1), + stride_om=out.stride(0), + stride_oh=out.stride(1), + alpha=alpha, + contextual_seq_len=contextual_seq_len, + max_attn_len=max_attn_len, + Z=Z, + AUTOTUNE_Z=AUTOTUNE_Z, + H=H, + MAX_SEQ_LEN=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(N), + DimQ=DimQ, + DimV=DimV, + DeltaSize=DeltaSize, + HAS_MULTIPLE_TARGETS=num_targets is not None, + IS_DELTA_Q=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + BLOCK_D_Q=DimQ, + BLOCK_D_V=DimV, + HAS_CONTEXTUAL_SEQ_LEN=has_contextual_seq_len, + HAS_MAX_ATTN_LEN=has_max_attn_len, + HAS_SORT_BY_LENGTH_INDICES=False, + ENABLE_TMA=enable_tma, + TMA_DESC_SIZE=TMA_DESC_SIZE, + ) + return out diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_linear.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_linear.py new file mode 100644 index 0000000000..8b0c288696 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_linear.py @@ -0,0 +1,2063 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.triton.triton_addmm import maybe_triton_addmm_fwd + + +def _get_layer_norm_mul_dropout_fwd_multirow_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm multiplication with dropout kernels.""" + configs = [] + for BLOCK_N in [1, 2, 4, 8, 16]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +from generative_recommenders.ops.utils import is_sm100 + +# @manual=//triton:triton +from triton.language.extra import libdevice + +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + + +COMPUTE_OUTPUT_LN_FAST_DROPOUT = False + + +def set_compute_output_ln_fast_dropout(value: bool) -> None: + global COMPUTE_OUTPUT_LN_FAST_DROPOUT + COMPUTE_OUTPUT_LN_FAST_DROPOUT = value + + +FUSE_OUTPUT_LN_RNG_BLACKWELL = False + + +# Only impact B200 training when CONCAT_UX is False +def set_fuse_output_ln_rng_blackwell(value: bool) -> None: + global FUSE_OUTPUT_LN_RNG_BLACKWELL + FUSE_OUTPUT_LN_RNG_BLACKWELL = value + + +@triton.jit +def rand3x(seed, offsets, n_rounds: tl.constexpr = 10): # pyre-ignore [9] + i1, i2, i3, _ = tl.randint4x(seed, offsets, n_rounds) + u1 = tl.uint_to_uniform_float(i1) + u2 = tl.uint_to_uniform_float(i2) + u3 = tl.uint_to_uniform_float(i3) + return u1, u2, u3 + + +@triton.jit +def _generate_random_mask( + MASK_BUFFER, + N_MASK, + dropout_ratio, + seed, + D: tl.constexpr, + STRIDE: tl.constexpr, + BLOCK_D: tl.constexpr, +): + # NOTE: This function appears to be incomplete/unused - kept for compatibility + pid = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + random_offsets = pid * BLOCK_D + cols + rand1, rand2, rand3, rand4 = tl.rand4x(seed, random_offsets) + start_row = pid * 4 + MASK_BUFFER += start_row * STRIDE + row_mask = start_row < N_MASK + mask1 = rand1 > dropout_ratio + tl.store(MASK_BUFFER + cols, mask1, mask=row_mask & col_mask) + row_mask = (start_row + 1) < N_MASK + mask2 = rand2 > dropout_ratio + tl.store(MASK_BUFFER + STRIDE + cols, mask2, mask=row_mask & col_mask) + row_mask = (start_row + 2) < N_MASK + mask3 = rand3 > dropout_ratio + tl.store( + MASK_BUFFER + 2 * STRIDE + cols, + mask3, + mask=row_mask & col_mask, + ) + row_mask = (start_row + 3) < N_MASK + mask4 = rand4 > dropout_ratio + tl.store( + MASK_BUFFER + 3 * STRIDE + cols, + mask4, + mask=row_mask & col_mask, + ) + + +@triton_autotune( + configs=_get_layer_norm_mul_dropout_fwd_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _ln_mul_dropout_fwd_rng( + X, + U, + Y, + W, + B, + Mean, + Rstd, + RANDOM_MASK, + N, + D, + eps, + dropout_ratio, + stride_x, + stride_u, + stride_y, + stride_mask, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Create block pointers for X, U, and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + U_block_ptr = tl.make_block_ptr( + base=U, + shape=(N, D), + strides=(stride_u, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + u_block = tl.load(U_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + mean = tl.sum(x_block, axis=1) / D + tl.store(Mean + rows, mean, mask=row_mask) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(row_mask[:, None] & col_mask[None, :], x_mean, 0.0) + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + rows, rstd, mask=row_mask) + rstd = tl.expand_dims(rstd, 1) + + y = x_mean * rstd + w = tl.load(W + cols, mask=col_mask).to(tl.float32) + b = tl.load(B + cols, mask=col_mask).to(tl.float32) + y = y * w[None, :] + b[None, :] + + if SILU_U: + # pyre-fixme[16] + u_block = fast_dividef(u_block, 1.0 + tl.exp(-u_block)) + + y = y * u_block + + if TRAINING: + if CONCAT_UX: + row_offsets = start_row + tl.arange(0, BLOCK_N) + col_offsets = tl.arange(0, BLOCK_D) + + # Load precomputed random masks for u, x, y + u_offsets = row_offsets[:, None] * stride_mask + col_offsets[None, :] + x_offsets = (row_offsets[:, None] + N) * stride_mask + col_offsets[None, :] + y_offsets = (row_offsets[:, None] + 2 * N) * stride_mask + col_offsets[ + None, : + ] + + mask = (row_offsets[:, None] < N) & (col_offsets[None, :] < D) + + u_keep = tl.load(RANDOM_MASK + u_offsets, mask=mask, other=True) + x_keep = tl.load(RANDOM_MASK + x_offsets, mask=mask, other=True) + y_keep = tl.load(RANDOM_MASK + y_offsets, mask=mask, other=True) + + u_block = tl.where(u_keep, u_block / (1.0 - dropout_ratio), 0.0) + x_block = tl.where(x_keep, x_block / (1.0 - dropout_ratio), 0.0) + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + else: + row_offsets = start_row + tl.arange(0, BLOCK_N) + col_offsets = tl.arange(0, BLOCK_D) + + # Load precomputed random mask for y + y_offsets = row_offsets[:, None] * stride_mask + col_offsets[None, :] + mask = (row_offsets[:, None] < N) & (col_offsets[None, :] < D) + + y_keep = tl.load(RANDOM_MASK + y_offsets, mask=mask, other=True) + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + + if CONCAT_UX: + Y_block_ptr_u = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr_x = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr_y = tl.make_block_ptr( + base=Y, + shape=(N, 3 * D), + strides=(stride_y, 1), + offsets=(start_row, 2 * D), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + tl.store(Y_block_ptr_u, u_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_x, x_block.to(Y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(Y_block_ptr_y, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + else: + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _ln_mul_dropout_fwd( + X, + U, + Y, + W, + B, + Mean, + Rstd, + D, + eps, + seed, + dropout_ratio, + stride_x, + stride_u, + stride_y, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + FAST_DROPOUT: tl.constexpr, +): + row = tl.program_id(0) + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + Y += row.to(tl.int64) * stride_y + cols = tl.arange(0, BLOCK_D) + + # Compute mean + mean = 0.0 + x = tl.load(X + cols, mask=cols < D, other=0.0).to(tl.float32) + mean = tl.sum(x, axis=0) / D + + # Compute variance + _var = tl.zeros([BLOCK_D], dtype=tl.float32) + x_mean = tl.where(cols < D, x - mean, 0.0) + _var += x_mean * x_mean + var = tl.sum(_var, axis=0) / D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + mask = cols < D + y = x_mean * rstd + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + y = y * w + b + u = tl.load(U + cols, mask=cols < D, other=0.0).to(tl.float32) + if SILU_U: + # pyre-fixme[16] + u = fast_dividef(u, 1.0 + tl.exp(-u)) + y = y * u + + if TRAINING: + random_offsets = 3 * row * BLOCK_D + cols + if CONCAT_UX: + # apply dropout on u + if FAST_DROPOUT: + random_u, random_x, random_y = rand3x(seed, random_offsets) + else: + random_u = tl.rand(seed, random_offsets) + u_keep = random_u > dropout_ratio + u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) + # apply dropout on x + if not FAST_DROPOUT: + random_x = tl.rand(seed, random_offsets + D) + x_keep = random_x > dropout_ratio # pyre-ignore [61] + x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + if not FAST_DROPOUT: + random_y = tl.rand(seed, random_offsets + 2 * D) + y_keep = random_y > dropout_ratio # pyre-ignore [61] + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + else: + random = tl.rand(seed, random_offsets) + y_keep = random > dropout_ratio + # write-back + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + + # Write output + if CONCAT_UX: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + + +@triton.jit +def _ln_mul_dropout_bwd_dx_du_rng( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + RANDOM_MASK, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + stride_mask, + D, + eps, + dropout_ratio, + N, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + COMPUTE_Y: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + rows_per_tile = N // tile_num + if pid < N % tile_num: + rows_per_tile += 1 + + if rows_per_tile == 0: + return + + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + row = pid + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + DW = DW + pid * D + cols + DB = DB + pid * D + cols + + num_random = 1 + if CONCAT_UX: + num_random = 3 + RANDOM_MASK += row.to(tl.int64) * stride_mask * num_random + + partial_dw = tl.zeros((BLOCK_D,), dtype=tl.float32) + partial_db = tl.zeros((BLOCK_D,), dtype=tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + for _ in range(0, rows_per_tile): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + if CONCAT_UX: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if TRAINING: + if CONCAT_UX: + # Load dropout masks for u, x, y from pre-generated mask tensor + du_keep = tl.load(RANDOM_MASK + cols, mask=mask, other=True) + dx_keep = tl.load( + RANDOM_MASK + stride_mask + cols, mask=mask, other=True + ) + dy_keep = tl.load( + RANDOM_MASK + 2 * stride_mask + cols, mask=mask, other=True + ) + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + else: + # Load dropout mask directly instead of generating random numbers + dy_keep = tl.load(RANDOM_MASK + cols, mask=mask, other=True) + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) + ln = xhat * w + b + du += dy * ln + if SILU_U: + # pyre-ignore[16] + sig_u = fast_dividef(1.0, 1.0 + tl.exp(-u)) + du = du * (sig_u + u * sig_u * (1.0 - sig_u)) + u = u * sig_u + tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) + dy = dy * u + wdy = w * dy + if COMPUTE_Y: + y = ln * u + if TRAINING: + if CONCAT_UX: + u = tl.where( + du_keep, # pyre-ignore [61] + u / (1.0 - dropout_ratio), + 0.0, + ) + x = tl.where( + dx_keep, # pyre-ignore [61] + x / (1.0 - dropout_ratio), + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + else: + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_UX: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + Y += tile_num.to(tl.int64) * stride_y + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / D + c2 = tl.sum(wdy, axis=0) / D + dx += (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw += dy * xhat + partial_db += dy + X += tile_num.to(tl.int64) * stride_x + U += tile_num.to(tl.int64) * stride_u + DY += tile_num.to(tl.int64) * stride_dy + DX += tile_num.to(tl.int64) * stride_dx + DU += tile_num.to(tl.int64) * stride_du + RANDOM_MASK += tile_num.to(tl.int64) * stride_mask * num_random + row += tile_num + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + + +@triton.jit +def _ln_mul_dropout_bwd_dx_du( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + D, + eps, + seed, + dropout_ratio, + N, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + COMPUTE_Y: tl.constexpr, + FAST_DROPOUT: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + rows_per_tile = N // tile_num + if pid < N % tile_num: + rows_per_tile += 1 + + if rows_per_tile == 0: + return + + cols = tl.arange(0, BLOCK_D) + mask = cols < D + + row = pid + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + DW = DW + pid * D + cols + DB = DB + pid * D + cols + + partial_dw = tl.zeros((BLOCK_D,), dtype=tl.float32) + partial_db = tl.zeros((BLOCK_D,), dtype=tl.float32) + w = tl.load(W + cols, mask=mask).to(tl.float32) + b = tl.load(B + cols, mask=mask).to(tl.float32) + for _idx in range(0, rows_per_tile): + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + if CONCAT_UX: + du = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + D + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * D + cols, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + if TRAINING: + random_offsets = 3 * row * BLOCK_D + cols + if CONCAT_UX: + # apply dropout on du + if FAST_DROPOUT: + random_du, random_dx, random_dy = rand3x(seed, random_offsets) + else: + random_du = tl.rand(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + # apply dropout on dx + if not FAST_DROPOUT: + random_dx = tl.rand(seed, random_offsets + D) + dx_keep = random_dx > dropout_ratio # pyre-ignore [61] + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + if not FAST_DROPOUT: + random_dy = tl.rand(seed, random_offsets + 2 * D) + dy_keep = random_dy > dropout_ratio # pyre-ignore [61] + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + else: + random = tl.rand(seed, random_offsets) + dy_keep = random > dropout_ratio + # write-back + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + u = tl.load(U + cols, mask=mask, other=0).to(tl.float32) + ln = xhat * w + b + du += dy * ln + if SILU_U: + # pyre-ignore[16] + sig_u = fast_dividef(1.0, 1.0 + tl.exp(-u)) + du = du * (sig_u + u * sig_u * (1.0 - sig_u)) + u = u * sig_u + tl.store(DU + cols, du.to(DU.dtype.element_ty), mask=mask) + dy = dy * u + wdy = w * dy + if COMPUTE_Y: + y = ln * u + if TRAINING: + if CONCAT_UX: + u = tl.where( + du_keep, # pyre-ignore [61] + u / (1.0 - dropout_ratio), + 0.0, + ) + x = tl.where( + dx_keep, # pyre-ignore [61] + x / (1.0 - dropout_ratio), + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + else: + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_UX: + tl.store(Y + cols, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + D + cols, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * D + cols, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask) + Y += tile_num.to(tl.int64) * stride_y + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / D + c2 = tl.sum(wdy, axis=0) / D + dx += (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Accumulate partial sums for dw/db + partial_dw += dy * xhat + partial_db += dy + X += tile_num.to(tl.int64) * stride_x + U += tile_num.to(tl.int64) * stride_u + DY += tile_num.to(tl.int64) * stride_dy + DX += tile_num.to(tl.int64) * stride_dx + DU += tile_num.to(tl.int64) * stride_du + row += tile_num + tl.store(DW, partial_dw, mask=mask) + tl.store(DB, partial_db, mask=mask) + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [32, 64, 128, 256]: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _ln_mul_dropout_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) + + +def triton_layer_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int +]: # y, mean, rstd, BLOCK_D, num_warps, seed + assert x.dim() == 2 + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + if concat_ux: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + if N == 0: + return y, mean, rstd, 0, 0, 0 + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if seed is None: + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + num_warps: int = min(max(BLOCK_D // 256, 1), 8) + sms = torch.cuda.get_device_properties("cuda").multi_processor_count + # Benchmark shows separating RNG from ln_mul_dropout kernel only benefits on + # blackwell when CONCAT_UX is enabled. (fused RNG kernel can benefit from rand3x fast + # dropout) + if not FUSE_OUTPUT_LN_RNG_BLACKWELL and is_sm100() and not concat_ux and training: + random_mask = torch.empty([N, D], dtype=torch.bool, device=x.device) + + _generate_random_mask[(triton.cdiv(N, 4),)]( + random_mask, + N, + dropout_ratio, + seed, + D, # pyre-ignore [6] + random_mask.stride(0), # pyre-ignore [6] + BLOCK_D, # pyre-ignore [6] + ) + + def grid(META): + return (triton.cdiv(N, META["BLOCK_N"]),) + + # pyre-ignore[28] + _ln_mul_dropout_fwd_rng[grid]( + x, + u, + y, + weight, + bias, + mean, + rstd, + random_mask, + N, + D, + eps, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + random_mask.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_UX=concat_ux, + ) + + else: + # pyre-ignore[28] + _ln_mul_dropout_fwd[(N,)]( + x, + u, + y, + weight, + bias, + mean, + rstd, + D, + eps, + seed, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_UX=concat_ux, + FAST_DROPOUT=COMPUTE_OUTPUT_LN_FAST_DROPOUT, + num_warps=num_warps, + ) + return y, mean, rstd, BLOCK_D, num_warps, seed # pyre-ignore [7] + + +def triton_layer_norm_mul_dropout_bwd( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: Optional[int] = None, + silu_u: bool = False, + concat_ux: bool = False, + compute_y: bool = False, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] +]: + y = None + N, D = x.shape + if compute_y: + if concat_ux: + y = torch.empty((N, 3 * D), dtype=x.dtype, device=x.device) + else: + y = torch.empty_like(x) + if N == 0: + return ( + torch.zeros_like(x), + torch.zeros_like(u), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + torch.zeros((D,), dtype=weight.dtype, device=x.device), + y, + ) + dx = torch.empty_like(x) + du = torch.empty_like(u) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 64, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + + if not FUSE_OUTPUT_LN_RNG_BLACKWELL and is_sm100() and not concat_ux and training: + random_mask = torch.empty([N, D], dtype=torch.bool, device=x.device) + + _generate_random_mask[(triton.cdiv(N, 4),)]( + random_mask, + N, + dropout_ratio, + seed, + D, # pyre-ignore [6] + random_mask.stride(0), # pyre-ignore [6] + BLOCK_D, # pyre-ignore [6] + ) + + # pyre-ignore[28] + _ln_mul_dropout_bwd_dx_du_rng[(tile_num,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y, + weight, + bias, + mean, + rstd, + random_mask, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, # pyre-ignore [16] + random_mask.stride(0), + D, + eps, + dropout_ratio, + N=N, + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_UX=concat_ux, + COMPUTE_Y=compute_y, + num_warps=num_warps, + ) + + else: + # pyre-ignore[28] + _ln_mul_dropout_bwd_dx_du[(tile_num,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y, + weight, + bias, + mean, + rstd, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, # pyre-ignore [16] + D, + eps, + seed, + dropout_ratio, + N=N, + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + TRAINING=training, + CONCAT_UX=concat_ux, + COMPUTE_Y=compute_y, + FAST_DROPOUT=COMPUTE_OUTPUT_LN_FAST_DROPOUT, + num_warps=num_warps, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D = min(max(BLOCK_D, 4), 128) + _ln_mul_dropout_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D, + ) + return dx, du, dweight, dbias, y + + +class LayerNormMulDropoutFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + concat_ux: bool = False, + seed: Optional[int] = None, + ) -> torch.Tensor: + if dropout_ratio == 0.0: + # skip dropout computation if dropout ratio is 0 + training = False + y, mean, rstd, BLOCK_D, num_warps, seed = triton_layer_norm_mul_dropout_fwd( + x=x, + u=u, + weight=weight, + bias=bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + seed=seed, + ) + ctx.save_for_backward(x, u, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.concat_ux = concat_ux + ctx.dropout_ratio = dropout_ratio + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + ]: + x, u, weight, bias, mean, rstd = ctx.saved_tensors + dx, du, dweight, dbias, _ = triton_layer_norm_mul_dropout_bwd( + dy=dy, + x=x, + u=u, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + concat_ux=ctx.concat_ux, + compute_y=False, + ) + return dx, du, dweight, dbias, None, None, None, None, None + + +@triton.jit +def _group_norm_mul_dropout_fwd( + X, + U, + Y, + W, + B, + Mean, + Rstd, + D, + Heads, + eps, + seed, + dropout_ratio, + stride_x, + stride_u, + stride_y, + SILU_U: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, +): + row = tl.program_id(0) + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + Y += row.to(tl.int64) * stride_y + cols = tl.arange(0, BLOCK_D) + heads = tl.arange(0, BLOCK_H) + offsets = heads[:, None] * D + cols[None, :] + mask_h = heads < Heads + mask_c = cols < D + mask = mask_c[None, :] & mask_h[:, None] + + # Compute mean + mean = 0.0 + x = tl.load(X + offsets, mask=mask, other=0.0).to(tl.float32) + mean = tl.sum(x, axis=1) / D + mean = tl.ravel(mean) + + # Compute variance + _var = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + x_mean = tl.where(mask, x - mean[:, None], 0.0) + _var += x_mean * x_mean + var = tl.sum(_var, axis=1) / D + var = tl.ravel(var) + rstd = 1 / tl.sqrt(var + eps) + tl.store(Mean + row * Heads + heads, mean, mask=mask_h) + tl.store(Rstd + row * Heads + heads, rstd, mask=mask_h) + + # Normalize and apply linear transformation + y = x_mean * rstd[:, None] # pyre-ignore [16] + w = tl.load(W + heads, mask=mask_h).to(tl.float32) + b = tl.load(B + heads, mask=mask_h).to(tl.float32) + y = y * w[:, None] + b[:, None] + u = tl.load(U + offsets, mask=mask, other=0.0).to(tl.float32) + if SILU_U: + # pyre-fixme[16] + u = fast_dividef(u, 1.0 + tl.exp(-u)) + y = y * u + + if TRAINING: + if CONCAT_UX: + random_offsets = row * 3 * D * Heads + offsets + # apply dropout on u + random_u = tl.rand(seed, random_offsets) + u_keep = random_u > dropout_ratio + u = tl.where(u_keep, u / (1.0 - dropout_ratio), 0.0) + # apply dropout on x + random_x = tl.rand(seed, random_offsets + Heads * D) + x_keep = random_x > dropout_ratio + x = tl.where(x_keep, x / (1.0 - dropout_ratio), 0.0) + # apply dropout on y + random_y = tl.rand(seed, random_offsets + 2 * Heads * D) + y_keep = random_y > dropout_ratio + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + else: + random_offsets = row * D * Heads + offsets + random = tl.rand(seed, random_offsets) + y_keep = random > dropout_ratio + # write-back + y = tl.where(y_keep, y / (1.0 - dropout_ratio), 0.0) + + # Write output + if CONCAT_UX: + tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) + + +@triton.jit +def _group_norm_mul_dropout_bwd_dx_du( + DX, + DU, + DY, + DW, + DB, + X, + U, + Y, + W, + B, + Mean, + Rstd, + stride_dx, + stride_du, + stride_dy, + stride_x, + stride_u, + stride_y, + D, + Heads, + eps, + seed, + dropout_ratio, + SILU_U: tl.constexpr, + GROUP_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, + TRAINING: tl.constexpr, + CONCAT_UX: tl.constexpr, + COMPUTE_Y: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + off_heads = tl.arange(0, BLOCK_H) + mask_c = cols < D + mask_h = off_heads < Heads + mask = mask_c[None, :] & mask_h[:, None] + X += row.to(tl.int64) * stride_x + U += row.to(tl.int64) * stride_u + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + DU += row.to(tl.int64) * stride_du + offsets = off_heads[:, None] * D + cols[None, :] + + # Load data to SRAM + x = tl.load(X + offsets, mask=mask, other=0).to(tl.float32) + if CONCAT_UX: + du = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) + dx = tl.load(DY + Heads * D + offsets, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + 2 * Heads * D + offsets, mask=mask, other=0).to(tl.float32) + else: + du = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + dx = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) + dy = tl.load(DY + offsets, mask=mask, other=0).to(tl.float32) + if TRAINING: + if CONCAT_UX: + random_offsets = row * 3 * D * Heads + offsets + # apply dropout on du + random_du = tl.rand(seed, random_offsets) + du_keep = random_du > dropout_ratio + du = tl.where(du_keep, du / (1.0 - dropout_ratio), 0.0) + # apply dropout on dx + random_dx = tl.rand(seed, random_offsets + Heads * D) + dx_keep = random_dx > dropout_ratio + dx = tl.where(dx_keep, dx / (1.0 - dropout_ratio), 0.0) + # apply dropout on dy + random_dy = tl.rand(seed, random_offsets + 2 * Heads * D) + dy_keep = random_dy > dropout_ratio + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + else: + random_offsets = row * D * Heads + offsets + random = tl.rand(seed, random_offsets) + dy_keep = random > dropout_ratio + # write-back + dy = tl.where(dy_keep, dy / (1.0 - dropout_ratio), 0.0) + + mean = tl.load(Mean + row * Heads + off_heads) + rstd = tl.load(Rstd + row * Heads + off_heads) + + # Compute dx + xhat = (x - mean[:, None]) * rstd[:, None] + w = tl.load(W + off_heads, mask=mask_h).to(tl.float32) + b = tl.load(B + off_heads, mask=mask_h).to(tl.float32) + u = tl.load(U + offsets, mask=mask, other=0).to(tl.float32) + ln = xhat * w[:, None] + b[:, None] + du += dy * ln + if SILU_U: + # pyre-ignore[16] + sig_u = fast_dividef(1.0, 1.0 + tl.exp(-u)) + du = du * (sig_u + u * sig_u * (1.0 - sig_u)) + u = u * sig_u + tl.store(DU + offsets, du.to(DU.dtype.element_ty), mask=mask) + dy = dy * u + wdy = w[:, None] * dy + if COMPUTE_Y: + Y += row.to(tl.int64) * stride_y + y = ln * u + if TRAINING: + if CONCAT_UX: + u = tl.where( + du_keep, # pyre-ignore [61] + u / (1.0 - dropout_ratio), + 0.0, + ) + x = tl.where( + dx_keep, # pyre-ignore [61] + x / (1.0 - dropout_ratio), + 0.0, + ) + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + else: + y = tl.where( + dy_keep, # pyre-ignore [61] + y / (1.0 - dropout_ratio), + 0.0, + ) + if CONCAT_UX: + tl.store(Y + offsets, u.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + Heads * D + offsets, x.to(Y.dtype.element_ty), mask=mask) + tl.store(Y + 2 * Heads * D + offsets, y.to(Y.dtype.element_ty), mask=mask) + else: + tl.store(Y + offsets, y.to(Y.dtype.element_ty), mask=mask) + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=1) / D + c2 = tl.sum(wdy, axis=1) / D + dx += (wdy - (xhat * c1[:, None] + c2[:, None])) * rstd[:, None] + # Write dx + tl.store(DX + offsets, dx, mask=mask) + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_N + DW = DW + lock_id * Heads + off_heads + DB = DB + lock_id * Heads + off_heads + # Accumulate partial sums for dw/db + partial_dw = tl.sum(dy * xhat, axis=1) + partial_dw = tl.ravel(partial_dw) + partial_db = tl.sum(dy, axis=1) + partial_db = tl.ravel(partial_db) + tl.atomic_add( + DW, + partial_dw, + mask=mask_h, + sem="relaxed", + ) + tl.atomic_add( + DB, + partial_db, + mask=mask_h, + sem="relaxed", + ) + + +def triton_group_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, int +]: # y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed + assert x.dim() == 2 + assert x.shape == u.shape + assert x.shape[1] == num_heads * linear_dim + x = switch_to_contiguous_if_needed(x) + u = switch_to_contiguous_if_needed(u) + N, _ = x.shape + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == num_heads + assert bias.numel() == num_heads + + if concat_ux: + y = torch.empty((N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device) + else: + y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) + mean = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N * num_heads,), dtype=torch.float32, device=x.device) + if N == 0: + return y, mean, rstd, 0, 0, 0, 0 + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = triton.next_power_of_2(linear_dim) + BLOCK_H: int = triton.next_power_of_2(num_heads) + if BLOCK_D * BLOCK_H > MAX_FUSED_SIZE: + raise RuntimeError( + "This group norm doesn't support num_heads * linear_dim >= 64KB." + ) + + if seed is None: + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + num_warps: int = min(max(BLOCK_D * BLOCK_H // 256, 1), 8) + # pyre-ignore[28] + _group_norm_mul_dropout_fwd[(N,)]( + x, + u, + y, + weight, + bias, + mean, + rstd, + linear_dim, + num_heads, + eps, + seed, + dropout_ratio, + x.stride(0), + u.stride(0), + y.stride(0), + SILU_U=silu_u, + BLOCK_D=BLOCK_D, + BLOCK_H=BLOCK_H, + TRAINING=training, + CONCAT_UX=concat_ux, + num_warps=num_warps, + ) + return y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed # pyre-ignore [7] + + +def triton_group_norm_mul_dropout_bwd( + dy: torch.Tensor, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + BLOCK_D: int, + BLOCK_H: int, + num_warps: int, + eps: float, + training: bool, + dropout_ratio: float, + seed: Optional[int] = None, + silu_u: bool = False, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + compute_y: bool = False, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor] +]: + y = None + N, dim = x.shape + if compute_y: + if concat_ux: + y = torch.empty( + (N, 3 * num_heads * linear_dim), dtype=x.dtype, device=x.device + ) + else: + y = torch.empty((N, num_heads * linear_dim), dtype=x.dtype, device=x.device) + if N == 0: + return ( + torch.zeros_like(x), + torch.zeros_like(u), + torch.zeros_like(weight), + torch.zeros_like(bias), + y, + ) + dx = torch.empty_like(x) + du = torch.empty_like(u) + if dim <= 1024: + GROUP_N = 256 * 8 + elif dim <= 4096: + GROUP_N = 128 * 8 + elif dim <= 8192: + GROUP_N = 96 * 8 + else: + GROUP_N = 64 * 8 + GROUP_N = N if GROUP_N > N else GROUP_N + _dweight = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) + _dbias = torch.zeros((GROUP_N, num_heads), dtype=torch.float32, device=x.device) + dweight = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((num_heads,), dtype=weight.dtype, device=x.device) + # pyre-ignore[28] + _group_norm_mul_dropout_bwd_dx_du[(N,)]( + dx, + du, + dy, + _dweight, + _dbias, + x, + u, + y, + weight, + bias, + mean, + rstd, + dx.stride(0), + du.stride(0), + dy.stride(0), + x.stride(0), + u.stride(0), + y.stride(0) if compute_y else 0, # pyre-ignore [16] + linear_dim, + num_heads, + eps, + seed, + dropout_ratio, + SILU_U=silu_u, + GROUP_N=GROUP_N, + BLOCK_D=BLOCK_D, + BLOCK_H=BLOCK_H, + TRAINING=training, + CONCAT_UX=concat_ux, + COMPUTE_Y=compute_y, + num_warps=num_warps, + ) + _group_norm_bwd_dwdb[(num_heads,)]( + _dweight, + _dbias, + dweight, + dbias, + GROUP_N, + ) + return dx, du, dweight, dbias, y + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [32, 64, 128, 256]: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=[], +) +@triton.jit +def _group_norm_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + BLOCK_N: tl.constexpr, +): + col = tl.program_id(0) + num_heads = tl.num_programs(0) + dw = tl.zeros((BLOCK_N,), dtype=tl.float32) + db = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + mask = rows < N + offs = rows * num_heads + col + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + col, sum_dw.to(FINAL_DW.dtype.element_ty)) + tl.store(FINAL_DB + col, sum_db.to(FINAL_DB.dtype.element_ty)) + + +class GroupNormMulDropoutFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + concat_ux: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + ) -> torch.Tensor: + y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( + triton_group_norm_mul_dropout_fwd( + x=x, + u=u, + weight=weight, + bias=bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + concat_ux=concat_ux, + num_heads=num_heads, + linear_dim=linear_dim, + seed=seed, + ) + ) + ctx.save_for_backward(x, u, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.BLOCK_H = BLOCK_H + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.concat_ux = concat_ux + ctx.dropout_ratio = dropout_ratio + ctx.num_heads = num_heads + ctx.linear_dim = linear_dim + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + ]: + x, u, weight, bias, mean, rstd = ctx.saved_tensors + dx, du, dweight, dbias, _ = triton_group_norm_mul_dropout_bwd( + dy=dy, + x=x, + u=u, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + BLOCK_H=ctx.BLOCK_H, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + concat_ux=ctx.concat_ux, + num_heads=ctx.num_heads, + linear_dim=ctx.linear_dim, + compute_y=False, + ) + return ( + dx, + du, + dweight, + dbias, + None, + None, + None, + None, + None, + None, + None, + ) + + +class HSTUComputeOutputFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + recompute_y_in_backward: bool = False, + ) -> torch.Tensor: + if dropout_ratio == 0.0: + training = False + + if group_norm: + y, mean, rstd, BLOCK_D, BLOCK_H, num_warps, seed = ( + triton_group_norm_mul_dropout_fwd( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + num_heads=num_heads, + linear_dim=linear_dim, + seed=seed, + ) + ) + ctx.BLOCK_H = BLOCK_H + else: + y, mean, rstd, BLOCK_D, num_warps, seed = triton_layer_norm_mul_dropout_fwd( + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + eps=eps, + dropout_ratio=dropout_ratio, + training=training, + silu_u=silu_u, + concat_ux=concat_ux, + seed=seed, + ) + + out = maybe_triton_addmm_fwd(x=y, w=output_weight, y=x) + + saved_tensors = [attn, u, norm_weight, norm_bias, mean, rstd, output_weight] + if not recompute_y_in_backward: + saved_tensors.append(y) + ctx.save_for_backward(*saved_tensors) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + ctx.seed = seed + ctx.training = training + ctx.concat_ux = concat_ux + ctx.dropout_ratio = dropout_ratio + ctx.num_heads = num_heads + ctx.linear_dim = linear_dim + ctx.group_norm = group_norm + ctx.recompute_y_in_backward = recompute_y_in_backward + ctx.silu_u = silu_u + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dout: torch.Tensor + ) -> Tuple[ + torch.Tensor, # dattn + torch.Tensor, # du + torch.Tensor, # dx + torch.Tensor, # d_norm_weight + torch.Tensor, # d_norm_bias + torch.Tensor, # d_output_weight + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + attn, u, norm_weight, norm_bias, mean, rstd, output_weight = ctx.saved_tensors[ + :7 + ] + dy = torch.mm(dout, output_weight.t()) + + if ctx.group_norm: + dattn, du, d_norm_weight, d_norm_bias, y = ( + triton_group_norm_mul_dropout_bwd( + dy=dy, + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + BLOCK_H=ctx.BLOCK_H, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_ux=ctx.concat_ux, + num_heads=ctx.num_heads, + linear_dim=ctx.linear_dim, + compute_y=ctx.recompute_y_in_backward, + ) + ) + else: + dattn, du, d_norm_weight, d_norm_bias, y = ( + triton_layer_norm_mul_dropout_bwd( + dy=dy, + x=attn, + u=u, + weight=norm_weight, + bias=norm_bias, + mean=mean, + rstd=rstd, + BLOCK_D=ctx.BLOCK_D, + num_warps=ctx.num_warps, + eps=ctx.eps, + training=ctx.training, + dropout_ratio=ctx.dropout_ratio, + seed=ctx.seed, + silu_u=ctx.silu_u, + concat_ux=ctx.concat_ux, + compute_y=ctx.recompute_y_in_backward, + ) + ) + if not ctx.recompute_y_in_backward: + y = ctx.saved_tensors[7] + d_output_weight = torch.mm(y.t(), dout) + return ( + dattn, + du, + dout, + d_norm_weight, + d_norm_bias, + d_output_weight, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +@triton.jit +def _helion_ln_mul_dropout_fwd( + x, + weight, + bias, + u, + y, + mean, + rstd, + eps, + seed, + dropout_ratio, + D: tl.constexpr, + stride_x: tl.constexpr, + stride_u: tl.constexpr, + stride_y: tl.constexpr, + _RDIM_SIZE_1: tl.constexpr, + CONCAT_UX: tl.constexpr, + SILU_U: tl.constexpr, + TRAINING: tl.constexpr, +): + row = tl.program_id(0) + x += row.to(tl.int64) * stride_x + u += row.to(tl.int64) * stride_u + y += row.to(tl.int64) * stride_y + cols = tl.arange(0, _RDIM_SIZE_1) + mask = cols < D + + # Load input + x_val = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + + # Precompute inverse of D for faster computation + inv_D = 1.0 / D + + # Compute mean + mean_val = tl.sum(x_val, axis=0) * inv_D + + # Center the data + x_mean = tl.where(mask, x_val - mean_val, 0.0) + + # Compute variance + var = tl.sum(x_mean * x_mean, axis=0) * inv_D + + # Compute reciprocal standard deviation + # pyre-fixme[16] + rstd_val = libdevice.rsqrt(var + eps) + + # Normalize + y_norm = x_mean * rstd_val + + # Apply weight and bias + w = tl.load(weight + cols, mask=mask, other=0.0).to(tl.float32) + b = tl.load(bias + cols, mask=mask, other=0.0).to(tl.float32) + y_ln = y_norm * w + b + + # Load u and optionally apply SiLU activation + u_val = tl.load(u + cols, mask=mask, other=0.0).to(tl.float32) + if SILU_U: + # pyre-fixme[16] + u_processed = fast_dividef(u_val, 1.0 + tl.exp(-u_val)) + else: + u_processed = u_val + + y_out = y_ln * u_processed + + if TRAINING: + # Compute dropout scale + # pyre-fixme[16] + dropout_scale = fast_dividef(1.0, 1.0 - dropout_ratio) + + if CONCAT_UX: + # Generate dropout masks + random_offsets = 3 * row * _RDIM_SIZE_1 + cols + random_u, random_x, random_y = rand3x(seed, random_offsets) + + u_keep = random_u > dropout_ratio + x_keep = random_x > dropout_ratio + y_keep = random_y > dropout_ratio + + # Apply dropout to u, x, y + u_output = tl.where(u_keep, u_processed * dropout_scale, 0.0) + x_output = tl.where(x_keep, x_val * dropout_scale, 0.0) + y_output = tl.where(y_keep, y_out * dropout_scale, 0.0) + else: + # Generate dropout mask for y + random_offsets = row * _RDIM_SIZE_1 + cols + random_y = tl.rand(seed, random_offsets) + y_keep = random_y > dropout_ratio + + # Apply dropout to y + y_output = tl.where(y_keep, y_out * dropout_scale, 0.0) + else: + if CONCAT_UX: + u_output = u_processed + x_output = x_val + y_output = y_out + + # Store outputs + if CONCAT_UX: + tl.store(y + cols, u_output.to(y.dtype.element_ty), mask=mask) + tl.store(y + D + cols, x_output.to(y.dtype.element_ty), mask=mask) + tl.store(y + 2 * D + cols, y_output.to(y.dtype.element_ty), mask=mask) + else: + tl.store(y + cols, y_output.to(y.dtype.element_ty), mask=mask) + + # Store mean and rstd + tl.store(mean + row, mean_val) + tl.store(rstd + row, rstd_val) + + +def helion_layer_norm_mul_dropout_fwd( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + seed: Optional[int] = None, +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, int, int, int +]: # y, mean, rstd, BLOCK_D, num_warps, seed + N, D = x.shape + + if seed is None: + seed = torch.randint(low=0, high=2**62, size=(1,), dtype=torch.int64).item() + + if concat_ux: + y = torch.empty([N, 3 * D], dtype=x.dtype, device=x.device) + else: + y = torch.empty([N, D], dtype=x.dtype, device=x.device) + mean = torch.empty([N], dtype=torch.float32, device=x.device) + rstd = torch.empty([N], dtype=torch.float32, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + # pyre-ignore[28] + _helion_ln_mul_dropout_fwd[(N,)]( + x, + weight, + bias, + u, + y, + mean, + rstd, + eps, + seed, + dropout_ratio, + D, + x.stride(0), + u.stride(0), + y.stride(0), + BLOCK_D, + CONCAT_UX=concat_ux, + SILU_U=silu_u, + TRAINING=training, + num_warps=1, + ) + + return y, mean, rstd, BLOCK_D, 1, seed # pyre-ignore [7] + + +@torch.fx.wrap +def triton_norm_mul_dropout( + x: torch.Tensor, + u: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, +) -> torch.Tensor: + if group_norm: + return GroupNormMulDropoutFunction.apply( + x, + u, + weight, + bias, + eps, + dropout_ratio, + training, + concat_ux, + num_heads, + linear_dim, + seed, + ) + else: + return LayerNormMulDropoutFunction.apply( + x, u, weight, bias, eps, dropout_ratio, training, concat_ux, seed + ) + + +@torch.fx.wrap +def triton_hstu_compute_output( + attn: torch.Tensor, + u: torch.Tensor, + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + output_weight: torch.Tensor, + eps: float, + dropout_ratio: float, + training: bool, + silu_u: bool = False, + concat_ux: bool = False, + group_norm: bool = False, + num_heads: int = 1, + linear_dim: int = -1, + seed: Optional[int] = None, + recompute_y_in_backward: bool = False, +) -> torch.Tensor: + return HSTUComputeOutputFunction.apply( + attn, + u, + x, + norm_weight, + norm_bias, + output_weight, + eps, + dropout_ratio, + training, + silu_u, + concat_ux, + group_norm, + num_heads, + linear_dim, + seed, + recompute_y_in_backward, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py new file mode 100644 index 0000000000..85e60db3c7 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_hstu_preprocess_and_attention.py @@ -0,0 +1,338 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +from typing import Optional, Tuple + +import torch +from generative_recommenders.ops.triton.triton_addmm import ( + maybe_triton_addmm_fwd, + triton_addmm_bwd, + triton_addmm_fwd, +) +from generative_recommenders.ops.triton.triton_hstu_attention import ( + triton_hstu_attention_bwd, + triton_hstu_attention_fwd, +) +from generative_recommenders.ops.triton.triton_layer_norm import ( + triton_weighted_layer_norm_bwd, + triton_weighted_layer_norm_fwd, +) +from torch.nn import functional as F + + +class _HSTUPreprocessAndAttentionFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore [14] + def forward( + ctx, # pyre-ignore [2] + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + num_targets: Optional[torch.Tensor], + max_attn_len: int, + contextual_seq_len: int, + recompute_uvqk_in_backward: bool, + recompute_normed_x_in_backward: bool, + sort_by_length: bool, + enable_tma: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + normed_x, x_mean, x_rstd, BLOCK_D = triton_weighted_layer_norm_fwd( + x=x, + weight=norm_weight, + bias=norm_bias, + eps=norm_eps, + ) + uvqk = maybe_triton_addmm_fwd( + x=normed_x, w=uvqk_weight, y=uvqk_bias + ).contiguous() + u, v, q, k = uvqk.split( + [ + hidden_dim * num_heads, + hidden_dim * num_heads, + attn_dim * num_heads, + attn_dim * num_heads, + ], + dim=1, + ) + q = q.view(-1, num_heads, attn_dim) + k = k.view(-1, num_heads, attn_dim) + v = v.view(-1, num_heads, hidden_dim) + silu_u = F.silu(u) + sort_by_length_indices = None + if sort_by_length: + seq_lengths = seq_offsets[1:] - seq_offsets[:-1] + _, sort_by_length_indices = torch.sort( + seq_lengths, descending=True, stable=False + ) + out = triton_hstu_attention_fwd( + N=max_seq_len, + alpha=attn_alpha, + q=q, + k=k, + v=v, + seq_offsets=seq_offsets, + num_targets=num_targets, + max_attn_len=max_attn_len, + contextual_seq_len=contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=enable_tma, + ) + # update ctx + saved_tensors = [ + x, + norm_weight, + norm_bias, + x_mean, + x_rstd, + uvqk_weight, + seq_offsets, + ] + if num_targets is not None: + saved_tensors.append(num_targets) + if not recompute_normed_x_in_backward: + saved_tensors.append(normed_x) + if recompute_uvqk_in_backward: + saved_tensors.append(uvqk_bias) + else: + saved_tensors.append(uvqk) + if sort_by_length: + saved_tensors.append(sort_by_length_indices) + ctx.save_for_backward(*saved_tensors) + ctx.attn_alpha = attn_alpha + ctx.has_multiple_targets = num_targets is not None + ctx.max_seq_len = max_seq_len + ctx.max_attn_len = max_attn_len + ctx.recompute_normed_x_in_backward = recompute_normed_x_in_backward + ctx.recompute_uvqk_in_backward = recompute_uvqk_in_backward + ctx.hidden_dim = hidden_dim + ctx.attn_dim = attn_dim + ctx.num_heads = num_heads + ctx.uvqk_bias_1d = uvqk_bias.dim() == 1 + ctx.norm_eps = norm_eps + ctx.norm_BLOCK_D = BLOCK_D + ctx.contextual_seq_len = contextual_seq_len + ctx.sort_by_length = sort_by_length + ctx.enable_tma = enable_tma + return silu_u, out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, # pyre-ignore[2] + dsilu_u: torch.Tensor, + dout: torch.Tensor, + ) -> Tuple[ + torch.Tensor, # d_x + torch.Tensor, # d_norm_weight + torch.Tensor, # d_norm_bias + None, + None, + None, + None, + torch.Tensor, # d_uvqk_weight + torch.Tensor, # d_uvqk_bias + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ]: + x, norm_weight, norm_bias, x_mean, x_rstd, uvqk_weight, seq_offsets = ( + ctx.saved_tensors[:7] + ) + idx = 7 + if ctx.has_multiple_targets: + num_targets = ctx.saved_tensors[idx] + idx += 1 + else: + num_targets = None + if ctx.recompute_normed_x_in_backward: + normed_x, _, _, _ = triton_weighted_layer_norm_fwd( + x=x, + weight=norm_weight, + bias=norm_bias, + eps=ctx.norm_eps, + mean=x_mean, + rstd=x_rstd, + ) + else: + normed_x = ctx.saved_tensors[idx] + idx += 1 + if ctx.recompute_uvqk_in_backward: + uvqk_bias = ctx.saved_tensors[idx] + uvqk = maybe_triton_addmm_fwd(x=normed_x, w=uvqk_weight, y=uvqk_bias) + idx += 1 + else: + uvqk = ctx.saved_tensors[idx] + idx += 1 + if ctx.sort_by_length: + sort_by_length_indices = ctx.saved_tensors[idx] + else: + sort_by_length_indices = None + + duvqk = torch.empty_like(uvqk) + du, dv, dq, dk = duvqk.split( + [ + ctx.hidden_dim * ctx.num_heads, + ctx.hidden_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ], + dim=1, + ) + u, v, q, k = uvqk.split( + [ + ctx.hidden_dim * ctx.num_heads, + ctx.hidden_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ctx.attn_dim * ctx.num_heads, + ], + dim=1, + ) + q = q.view(-1, ctx.num_heads, ctx.attn_dim) + k = k.view(-1, ctx.num_heads, ctx.attn_dim) + v = v.view(-1, ctx.num_heads, ctx.hidden_dim) + dq = dq.view(-1, ctx.num_heads, ctx.attn_dim) + dk = dk.view(-1, ctx.num_heads, ctx.attn_dim) + dv = dv.view(-1, ctx.num_heads, ctx.hidden_dim) + # Note: the two operations below update duvqk in place + ( + _dq, + _dk, + _dv, + ) = triton_hstu_attention_bwd( + dout=dout, + q=q, + k=k, + v=v, + dq=dq, + dk=dk, + dv=dv, + seq_offsets=seq_offsets, + num_targets=num_targets, + N=ctx.max_seq_len, + max_attn_len=ctx.max_attn_len, + alpha=ctx.attn_alpha, + contextual_seq_len=ctx.contextual_seq_len, + sort_by_length_indices=sort_by_length_indices, + enable_tma=ctx.enable_tma, + ) + if dq.data_ptr() != _dq.data_ptr(): + dq.copy_(_dq) + if dk.data_ptr() != _dk.data_ptr(): + dk.copy_(_dk) + if dv.data_ptr() != _dv.data_ptr(): + dv.copy_(_dv) + torch.ops.aten.silu_backward(dsilu_u, u, grad_input=du) + d_normed_x, d_uvqk_weight, d_uvqk_bias = triton_addmm_bwd( + x=normed_x, + w=uvqk_weight, + dz=duvqk, + is_y_1d=ctx.uvqk_bias_1d, + ) + d_x, d_norm_weight, d_norm_bias = triton_weighted_layer_norm_bwd( + dy=d_normed_x, + x=x, + weight=norm_weight, + bias=norm_bias, + mean=x_mean, + rstd=x_rstd, + learnable=True, + eps=ctx.norm_eps, + BLOCK_D=ctx.norm_BLOCK_D, + ) + # pyre-ignore[7] + return ( + d_x, + d_norm_weight, + d_norm_bias, + None, + None, + None, + None, + d_uvqk_weight, + d_uvqk_bias, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def triton_hstu_preprocess_and_attention( + x: torch.Tensor, + norm_weight: torch.Tensor, + norm_bias: torch.Tensor, + norm_eps: float, + num_heads: int, + attn_dim: int, + hidden_dim: int, + uvqk_weight: torch.Tensor, + uvqk_bias: torch.Tensor, + max_seq_len: int, + seq_offsets: torch.Tensor, + attn_alpha: float, + num_targets: Optional[torch.Tensor], + max_attn_len: int = 0, + contextual_seq_len: int = 0, + recompute_uvqk_in_backward: bool = False, + recompute_normed_x_in_backward: bool = False, + sort_by_length: bool = False, + enable_tma: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _HSTUPreprocessAndAttentionFunction.apply( + x, + norm_weight, + norm_bias, + norm_eps, + num_heads, + attn_dim, + hidden_dim, + uvqk_weight, + uvqk_bias, + max_seq_len, + seq_offsets, + attn_alpha, + num_targets, + max_attn_len, + contextual_seq_len, + recompute_uvqk_in_backward, + recompute_normed_x_in_backward, + sort_by_length, + enable_tma, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged.py new file mode 100644 index 0000000000..46884a63d0 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged.py @@ -0,0 +1,2209 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +from generative_recommenders.common import ( + autotune_max_seq_len, + fine_grained_autotune_max_seq_len, + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.utils import is_sm100 +from torch._inductor.runtime import triton_helpers + + +def _triton_concat_2D_jagged_internal( + values_a: torch.Tensor, + values_b: torch.Tensor, + values_out: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + D: int, + dense_size: int, + stride_dense_batch: int, + n_prefix: int, + is_dense_a: bool, + is_dense_b: bool, + is_replace: bool, + BLOCK_D: int, +) -> None: + if n_prefix != 0: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_jagged_w_prefix_multirow[grid]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(0), + n_prefix_from_B=n_prefix, + BLOCK_D=BLOCK_D, + ) + else: + concat_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(0), + n_prefix_from_B=n_prefix, + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=dense_size, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=values_out.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + else: + concat_2D_jagged[(max_seq_len, B)]( + OffsetsA=offsets_a, + ValuesA=values_a, + OffsetsB=offsets_b, + ValuesB=values_b, + DenseSize=dense_size, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_dense_batch=stride_dense_batch, + stride_od=values_out.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + + +def _get_split_concat_2d_jagged_multirow_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [1, 2, 4, 8]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +def _get_bmm_configs() -> List[triton.Config]: + configs = [] + for BLOCK_M in [64, 128]: + for BLOCK_N in [64, 128, 256]: + for BLOCK_K in [32, 64]: + for num_stages in [3, 5]: + for num_warps in [4, 8]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bmm_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN", "N", "K", "ELEMENTWISE", "HAS_BIAS"], +) +@triton.jit +def jagged_dense_bmm_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Bias, + Out, + AUTOTUNE_MAX_SEQ_LEN, + N, + K, + stride_jm, + stride_db, + stride_dk, + stride_dn, + stride_bias_b, + stride_om, + HAS_BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + ELEMENTWISE: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Dense + Bias + M is the jagged dimension + Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N), and Out has shape (sum_B(M_i), N) + """ + + off_n = tl.program_id(0) + off_m = tl.program_id(1).to(tl.int64) + off_b = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + if start_m >= seq_len: + return + + Jagged += (seq_start + start_m) * stride_jm + Dense += off_b.to(tl.int64) * stride_db + Out += seq_start * stride_om + + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + jg_ptrs = Jagged + offs_m[:, None] * stride_jm + offs_k[None, :] + dn_ptrs = Dense + offs_k[:, None] * stride_dk + offs_n[None, :] * stride_dn + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, K, BLOCK_K): + jg = tl.load( + jg_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < (seq_len - start_m)) & ((k + offs_k)[None, :] < K), + other=0.0, + ) + dn = tl.load( + dn_ptrs, + mask=((k + offs_k)[:, None] < K) and (offs_n[None, :] < N), + other=0.0, + ) + accumulator += tl.dot(jg, dn, allow_tf32=ALLOW_TF32) + jg_ptrs += BLOCK_K + dn_ptrs += BLOCK_K * stride_dk + + if HAS_BIAS: + if ELEMENTWISE: + Bias += (seq_start + start_m) * stride_bias_b + bias_ptrs = Bias + offs_m[:, None] * stride_bias_b + offs_n[None, :] + bias = tl.load( + bias_ptrs, + mask=(offs_m[:, None] < (seq_len - start_m)) & (offs_n[None, :] < N), + other=0.0, + ) + accumulator += bias.to(tl.float32) + else: + bias_ptrs = Bias + off_b.to(tl.int64) * stride_bias_b + offs_n + bias = tl.load(bias_ptrs, mask=offs_n < N) + accumulator += bias[None, :].to(tl.float32) + + out = accumulator.to(Out.dtype.element_ty) + + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + Out += start_m * stride_om + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (seq_len - start_m)) & (offs_n[None, :] < N), + ) + + +def _get_bmm_reduce_sum_configs() -> List[triton.Config]: + configs = [] + for BLOCK_M in [64, 128]: + for BLOCK_N in [64, 128]: + for BLOCK_K in [64, 128]: + for num_stages in [3, 4]: + for num_warps in [4, 8]: + configs.append( + triton.Config( + { + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "BLOCK_K": BLOCK_K, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bmm_reduce_sum_configs(), + key=["M", "N", "AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def _jagged_jagged_bmm_reduce_sum( + seq_offsets, + JaggedA, + JaggedB, + Out, + ReduceOut, + M, + N, + AUTOTUNE_MAX_SEQ_LEN, + stride_ak, + stride_bk, + stride_ob, + stride_om, + stride_on, + stride_orb, + stride_orn, + REDUCE_JAGGEDB: tl.constexpr, + ALLOW_TF32: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + """ + Computing bmm Out = Jagged x Jagged + K is the jagged dimension + JaggedA has shape (sum_B(K_i), M), JaggedB has shape (sum_B(K_i), N), and Out has shape (B, M, N) + """ + + off_m = tl.program_id(0).to(tl.int64) + off_n = tl.program_id(1) + off_b = tl.program_id(2) + + seq_start = tl.load(seq_offsets + off_b).to(tl.int64) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + + start_m = off_m * BLOCK_M + start_n = off_n * BLOCK_N + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + Out += off_b.to(tl.int64) * stride_ob + offs_m = tl.arange(0, BLOCK_M) + offs_n = start_n + tl.arange(0, BLOCK_N) + Out += start_m * stride_om + out_ptrs = Out + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + if REDUCE_JAGGEDB: + out_reduce_ptrs = ( + ReduceOut + off_b.to(tl.int64) * stride_orb + offs_n * stride_orn + ) + acc_reduce = tl.zeros((BLOCK_N,), dtype=tl.float32) + if seq_len == 0: + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (M - start_m)) & (offs_n[None, :] < N), + ) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + return + + JaggedA += seq_start * stride_ak + JaggedB += seq_start * stride_bk + offs_k = tl.arange(0, BLOCK_K) + jg_a_ptrs = JaggedA + offs_k[None, :] * stride_ak + (start_m + offs_m)[:, None] + jg_b_ptrs = JaggedB + offs_k[:, None] * stride_bk + offs_n[None, :] + + for k in range(0, seq_len, BLOCK_K): + jg_a = tl.load( + jg_a_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_m[:, None] < (M - start_m)) & ((k + offs_k)[None, :] < seq_len), + other=0.0, + ) + jg_b = tl.load( + jg_b_ptrs, + mask=(offs_n[None, :] < N) and ((k + offs_k)[:, None] < seq_len), + other=0.0, + ) + + accumulator += tl.dot(jg_a, jg_b, allow_tf32=ALLOW_TF32) + if REDUCE_JAGGEDB: + if off_m == 0: + acc_reduce += tl.sum(jg_b.to(tl.float32), axis=0) + + jg_a_ptrs += BLOCK_K * stride_ak + jg_b_ptrs += BLOCK_K * stride_bk + + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < (M - start_m)) & (offs_n[None, :] < N), + ) + if REDUCE_JAGGEDB: + if off_m == 0: + tl.store( + out_reduce_ptrs, # pyre-ignore [61] + acc_reduce.to(ReduceOut.dtype.element_ty), + mask=(offs_n < N), + ) + + +class _JaggedDenseBmmFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + L, D = jagged.shape + B, _, K = dense.shape + bmm_out = torch.empty((L, K), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=0, + Out=bmm_out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=K, + K=D, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=0, + stride_om=bmm_out.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.D = D + return bmm_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_bmm_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets, jagged, dense = ctx.saved_tensors + d_jagged = torch.empty_like(jagged) + d_dense = torch.empty_like(dense) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.D, meta["BLOCK_N"]), + triton.cdiv(ctx.max_seq_len, meta["BLOCK_M"]), + ctx.B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_bmm_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + N=ctx.D, + K=ctx.K, + stride_jm=d_bmm_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(ctx.D, meta["BLOCK_M"]), + triton.cdiv(ctx.K, meta["BLOCK_N"]), + ctx.B, + ) + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_bmm_out, + Out=d_dense, + ReduceOut=None, + M=ctx.D, + N=ctx.K, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_bmm_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=0, + stride_orn=0, + REDUCE_JAGGEDB=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return None, None, d_jagged, d_dense + + +def _get_jagged_dense_broadcast_add_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [16, 32, 64]: + for num_stages in [1, 2]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_N": BLOCK_N, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_jagged_dense_broadcast_add_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def jagged_dense_broadcast_add_kernel( + seq_offsets, + Jagged, + Dense, + Out, + AUTOTUNE_MAX_SEQ_LEN, + D, + stride_jn, + stride_db, + stride_on, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + + off_b = tl.program_id(0) + off_n = tl.program_id(1) + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + start_n = off_n * BLOCK_N + if start_n >= seq_len: + return + Jagged += seq_start * stride_jn + Dense += off_b * stride_db + Out += seq_start * stride_on + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_n[:, None] * stride_jn + offs_d[None, :] + dense_ptrs = Dense + offs_d + out_ptrs = Out + offs_n[:, None] * stride_jn + offs_d[None, :] + for d in range(0, D, BLOCK_D): + jg = tl.load( + jagged_ptrs, + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask=(offs_n[:, None] < seq_len) and (d + offs_d)[None, :] < D, + ) + dn = tl.load(dense_ptrs, mask=d + offs_d < D) + out = jg + dn[None, :] + tl.store( + out_ptrs, + out, + mask=(offs_n[:, None] < seq_len) and (d + offs_d)[None, :] < D, + ) + dense_ptrs += BLOCK_D + jagged_ptrs += BLOCK_D + out_ptrs += BLOCK_D + + +@triton.jit +def jagged_reduce_sum( + seq_offsets, + Jagged, + Out, + D, + stride_jn, + stride_ob, + BLOCK_D: tl.constexpr, +): + """ + Computing Out = Jagged + Dense + JaggedA has shape (sum_B(N_i), D), Dense has shape (B, D), and Out has shape (sum_B(N_i), D) + """ + off_b = tl.program_id(0) + off_d = tl.program_id(1) * BLOCK_D + seq_start = tl.load(seq_offsets + off_b) + seq_end = tl.load(seq_offsets + off_b + 1) + seq_len = seq_end - seq_start + Jagged += seq_start * stride_jn + Out += off_b * stride_ob + offs_d = off_d + tl.arange(0, BLOCK_D) + jagged_ptrs = Jagged + offs_d + out_ptrs = Out + offs_d + accumulator = tl.zeros((BLOCK_D,), dtype=tl.float32) + for _ in range(0, seq_len): + jg = tl.load( + jagged_ptrs, + mask=offs_d < D, + ) + accumulator += jg + jagged_ptrs += stride_jn + out = accumulator.to(Out.dtype.element_ty) + tl.store( + out_ptrs, + out, + mask=offs_d < D, + ) + + +class _JaggedDenseBroadcastAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + ): + jagged = switch_to_contiguous_if_needed(jagged) + dense = switch_to_contiguous_if_needed(dense) + L, D = jagged.shape + B, _ = dense.shape + out = torch.empty_like(jagged) + + grid = lambda meta: ( # noqa E731 + B, + triton.cdiv(max_seq_len, meta["BLOCK_N"]), + ) + BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64 + jagged_dense_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + D=D, + stride_jn=jagged.stride(0), + stride_db=dense.stride(0), + stride_on=out.stride(0), + BLOCK_D=BLOCK_D, + ) + + ctx.save_for_backward(seq_offsets) + ctx.max_seq_len = max_seq_len + ctx.B = B + ctx.D = D + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor]: + seq_offsets = ctx.saved_tensors[0] + d_dense = torch.empty((ctx.B, ctx.D), device=d_out.device, dtype=d_out.dtype) + BLOCK_D = triton.next_power_of_2(ctx.D) if ctx.D < 64 else 64 + jagged_reduce_sum[(ctx.B, triton.cdiv(ctx.D, BLOCK_D))]( + seq_offsets=seq_offsets, + Jagged=d_out, + Out=d_dense, + D=ctx.D, + stride_jn=d_out.stride(0), + stride_ob=d_dense.stride(0), + BLOCK_D=BLOCK_D, + ) + return None, None, d_out, d_dense + + +def triton_jagged_dense_bmm_add_fwd( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> Tuple[torch.Tensor, int, int, int]: + jagged = switch_to_contiguous_if_needed(jagged) + bias = switch_to_contiguous_if_needed(bias) + L, K = jagged.shape + B, _, N = dense.shape + out = torch.empty((L, N), dtype=jagged.dtype, device=jagged.device) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=jagged, + Dense=dense, + Bias=bias, + Out=out, + AUTOTUNE_MAX_SEQ_LEN=fine_grained_autotune_max_seq_len(max_seq_len), + N=N, + K=K, + stride_jm=jagged.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(1), + stride_dn=dense.stride(2), + stride_bias_b=bias.stride(0), + stride_om=out.stride(0), + HAS_BIAS=True, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=elementwise, + ) + + return out, B, K, N + + +def triton_jagged_dense_bmm_add_bwd_jagged( + max_seq_len: int, + seq_offsets: torch.Tensor, + d_jagged: torch.Tensor, + dense: torch.Tensor, + d_out: torch.Tensor, + K: int, + B: int, + N: int, +) -> torch.Tensor: + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_N"]), + triton.cdiv(max_seq_len, meta["BLOCK_M"]), + B, + ) + jagged_dense_bmm_broadcast_add_kernel[grid]( + seq_offsets=seq_offsets, + Jagged=d_out, + Dense=dense, + Bias=None, + Out=d_jagged, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + N=K, + K=N, + stride_jm=d_out.stride(0), + stride_db=dense.stride(0), + stride_dk=dense.stride(2), + stride_dn=dense.stride(1), + stride_bias_b=0, + stride_om=d_jagged.stride(0), + HAS_BIAS=False, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ELEMENTWISE=False, + ) + + return d_jagged + + +def triton_jagged_dense_bmm_add_bwd_dense_bias( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + d_dense: torch.Tensor, + B: int, + K: int, + N: int, + d_out: torch.Tensor, + elementwise: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + d_bias = torch.empty((B, N), device=d_out.device, dtype=d_out.dtype) + + grid = lambda meta: ( # noqa E731 + triton.cdiv(K, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + B, + ) + + if elementwise: + d_bias = d_out + reduce_out = None + stride_orb = 0 + stride_orn = 0 + reduce_jaggedb = False + else: + reduce_out = d_bias + stride_orb = d_bias.stride(0) + stride_orn = d_bias.stride(1) + reduce_jaggedb = True + + _jagged_jagged_bmm_reduce_sum[grid]( + seq_offsets=seq_offsets, + JaggedA=jagged, + JaggedB=d_out, + Out=d_dense, + ReduceOut=reduce_out, + M=K, + N=N, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + stride_ak=jagged.stride(0), + stride_bk=d_out.stride(0), + stride_ob=d_dense.stride(0), + stride_om=d_dense.stride(1), + stride_on=d_dense.stride(2), + stride_orb=stride_orb, + stride_orn=stride_orn, + REDUCE_JAGGEDB=reduce_jaggedb, + ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32, + ) + + return d_dense, d_bias + + +class _JaggedDenseBmmAddFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, + ): + out, B, K, N = triton_jagged_dense_bmm_add_fwd( + max_seq_len, seq_offsets, jagged, dense, bias, elementwise + ) + + ctx.save_for_backward(seq_offsets, jagged, dense) + ctx.B = B + ctx.max_seq_len = max_seq_len + ctx.K = K + ctx.N = N + ctx.elementwise = elementwise + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, None, torch.Tensor, torch.Tensor, torch.Tensor, None]: + seq_offsets, jagged, dense = ctx.saved_tensors + d_jagged = triton_jagged_dense_bmm_add_bwd_jagged( + ctx.max_seq_len, + seq_offsets, + torch.empty_like(jagged), + dense, + d_out, + ctx.K, + ctx.B, + ctx.N, + ) + d_dense, d_bias = triton_jagged_dense_bmm_add_bwd_dense_bias( + ctx.max_seq_len, + seq_offsets, + jagged, + torch.empty_like(dense), + ctx.B, + ctx.K, + ctx.N, + d_out, + ctx.elementwise, + ) + + return None, None, d_jagged, d_dense, d_bias, None + + +@triton.jit +def concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + n_prefix_from_B, # nonzero is not supported when IS_REPLACE=True + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + offs_d = tl.arange(0, BLOCK_D) + if IS_REPLACE: + out_seq_start = seq_start_a + off_n + out_seq_b_start = seq_len_a - seq_len_b + else: + out_seq_start = seq_start_a + seq_start_b + off_n + out_seq_b_start = seq_len_a + n_prefix_from_B + + out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_from_B: + off_a = off_n - n_prefix_from_B + if IS_DENSE_A: + in_ptrs = ( + ValuesA + + off_a.to(tl.int64) * stride_ad + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_from_B + if off_n < n_prefix_from_B: + off_b += out_seq_b_start - n_prefix_from_B + if IS_DENSE_B: + in_ptrs = ( + ValuesB + + off_b.to(tl.int64) * stride_bd + + off_z.to(tl.int64) * stride_dense_batch + + offs_d + ) + else: + in_ptrs = ValuesB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def concat_2D_jagged( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def concat_2D_jagged_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + BLOCK_D: tl.constexpr, +): + concat_2D_jagged_w_prefix( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + 0, + Out, + D, + stride_ad, + stride_bd, + 0, + stride_od, + n_prefix_from_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +@triton.jit +def split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + + if IS_REPLACE: + seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_to_B + + offs_d = tl.arange(0, BLOCK_D) + in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d + if off_n < out_seq_b_start and off_n >= n_prefix_to_B: + off_a = off_n - n_prefix_to_B + out_ptrs = OutA + (off_a + seq_start_a).to(tl.int64) * stride_ad + offs_d + else: + off_b = off_n - out_seq_b_start + n_prefix_to_B + if off_n < n_prefix_to_B: + off_b += out_seq_b_start - n_prefix_to_B + out_ptrs = OutB + (off_b + seq_start_b).to(tl.int64) * stride_bd + offs_d + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def split_2D_jagged( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + IS_REPLACE, + ) + + +@triton.jit +def split_2D_jagged_jagged_w_prefix( + JaggedIn, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + BLOCK_D: tl.constexpr, +): + split_2D_jagged_w_prefix( + JaggedIn, + 0, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + IS_REPLACE=False, + ) + + +def _triton_split_2D_jagged_internal( + jagged_in: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + out_a: torch.Tensor, + out_b: torch.Tensor, + D: int, + dense_size: int, + n_prefix: int, + is_dense_a: bool, + is_dense_b: bool, + is_replace: bool, + BLOCK_D: int, +) -> None: + if n_prefix != 0: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_jagged_w_prefix_multirow[grid]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix, + BLOCK_D=BLOCK_D, + ) + else: + split_2D_jagged_jagged_w_prefix[(max_seq_len, B)]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix, + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=jagged_in, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + else: + split_2D_jagged[(max_seq_len, B)]( + JaggedIn=jagged_in, + DenseSize=dense_size, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + IS_REPLACE=is_replace, # pyre-ignore[6] + ) + + +class _Concat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + is_dense_a = offsets_a is None + is_dense_b = offsets_b is None + dense_size: int = 0 + if is_dense_a: + assert offsets_b is not None + B, dense_size, D = values_a.shape + seq_len_a = dense_size * B + seq_len_b, _ = values_b.shape + device = values_b.device + dtype = values_b.dtype + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + assert offsets_a is not None + B, dense_size, D = values_b.shape + seq_len_a, _ = values_a.shape + seq_len_b = dense_size * B + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = values_b.stride(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + seq_len_a, D = values_a.shape + seq_len_b, _ = values_b.shape + device = values_a.device + dtype = values_a.dtype + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(D) + if is_replace: + values_out = torch.empty_like(values_a) + else: + values_out = torch.empty( + (seq_len_a + seq_len_b, D), device=device, dtype=dtype + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=values_out, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=D, + dense_size=dense_size, + stride_dense_batch=stride_dense_batch, + n_prefix=n_prefix_from_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=is_replace, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.is_replace = is_replace + ctx.n_prefix_from_right = n_prefix_from_right + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b, is_replace = ( + ctx.is_dense_a, + ctx.is_dense_b, + ctx.is_replace, + ) + dense_size = ctx.dense_size + if is_dense_a: + B = offsets_b.shape[0] - 1 + else: + B = offsets_a.shape[0] - 1 + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.zeros( + (ctx.seq_len_a, D), device=d_out.device, dtype=d_out.dtype + ) + values_b = torch.empty( + (ctx.seq_len_b, D), device=d_out.device, dtype=d_out.dtype + ) + _triton_split_2D_jagged_internal( + jagged_in=d_out, + max_seq_len=ctx.max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + out_a=values_a, + out_b=values_b, + D=D, + dense_size=dense_size, + n_prefix=ctx.n_prefix_from_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=is_replace, + BLOCK_D=BLOCK_D, + ) + + if is_dense_a: + values_a = values_a.reshape((B, dense_size, D)) + elif is_dense_b: + values_b = values_b.reshape((B, dense_size, D)) + return None, values_a, values_b, None, None, None, None + + +class _Split2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, + seq_len_a: Optional[int] = None, + seq_len_b: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_a: bool = offsets_a is None + is_dense_b: bool = offsets_b is None + if is_dense_a: + L, _ = values.shape + assert offsets_b is not None + B = offsets_b.shape[0] - 1 + seq_len_a = dense_size * B + seq_len_b = L - seq_len_a + offsets_a = offsets_b.new_empty(0) + elif is_dense_b: + L, _ = values.shape + assert offsets_a is not None + B = offsets_a.shape[0] - 1 + seq_len_b = dense_size * B + seq_len_a = L - seq_len_b + offsets_b = offsets_a.new_empty(0) + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + + # Select the last offset item using torch.index_select instead of + # "int(offsets_a[-1].item())" so that it won't cause "Cannot cast + # FakeTensor to python number" error for AOTI. + if torch.compiler.is_compiling(): + offsets_a_last_idx = torch.tensor(offsets_a.size(0) - 1).to( + offsets_a.device, non_blocking=True + ) + offsets_b_last_idx = torch.tensor(offsets_b.size(0) - 1).to( + offsets_b.device, non_blocking=True + ) + if seq_len_a is None: + seq_len_a = offsets_a.index_select(dim=0, index=offsets_a_last_idx) + if seq_len_b is None: + seq_len_b = offsets_b.index_select(dim=0, index=offsets_b_last_idx) + else: + if seq_len_a is None: + seq_len_a = int(offsets_a[-1].item()) + if seq_len_b is None: + seq_len_b = int(offsets_b[-1].item()) + _, D = values.shape + BLOCK_D = triton.next_power_of_2(D) + # pyre-ignore[6] Incompatible parameter type + values_a = torch.empty((seq_len_a, D), device=values.device, dtype=values.dtype) + # pyre-ignore[6] Incompatible parameter type + values_b = torch.empty((seq_len_b, D), device=values.device, dtype=values.dtype) + _triton_split_2D_jagged_internal( + jagged_in=values, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + out_a=values_a, + out_b=values_b, + D=D, + dense_size=dense_size, + n_prefix=n_prefix_to_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + if is_dense_a: + values_a = values_a.reshape(B, dense_size, D) + if is_dense_b: + values_b = values_b.reshape(B, dense_size, D) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.seq_len_a = seq_len_a + ctx.seq_len_b = seq_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.dense_size = dense_size + ctx.B = B + ctx.D = D + ctx.n_prefix_to_right = n_prefix_to_right + return values_a, values_b + + @staticmethod + def backward( + ctx, *d_values + ) -> Tuple[torch.Tensor, None, None, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + is_dense_a, is_dense_b = ctx.is_dense_a, ctx.is_dense_b + values_a, values_b = d_values + if is_dense_a: + stride_dense_batch = values_a.stride(0) + elif is_dense_b: + stride_dense_batch = values_b.stride(0) + else: + stride_dense_batch = 0 + + BLOCK_D = triton.next_power_of_2(ctx.D) + dvalues = torch.empty( + (ctx.seq_len_a + ctx.seq_len_b, ctx.D), + device=values_a.device, + dtype=values_b.dtype, + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=dvalues, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + D=ctx.D, + dense_size=ctx.dense_size, + stride_dense_batch=stride_dense_batch, + n_prefix=ctx.n_prefix_to_right, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + is_replace=False, + BLOCK_D=BLOCK_D, + ) + + return dvalues, None, None, None, None, None, None, None + + +@torch.jit.unused +@torch.fx.wrap +def triton_jagged_dense_bmm_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, + bias: torch.Tensor, + elementwise: bool = False, +) -> torch.Tensor: + """ + Computing bmm Out = Jagged x Dense + Bias + M is the jagged dimension + Jagged has shape (sum_B(M_i), K), Dense has shape (B, K, N), Bias has shape (B, N) or (sum_B(M_i), N) depending on Elementwise, and Out has shape (sum_B(M_i), N) + """ + return _JaggedDenseBmmAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense, bias, elementwise + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + is_replace: bool = False, + n_prefix_from_right: int = 0, +) -> torch.Tensor: + return _Concat2DJaggedFunction.apply( + max_seq_len, + values_a, + values_b, + offsets_a, + offsets_b, + is_replace, + n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_concat_2D_jagged_jagged( + max_seq_len_left: int, + offsets_left: torch.Tensor, + values_left: torch.Tensor, + max_seq_len_right: int, + offsets_right: torch.Tensor, + values_right: torch.Tensor, + is_replace: bool, + n_prefix_from_right: int, +) -> torch.Tensor: + return triton_concat_2D_jagged( + max_seq_len=max_seq_len_left + max_seq_len_right, + values_a=values_left, + values_b=values_right, + offsets_a=offsets_left, + offsets_b=offsets_right, + is_replace=is_replace, + n_prefix_from_right=n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_concat_2D_dense_jagged( + jagged_max_seq_len: int, + jagged_offsets: torch.Tensor, + jagged_values: torch.Tensor, + dense_values: torch.Tensor, +) -> torch.Tensor: + B, dense_size, D = dense_values.size() + max_seq_len = jagged_max_seq_len + dense_size + return triton_concat_2D_jagged( + max_seq_len=max_seq_len, + values_a=dense_values, + values_b=jagged_values, + offsets_a=None, + offsets_b=jagged_offsets, + ) + + +def triton_jagged_dense_bmm( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBmmFunction.apply(max_seq_len, seq_offsets, jagged, dense) + + +@torch.jit.unused +def triton_split_2D_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: Optional[torch.Tensor] = None, + offsets_b: Optional[torch.Tensor] = None, + dense_size: int = 0, + n_prefix_to_right: int = 0, + seq_len_a: Optional[int] = None, + seq_len_b: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedFunction.apply( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + n_prefix_to_right, + seq_len_a, + seq_len_b, + ) + + +@triton.jit +def concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + out_seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_len = seq_len_a + seq_len_b + out_seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_from_B + + start_n = off_block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + if start_n >= seq_len: + return + valid_mask = offs_n < seq_len + + out_ptrs = ( + Out + + (out_seq_start + offs_n[:, None]).to(tl.int64) * stride_od + + offs_d[None, :] + ) + + to_a_mask = (offs_n < out_seq_b_start) & (offs_n >= n_prefix_from_B) & valid_mask + to_b_mask = ~to_a_mask & valid_mask + + off_a = offs_n - n_prefix_from_B + if IS_DENSE_A: + in_a_ptrs = ( + ValuesA + + off_a[:, None].to(tl.int64) * stride_ad + + off_z.to(tl.int64) * stride_dense_batch + + offs_d[None, :] + ) + else: + in_a_ptrs = ( + ValuesA + + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + + offs_d[None, :] + ) + + v_a = tl.load(in_a_ptrs, mask=to_a_mask[:, None] & (offs_d[None, :] < D), other=0.0) + tl.store(out_ptrs, v_a, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + prefix_mask = offs_n < n_prefix_from_B + + off_b = tl.where(prefix_mask, offs_n, offs_n - out_seq_b_start + n_prefix_from_B) + if IS_DENSE_B: + in_b_ptrs = ( + ValuesB + + off_b[:, None].to(tl.int64) * stride_bd + + off_z.to(tl.int64) * stride_dense_batch + + offs_d[None, :] + ) + else: + in_b_ptrs = ( + ValuesB + + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + + v_b = tl.load(in_b_ptrs, mask=to_b_mask[:, None] & (offs_d[None, :] < D), other=0.0) + tl.store(out_ptrs, v_b, mask=to_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + DenseSize, + Out, + D, + stride_ad, + stride_bd, + stride_dense_batch, + stride_od, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + IS_REPLACE, + ) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + concat_2D_jagged_w_prefix_multirow( + OffsetsA, + ValuesA, + OffsetsB, + ValuesB, + 0, + Out, + D, + stride_ad, + stride_bd, + 0, + stride_od, + n_prefix_from_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_REPLACE=False, + ) + + +@triton.jit +def split_2D_jagged_w_prefix_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + off_z = tl.program_id(1) + off_block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_start_a = off_z * DenseSize + seq_len_a = DenseSize + seq_len_b = seq_end_b - seq_start_b + elif IS_DENSE_B: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = off_z * DenseSize + seq_len_b = DenseSize + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + + if IS_REPLACE: + seq_len = seq_len_a + else: + seq_len = seq_len_a + seq_len_b + + if IS_REPLACE: + seq_start = seq_start_a + out_seq_b_start = seq_len_a - seq_len_b + else: + seq_start = seq_start_a + seq_start_b + out_seq_b_start = seq_len_a + n_prefix_to_B + + start_n = off_block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + if start_n >= seq_len: + return + valid_mask = offs_n < seq_len + + in_ptrs = ( + JaggedIn + + (seq_start + offs_n[:, None]).to(tl.int64) * stride_id + + offs_d[None, :] + ) + + v = tl.load(in_ptrs, mask=valid_mask[:, None] & (offs_d[None, :] < D), other=0.0) + + to_a_mask = (offs_n < out_seq_b_start) & (offs_n >= n_prefix_to_B) & valid_mask + to_b_mask = ~to_a_mask & valid_mask + + off_a = offs_n - n_prefix_to_B + out_a_ptrs = ( + OutA + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + offs_d[None, :] + ) + tl.store(out_a_ptrs, v, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + prefix_mask = offs_n < n_prefix_to_B + + off_b = tl.where(prefix_mask, offs_n, offs_n - out_seq_b_start + n_prefix_to_B) + out_b_ptrs = ( + OutB + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + offs_d[None, :] + ) + tl.store(out_b_ptrs, v, mask=to_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_REPLACE: tl.constexpr, +): + split_2D_jagged_w_prefix_multirow( + JaggedIn, + DenseSize, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + 0, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + IS_REPLACE, + ) + + +@triton_autotune( + configs=_get_split_concat_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_jagged_w_prefix_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + split_2D_jagged_w_prefix_multirow( + JaggedIn, + 0, + OffsetsA, + OffsetsB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A=False, + IS_DENSE_B=False, + BLOCK_D=BLOCK_D, + BLOCK_N=BLOCK_N, + IS_REPLACE=False, + ) + + +def triton_jagged_dense_broadcast_add( + max_seq_len: int, + seq_offsets: torch.Tensor, + jagged: torch.Tensor, + dense: torch.Tensor, +) -> torch.Tensor: + return _JaggedDenseBroadcastAddFunction.apply( + max_seq_len, seq_offsets, jagged, dense + ) + + +@triton.jit +def _helion_split_2d_jagged_kernel( + offsets_a, + offsets_b, + values_flat, + out_a_flat, + out_b_flat, + max_seq_len, + D: tl.constexpr, + _BLOCK_SIZE_0: tl.constexpr, + _BLOCK_SIZE_1: tl.constexpr, +) -> None: + # Get program ID and decompose to batch and sequence block coordinates + program_id = tl.program_id(0) + flat_program_id = program_id + batch_id = triton_helpers.div_floor_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + seq_block_id = triton_helpers.remainder_integer( # noqa: F841 + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + # Load output boundaries for part A + out_a_start = tl.load(offsets_a + batch_id * 1, None, eviction_policy="evict_last") + batch_id_plus_1 = 1 + triton_helpers.div_floor_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + out_a_end = tl.load( + offsets_a + batch_id_plus_1 * 1, None, eviction_policy="evict_last" + ) + len_a = out_a_end - out_a_start + # Load output boundaries for part B + out_b_start = tl.load(offsets_b + batch_id * 1, None) + out_b_end = tl.load( + offsets_b + batch_id_plus_1 * 1, None, eviction_policy="evict_last" + ) + len_b = out_b_end - out_b_start + # Compute input start and total length for this batch + input_start = out_a_start + out_b_start + total_len = len_a + len_b + # Calculate sequence offset for this block + seq_offset = _BLOCK_SIZE_0 * triton_helpers.remainder_integer( + flat_program_id, + triton_helpers.div_floor_integer( + -1 + _BLOCK_SIZE_0 + max_seq_len, _BLOCK_SIZE_0 + ), + ) + has_work = total_len > seq_offset + if has_work: + # Generate row indices for this sequence block + seq_range = tl.arange(0, _BLOCK_SIZE_0) + seq_offset_i32 = tl.cast(seq_offset, tl.int32) + row_indices = seq_range + seq_offset_i32 + + # Create masks for valid rows and parts A/B + total_len_i32 = tl.cast(total_len[None], tl.int32) + len_a_i32 = tl.cast(len_a[None], tl.int32) + valid_mask = row_indices < total_len_i32 + is_part_a = row_indices < len_a_i32 + is_part_b = (row_indices >= len_a_i32) & valid_mask + + # Extract scalar values once + input_start_i32 = tl.cast(input_start[None, None], tl.int32) + out_a_start_i32 = tl.cast(out_a_start[None, None], tl.int32) + out_b_start_i32 = tl.cast(out_b_start[None, None], tl.int32) + + # Process features in smaller tiles + for feature_offset in tl.range( + 0, + D, + _BLOCK_SIZE_1, + loop_unroll_factor=1, + num_stages=4, + disallow_acc_multi_buffer=True, + flatten=True, + ): + feature_indices = feature_offset + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + + # Compute D constant and feature mask once per feature iteration + D_const = tl.full([], tl.cast(D, tl.int32), tl.int32) + D_i32 = tl.cast(D, tl.int32) + feature_mask = feature_indices < D_i32 + + # Compute indices for part A + row_subscript = row_indices[:, None] + input_row_a = input_start_i32 + row_subscript + input_idx_a = ( + tl.cast(input_row_a * D_const, tl.int32) + feature_indices[None, :] + ) + + out_a_row = out_a_start_i32 + row_subscript + out_a_idx = ( + tl.cast(out_a_row * D_const, tl.int32) + feature_indices[None, :] + ) + + mask_a = is_part_a[:, None] & valid_mask[:, None] & feature_mask[None, :] + + # Load and store part A data + slice_a = tl.load( + values_flat + input_idx_a * 1, + mask_a, + other=0, + eviction_policy="evict_first", + ) + tl.store(out_a_flat + out_a_idx * 1, slice_a, mask_a) + + # Compute indices for part B + input_idx_b = ( + tl.cast((input_start_i32 + row_subscript) * D_const, tl.int32) + + feature_indices[None, :] + ) + + row_minus_len_a = row_subscript - len_a_i32 + out_b_row = out_b_start_i32 + row_minus_len_a + out_b_idx = ( + tl.cast(out_b_row * D_const, tl.int32) + feature_indices[None, :] + ) + + mask_b = is_part_b[:, None] & feature_mask[None, :] + + # Load and store part B data + slice_b = tl.load( + values_flat + input_idx_b * 1, + mask_b, + other=0, + eviction_policy="evict_first", + ) + tl.store(out_b_flat + out_b_idx * 1, slice_b, mask_b) + + +def helion_split_2D_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + dense_size: int = 0, # noqa: F841 +) -> Tuple[torch.Tensor, torch.Tensor]: + D = values.size(1) + + # Select dtype-specific optimal parameters + if values.dtype == torch.float32: + # FP32-optimized parameters + block_size_0 = 64 + block_size_1 = 64 + num_warps = 4 + num_stages = 4 + else: + # BF16/FP16-optimized parameters + block_size_0 = 128 + block_size_1 = triton.next_power_of_2(D) + num_warps = 32 + num_stages = 7 + + return _helion_split_2d_jagged( + values, + max_seq_len, + offsets_a, + offsets_b, + dense_size, + block_size_0=block_size_0, + block_size_1=block_size_1, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def _helion_split_2d_jagged( + values: torch.Tensor, + max_seq_len: int, + offsets_a: torch.Tensor, + offsets_b: torch.Tensor, + dense_size: int, # noqa: F841 + block_size_0: int = 64, + block_size_1: int = 64, + num_warps: int = 4, + num_stages: int = 4, +) -> Tuple[torch.Tensor, torch.Tensor]: + num_batches = offsets_a.size(0) - 1 + D = values.size(1) + num_seq_blocks = (max_seq_len + block_size_0 - 1) // block_size_0 + total_len_a = int(offsets_a[-1].item()) + total_len_b = int(offsets_b[-1].item()) + out_a = torch.empty([total_len_a, D], dtype=values.dtype, device=values.device) + out_b = torch.empty([total_len_b, D], dtype=values.dtype, device=values.device) + values_flat = values.view(-1) + out_a_flat = out_a.view(-1) + out_b_flat = out_b.view(-1) + total_programs = num_batches * num_seq_blocks + + # pyre-ignore[28] + _helion_split_2d_jagged_kernel[(total_programs,)]( + offsets_a, + offsets_b, + values_flat, + out_a_flat, + out_b_flat, + max_seq_len, + D, + block_size_0, + block_size_1, + num_warps=num_warps, + num_stages=num_stages, + ) + return (out_a, out_b) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged_tensors.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged_tensors.py new file mode 100644 index 0000000000..7fd79ad99d --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_jagged_tensors.py @@ -0,0 +1,1062 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.utils import is_sm100 + + +def _triton_concat_2D_jagged_internal( + values_a: torch.Tensor, + values_b: torch.Tensor, + values_out: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: Optional[int], + max_len_b: Optional[int], + D: int, + n_prefix_from_B: int, + is_dense_a: bool, + is_dense_b: bool, + BLOCK_D: int, +) -> None: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + ValuesA=values_a, + ValuesB=values_b, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + _concat_2D_jagged[(max_seq_len, B)]( + ValuesA=values_a, + ValuesB=values_b, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + Out=values_out, + D=D, + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + + +def _triton_split_2D_jagged_internal( + jagged_in: torch.Tensor, + max_seq_len: int, + B: int, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: Optional[int], + max_len_b: Optional[int], + out_a: torch.Tensor, + out_b: torch.Tensor, + D: int, + n_prefix_to_B: int, + is_dense_a: bool, + is_dense_b: bool, + BLOCK_D: int, +) -> None: + if is_sm100(): + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix_to_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + else: + _split_2D_jagged[(max_seq_len, B)]( + JaggedIn=jagged_in, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=max_len_a, + MaxLenB=max_len_b, + OutA=out_a, + OutB=out_b, + D=D, + stride_id=jagged_in.stride(0), + stride_ad=out_a.stride(0), + stride_bd=out_b.stride(0), + n_prefix_to_B=n_prefix_to_B, + IS_DENSE_A=is_dense_a, # pyre-ignore[6] + IS_DENSE_B=is_dense_b, # pyre-ignore[6] + BLOCK_D=BLOCK_D, # pyre-ignore[6] + ) + + +def _get_concat_split_2d_jagged_multirow_configs(): + configs = [] + for BLOCK_N in [1, 2, 4, 8]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton.jit +def _concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_z = tl.program_id(1) + block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + + start_n = block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_mask = offs_n < seq_len + + out_seq_start = seq_start_a + seq_start_b + offs_n + out_ptrs = Out + out_seq_start[:, None].to(tl.int64) * stride_od + offs_d[None, :] + + from_prefix_b_mask = (offs_n < n_prefix_from_B) & valid_mask + from_a_mask = ( + (offs_n >= n_prefix_from_B) + & (offs_n < seq_len_a + n_prefix_from_B) + & valid_mask + ) + from_suffix_b_mask = (offs_n >= seq_len_a + n_prefix_from_B) & valid_mask + + in_b1_ptrs = ( + ValuesB + + (offs_n[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + v_b1 = tl.load( + in_b1_ptrs, mask=from_prefix_b_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_b1, mask=from_prefix_b_mask[:, None] & (offs_d[None, :] < D)) + + off_a = offs_n - n_prefix_from_B + in_a_ptrs = ( + ValuesA + + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + + offs_d[None, :] + ) + v_a = tl.load( + in_a_ptrs, mask=from_a_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_a, mask=from_a_mask[:, None] & (offs_d[None, :] < D)) + + off_b = offs_n - seq_len_a + in_b2_ptrs = ( + ValuesB + + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + v_b2 = tl.load( + in_b2_ptrs, mask=from_suffix_b_mask[:, None] & (offs_d[None, :] < D), other=0.0 + ) + tl.store(out_ptrs, v_b2, mask=from_suffix_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_concat_split_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _concat_2D_jagged_multirow( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + ) + + +@triton.jit +def _split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_z = tl.program_id(1) + block_n = tl.program_id(0) + + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + seq_start = seq_start_a + seq_start_b + + start_n = block_n * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_mask = offs_n < seq_len + + in_ptrs = ( + JaggedIn + + (seq_start + offs_n[:, None]).to(tl.int64) * stride_id + + offs_d[None, :] + ) + + v = tl.load(in_ptrs, mask=valid_mask[:, None] & (offs_d[None, :] < D), other=0.0) + + to_prefix_b_mask = (offs_n < n_prefix_to_B) & valid_mask + to_a_mask = ( + (offs_n >= n_prefix_to_B) & (offs_n < seq_len_a + n_prefix_to_B) & valid_mask + ) + to_suffix_b_mask = (offs_n >= seq_len_a + n_prefix_to_B) & valid_mask + + out_b1_ptrs = ( + OutB + + (offs_n[:, None] + seq_start_b).to(tl.int64) * stride_bd + + offs_d[None, :] + ) + tl.store(out_b1_ptrs, v, mask=to_prefix_b_mask[:, None] & (offs_d[None, :] < D)) + + off_a = offs_n - n_prefix_to_B + out_a_ptrs = ( + OutA + (off_a[:, None] + seq_start_a).to(tl.int64) * stride_ad + offs_d[None, :] + ) + tl.store(out_a_ptrs, v, mask=to_a_mask[:, None] & (offs_d[None, :] < D)) + + off_b = offs_n - seq_len_a + out_b2_ptrs = ( + OutB + (off_b[:, None] + seq_start_b).to(tl.int64) * stride_bd + offs_d[None, :] + ) + tl.store(out_b2_ptrs, v, mask=to_suffix_b_mask[:, None] & (offs_d[None, :] < D)) + + +@triton_autotune( + configs=_get_concat_split_2d_jagged_multirow_configs(), + key=["BLOCK_D"], +) +@triton.jit +def split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + _split_2D_jagged_multirow( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A, + IS_DENSE_B, + BLOCK_D, + BLOCK_N, + ) + + +@triton.jit +def _concat_2D_jagged( + ValuesA, + ValuesB, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + Out, + D, + stride_ad, + stride_bd, + stride_od, + n_prefix_from_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + offs_d = tl.arange(0, BLOCK_D) + out_seq_start = seq_start_a + seq_start_b + off_n + out_ptrs = Out + out_seq_start.to(tl.int64) * stride_od + offs_d + if off_n < n_prefix_from_B: + in_ptrs = ValuesB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d + elif off_n < seq_len_a + n_prefix_from_B: + in_ptrs = ( + ValuesA + + (off_n - n_prefix_from_B + seq_start_a).to(tl.int64) * stride_ad + + offs_d + ) + else: + in_ptrs = ( + ValuesB + + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + + offs_d + ) + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +@triton.jit +def _split_2D_jagged( + JaggedIn, + OffsetsA, + OffsetsB, + MaxLenA, + MaxLenB, + OutA, + OutB, + D, + stride_id, + stride_ad, + stride_bd, + n_prefix_to_B, + IS_DENSE_A: tl.constexpr, + IS_DENSE_B: tl.constexpr, + BLOCK_D: tl.constexpr, +): + off_z = tl.program_id(1) + off_n = tl.program_id(0) + if IS_DENSE_A: + seq_start_a = off_z * MaxLenA + seq_len_a = MaxLenA + else: + seq_start_a = tl.load(OffsetsA + off_z) + seq_end_a = tl.load(OffsetsA + off_z + 1) + seq_len_a = seq_end_a - seq_start_a + if IS_DENSE_B: + seq_start_b = off_z * MaxLenB + seq_len_b = MaxLenB + else: + seq_start_b = tl.load(OffsetsB + off_z) + seq_end_b = tl.load(OffsetsB + off_z + 1) + seq_len_b = seq_end_b - seq_start_b + seq_len = seq_len_a + seq_len_b + if off_n >= seq_len: + return + seq_start = seq_start_a + seq_start_b + offs_d = tl.arange(0, BLOCK_D) + in_ptrs = JaggedIn + (seq_start + off_n).to(tl.int64) * stride_id + offs_d + if off_n < n_prefix_to_B: + out_ptrs = OutB + (off_n + seq_start_b).to(tl.int64) * stride_bd + offs_d + elif off_n < seq_len_a + n_prefix_to_B: + out_ptrs = ( + OutA + + (off_n - n_prefix_to_B + seq_start_a).to(tl.int64) * stride_ad + + offs_d + ) + else: + out_ptrs = ( + OutB + (off_n - seq_len_a + seq_start_b).to(tl.int64) * stride_bd + offs_d + ) + v = tl.load(in_ptrs, mask=offs_d < D) + tl.store(out_ptrs, v, mask=offs_d < D) + + +class _Concat2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + max_len_a: Optional[int], + max_len_b: Optional[int], + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + n_prefix_from_B: int, + ): + values_a = switch_to_contiguous_if_needed(values_a) + values_b = switch_to_contiguous_if_needed(values_b) + is_dense_a = offsets_a is None + is_dense_b = offsets_b is None + total_len_a, D = values_a.shape + total_len_b, _ = values_b.shape + if is_dense_a: + assert max_len_a is not None + B = total_len_a // max_len_a + else: + assert offsets_a is not None + B = offsets_a.shape[0] - 1 + if is_dense_b: + assert max_len_b is not None + B = total_len_b // max_len_b + else: + assert offsets_b is not None + B = offsets_b.shape[0] - 1 + total_seq_len = total_len_a + total_len_b + BLOCK_D = triton.next_power_of_2(D) + values_out = torch.empty( + (total_seq_len, D), device=values_a.device, dtype=values_a.dtype + ) + _triton_concat_2D_jagged_internal( + values_a=values_a, + values_b=values_b, + values_out=values_out, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=max_len_a, + max_len_b=max_len_b, + D=D, + n_prefix_from_B=n_prefix_from_B, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.total_len_a = total_len_a + ctx.total_len_b = total_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.max_len_a = max_len_a + ctx.max_len_b = max_len_b + ctx.B = B + ctx.n_prefix_from_B = n_prefix_from_B + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + d_values_a = torch.zeros( + (ctx.total_len_a, D), device=d_out.device, dtype=d_out.dtype + ) + d_values_b = torch.empty( + (ctx.total_len_b, D), device=d_out.device, dtype=d_out.dtype + ) + _split_2D_jagged[(ctx.max_seq_len, ctx.B)]( + JaggedIn=d_out, + OffsetsA=offsets_a, + OffsetsB=offsets_b, + MaxLenA=ctx.max_len_a, + MaxLenB=ctx.max_len_b, + OutA=d_values_a, + OutB=d_values_b, + D=D, + stride_id=d_out.stride(-2), + stride_ad=d_values_a.stride(-2), + stride_bd=d_values_b.stride(-2), + n_prefix_to_B=ctx.n_prefix_from_B, + BLOCK_D=BLOCK_D, + IS_DENSE_A=ctx.is_dense_a, + IS_DENSE_B=ctx.is_dense_b, + ) + return None, d_values_a, d_values_b, None, None, None, None, None + + +class _Split2DJaggedFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_a: Optional[int], + max_len_b: Optional[int], + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + n_prefix_to_B: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_a: bool = offsets_a is None + is_dense_b: bool = offsets_b is None + total_seq_len, D = values.shape + if is_dense_a: + assert is_dense_b is False + assert offsets_b is not None + assert max_len_a is not None + B = offsets_b.shape[0] - 1 + total_len_a = max_len_a * B + total_len_b = total_seq_len - total_len_a + elif is_dense_b: + assert is_dense_a is False + assert offsets_a is not None + assert max_len_b is not None + B = offsets_a.shape[0] - 1 + total_len_b = max_len_b * B + total_len_a = total_seq_len - total_len_b + else: + assert offsets_a is not None and offsets_b is not None + B = offsets_a.shape[0] - 1 + if total_len_left is not None and total_len_right is not None: + assert total_len_left + total_len_right == total_seq_len + total_len_a = total_len_left + total_len_b = total_len_right + else: + total_len_a = int(offsets_a[-1].item()) + total_len_b = values.size(0) - total_len_a + _, D = values.shape + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.empty( + (total_len_a, D), device=values.device, dtype=values.dtype + ) + values_b = torch.empty( + (total_len_b, D), device=values.device, dtype=values.dtype + ) + _triton_split_2D_jagged_internal( + jagged_in=values, + max_seq_len=max_seq_len, + B=B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=max_len_a, + max_len_b=max_len_b, + out_a=values_a, + out_b=values_b, + D=D, + n_prefix_to_B=n_prefix_to_B, + is_dense_a=is_dense_a, + is_dense_b=is_dense_b, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_a, offsets_b) + ctx.max_seq_len = max_seq_len + ctx.total_seq_len = total_seq_len + ctx.max_len_a = max_len_a + ctx.max_len_b = max_len_b + ctx.is_dense_a = is_dense_a + ctx.is_dense_b = is_dense_b + ctx.B = B + ctx.D = D + ctx.n_prefix_to_B = n_prefix_to_B + return values_a, values_b + + @staticmethod + def backward( + ctx, *d_values + ) -> Tuple[None, torch.Tensor, None, None, None, None, None, None, None]: + offsets_a, offsets_b = ctx.saved_tensors + d_values_a, d_values_b = d_values + BLOCK_D = triton.next_power_of_2(ctx.D) + d_jagged_in = torch.empty( + (ctx.total_seq_len, ctx.D), + device=d_values_a.device, + dtype=d_values_a.dtype, + ) + _triton_concat_2D_jagged_internal( + values_a=d_values_a, + values_b=d_values_b, + values_out=d_jagged_in, + max_seq_len=ctx.max_seq_len, + B=ctx.B, + offsets_a=offsets_a, + offsets_b=offsets_b, + max_len_a=ctx.max_len_a, + max_len_b=ctx.max_len_b, + D=ctx.D, + n_prefix_from_B=ctx.n_prefix_to_B, + is_dense_a=ctx.is_dense_a, + is_dense_b=ctx.is_dense_b, + BLOCK_D=BLOCK_D, + ) + + return None, d_jagged_in, None, None, None, None, None, None, None + + +@torch.fx.wrap +def triton_concat_2D_jagged( + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_from_right: int = 0, +) -> torch.Tensor: + return _Concat2DJaggedFunction.apply( + max_seq_len, + values_left, + values_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + n_prefix_from_right, + ) + + +@torch.fx.wrap +def triton_split_2D_jagged( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_to_right: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedFunction.apply( + max_seq_len, + values, + total_len_left, + total_len_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + n_prefix_to_right, + ) + + +class _Concat2DJaggedMultirowFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values_left: torch.Tensor, + values_right: torch.Tensor, + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + n_prefix_from_right: int, + ) -> torch.Tensor: + values_left = switch_to_contiguous_if_needed(values_left) + values_right = switch_to_contiguous_if_needed(values_right) + is_dense_left = offsets_left is None + is_dense_right = offsets_right is None + total_len_left, D = values_left.shape + total_len_right, _ = values_right.shape + if is_dense_left: + assert max_len_left is not None + B = total_len_left // max_len_left + else: + assert offsets_left is not None + B = offsets_left.shape[0] - 1 + if is_dense_right: + assert max_len_right is not None + B = total_len_right // max_len_right + else: + assert offsets_right is not None + B = offsets_right.shape[0] - 1 + total_seq_len = total_len_left + total_len_right + BLOCK_D = triton.next_power_of_2(D) + values_out = torch.empty( + (total_seq_len, D), device=values_left.device, dtype=values_left.dtype + ) + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + concat_2D_jagged_multirow[grid]( + ValuesA=values_left, + ValuesB=values_right, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=max_len_left, + MaxLenB=max_len_right, + Out=values_out, + D=D, + stride_ad=values_left.stride(-2), + stride_bd=values_right.stride(-2), + stride_od=values_out.stride(-2), + n_prefix_from_B=n_prefix_from_right, + IS_DENSE_A=is_dense_left, + IS_DENSE_B=is_dense_right, + BLOCK_D=BLOCK_D, + ) + ctx.save_for_backward(offsets_left, offsets_right) + ctx.max_seq_len = max_seq_len + ctx.total_len_left = total_len_left + ctx.total_len_right = total_len_right + ctx.is_dense_left = is_dense_left + ctx.is_dense_right = is_dense_right + ctx.max_len_left = max_len_left + ctx.max_len_right = max_len_right + ctx.B = B + ctx.n_prefix_from_right = n_prefix_from_right + return values_out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[None, torch.Tensor, torch.Tensor, None, None, None, None, None]: + offsets_left, offsets_right = ctx.saved_tensors + _, D = d_out.shape + BLOCK_D = triton.next_power_of_2(D) + d_values_left = torch.zeros( + (ctx.total_len_left, D), device=d_out.device, dtype=d_out.dtype + ) + d_values_right = torch.empty( + (ctx.total_len_right, D), device=d_out.device, dtype=d_out.dtype + ) + + def grid(meta): + return (triton.cdiv(ctx.max_seq_len, meta["BLOCK_N"]), ctx.B) + + split_2D_jagged_multirow[grid]( + JaggedIn=d_out, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=ctx.max_len_left, + MaxLenB=ctx.max_len_right, + OutA=d_values_left, + OutB=d_values_right, + D=D, + stride_id=d_out.stride(-2), + stride_ad=d_values_left.stride(-2), + stride_bd=d_values_right.stride(-2), + n_prefix_to_B=ctx.n_prefix_from_right, + IS_DENSE_A=ctx.is_dense_left, + IS_DENSE_B=ctx.is_dense_right, + BLOCK_D=BLOCK_D, + ) + return None, d_values_left, d_values_right, None, None, None, None, None + + +class _Split2DJaggedMultirowFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int], + total_len_right: Optional[int], + max_len_left: Optional[int], + max_len_right: Optional[int], + offsets_left: Optional[torch.Tensor], + offsets_right: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + values = switch_to_contiguous_if_needed(values) + is_dense_left: bool = offsets_left is None + is_dense_right: bool = offsets_right is None + total_seq_len, D = values.shape + + if is_dense_left: + assert is_dense_right is False + assert offsets_right is not None + assert max_len_left is not None + B = offsets_right.shape[0] - 1 + total_len_a = max_len_left * B + total_len_b = total_seq_len - total_len_a + elif is_dense_right: + assert is_dense_left is False + assert offsets_left is not None + assert max_len_right is not None + B = offsets_left.shape[0] - 1 + total_len_b = max_len_right * B + total_len_a = total_seq_len - total_len_b + else: + assert offsets_left is not None and offsets_right is not None + B = offsets_left.shape[0] - 1 + if total_len_left is not None and total_len_right is not None: + assert total_len_left + total_len_right == total_seq_len + total_len_a = total_len_left + total_len_b = total_len_right + else: + total_len_a = int(offsets_left[-1].item()) + total_len_b = values.size(0) - total_len_a + + BLOCK_D = triton.next_power_of_2(D) + values_a = torch.empty( + (total_len_a, D), device=values.device, dtype=values.dtype + ) + values_b = torch.empty( + (total_len_b, D), device=values.device, dtype=values.dtype + ) + + def grid(meta): + return (triton.cdiv(max_seq_len, meta["BLOCK_N"]), B) + + split_2D_jagged_multirow[grid]( + JaggedIn=values, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=max_len_left, + MaxLenB=max_len_right, + OutA=values_a, + OutB=values_b, + D=D, + stride_id=values.stride(-2), + stride_ad=values_a.stride(-2), + stride_bd=values_b.stride(-2), + n_prefix_to_B=0, + IS_DENSE_A=is_dense_left, + IS_DENSE_B=is_dense_right, + BLOCK_D=BLOCK_D, + ) + + ctx.save_for_backward(offsets_left, offsets_right) + ctx.max_seq_len = max_seq_len + ctx.total_seq_len = total_seq_len + ctx.max_len_left = max_len_left + ctx.max_len_right = max_len_right + ctx.is_dense_left = is_dense_left + ctx.is_dense_right = is_dense_right + ctx.B = B + ctx.D = D + + return values_a, values_b + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, *d_values + ) -> Tuple[None, torch.Tensor, None, None, None, None, None, None]: + offsets_left, offsets_right = ctx.saved_tensors + d_values_a, d_values_b = d_values + BLOCK_D = triton.next_power_of_2(ctx.D) + d_jagged_in = torch.empty( + (ctx.total_seq_len, ctx.D), + device=d_values_a.device, + dtype=d_values_a.dtype, + ) + + def grid(meta): + return (triton.cdiv(ctx.max_seq_len, meta["BLOCK_N"]), ctx.B) + + concat_2D_jagged_multirow[grid]( + ValuesA=d_values_a, + ValuesB=d_values_b, + OffsetsA=offsets_left, + OffsetsB=offsets_right, + MaxLenA=ctx.max_len_left, + MaxLenB=ctx.max_len_right, + Out=d_jagged_in, + D=ctx.D, + stride_ad=d_values_a.stride(-2), + stride_bd=d_values_b.stride(-2), + stride_od=d_jagged_in.stride(-2), + n_prefix_from_B=0, + IS_DENSE_A=ctx.is_dense_left, + IS_DENSE_B=ctx.is_dense_right, + BLOCK_D=BLOCK_D, + ) + + return None, d_jagged_in, None, None, None, None, None, None + + +@torch.fx.wrap +def triton_concat_2D_jagged_multirow( + max_seq_len: int, + values_a: torch.Tensor, + values_b: torch.Tensor, + offsets_a: Optional[torch.Tensor], + offsets_b: Optional[torch.Tensor], + max_len_a: int, + max_len_b: int, +) -> torch.Tensor: + return _Concat2DJaggedMultirowFunction.apply( + max_seq_len, + values_a, + values_b, + max_len_a, + max_len_b, + offsets_a, + offsets_b, + 0, + ) + + +@torch.fx.wrap +def triton_split_2D_jagged_multirow( + max_seq_len: int, + values: torch.Tensor, + total_len_left: Optional[int] = None, + total_len_right: Optional[int] = None, + max_len_left: Optional[int] = None, + max_len_right: Optional[int] = None, + offsets_left: Optional[torch.Tensor] = None, + offsets_right: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _Split2DJaggedMultirowFunction.apply( + max_seq_len, + values, + total_len_left, + total_len_right, + max_len_left, + max_len_right, + offsets_left, + offsets_right, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_layer_norm.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_layer_norm.py new file mode 100644 index 0000000000..2327ab14c6 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_layer_norm.py @@ -0,0 +1,1222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +from generative_recommenders.common import ( + switch_to_contiguous_if_needed, + triton_autotune, +) +from generative_recommenders.ops.utils import is_sm100 + +try: + # @manual=//triton:triton + from triton.language.extra.libdevice import fast_dividef +except ImportError: + try: + # @manual=//triton:triton + from triton.language.extra.cuda.libdevice import fast_dividef + except ImportError: + # pyre-ignore: Undefined import [21] + # @manual=//triton:triton + from triton.language.math import fast_dividef + + +def _get_layer_norm_fwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm kernels.""" + configs = [] + for BLOCK_N in [1, 2, 4, 8, 16]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +def _bwd_pre_hook(nargs): + nargs["DW"].zero_() + if "DB" in nargs: + nargs["DB"].zero_() + + +def _get_norm_bwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row LayerNorm kernels.""" + configs = [] + for BLOCK_N in [1, 4, 8, 16]: + for num_warps in [2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + pre_hook=_bwd_pre_hook, + ) + ) + return configs + + +@triton_autotune( + configs=_get_layer_norm_fwd_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _layer_norm_fwd( + X, + Y, + Mean, + Rstd, + N, + D, + eps, + stride_x, + stride_y, + TRAINING: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + COMPUTE_MEAN_AND_RSTD: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + if COMPUTE_MEAN_AND_RSTD: + mean = tl.sum(x_block, axis=1) / D + if TRAINING: + tl.store(Mean + rows, mean, row_mask) + mean = tl.expand_dims(mean, 1) + else: + mean = tl.load(Mean + rows, row_mask, other=0.0) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(row_mask[:, None] & col_mask[None, :], x_mean, 0.0) + + if COMPUTE_MEAN_AND_RSTD: + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + if TRAINING: + tl.store(Rstd + rows, rstd, row_mask) + else: + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + rstd = tl.expand_dims(rstd, 1) + y = x_mean * rstd + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton_autotune( + configs=_get_layer_norm_fwd_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _weighted_layer_norm_fwd( + X, + Y, + W, + B, + Mean, + Rstd, + N, + D, + eps, + stride_x, + stride_y, + IS_SWISH: tl.constexpr, + TRAINING: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, + COMPUTE_MEAN_AND_RSTD: tl.constexpr, +): + # Get the block ID and calculate starting row + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Load weight and bias once (shared across all rows in this block) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + b = tl.load(B + cols, mask=col_mask, other=0.0).to(tl.float32) + + # Create block pointers for X and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + if COMPUTE_MEAN_AND_RSTD: + mean = tl.sum(x_block, axis=1) / D + if TRAINING: + tl.store(Mean + rows, mean, row_mask) + mean = tl.expand_dims(mean, 1) + else: + mean = tl.load(Mean + rows, row_mask, other=0.0) + mean = tl.expand_dims(mean, 1) + + x_mean = x_block - mean + x_mean = tl.where(row_mask[:, None] & col_mask[None, :], x_mean, 0.0) + + if COMPUTE_MEAN_AND_RSTD: + _var = x_mean * x_mean + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + if TRAINING: + tl.store(Rstd + rows, rstd, row_mask) + else: + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + rstd = tl.expand_dims(rstd, 1) + y = x_mean * rstd + y = y * w[None, :] + b[None, :] + + if IS_SWISH: + y = tl.sigmoid(y) * x_block + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _layer_norm_bwd_dx( + DX, + DY, + X, + Mean, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + mask = cols < D + X += row.to(tl.int64) * stride_x + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + mean = tl.load(Mean + row) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = (x - mean) * rstd + xhat = tl.where(mask, xhat, 0.0) + dy = tl.where(mask, dy, 0.0) + c1 = tl.sum(xhat * dy, axis=0) / D + c2 = tl.sum(dy, axis=0) / D + dx = (dy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + +@triton_autotune( + configs=_get_layer_norm_fwd_configs(), + key=["BLOCK_D"], +) +@triton.jit +def _weighted_layer_norm_bwd_dx( + DX, + DY, + DW, + DB, + X, + W, + B, + Mean, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + IS_SWISH: tl.constexpr, + N, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + num_blocks = tl.cdiv(N, BLOCK_N) + blocks_per_tile = num_blocks // tile_num + if pid < num_blocks % tile_num: + blocks_per_tile += 1 + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + acc_dw = tl.zeros([BLOCK_D], dtype=tl.float32) + acc_db = tl.zeros([BLOCK_D], dtype=tl.float32) + + start_block = pid + + for idx in range(blocks_per_tile): + current_block = start_block + idx * tile_num + start_row = current_block * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DX_block_ptr = tl.make_block_ptr( + base=DX, + shape=(N, D), + strides=(stride_dx, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DY_block_ptr = tl.make_block_ptr( + base=DY, + shape=(N, D), + strides=(stride_dy, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + dy_block = tl.load( + DY_block_ptr, boundary_check=(0, 1), padding_option="zero" + ).to(tl.float32) + + # Load mean and rstd for all rows in this block + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + mean = tl.load(Mean + rows, row_mask, other=0.0) + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + # Expand dimensions for broadcasting + mean = tl.expand_dims(mean, 1) + rstd = tl.expand_dims(rstd, 1) + + xhat = (x_block - mean) * rstd + + xhat = tl.where(row_mask[:, None] & col_mask[None, :], xhat, 0.0) + wdy = w[None, :] * dy_block + wdy = tl.where(row_mask[:, None] & col_mask[None, :], wdy, 0.0) + + # Compute dx + if IS_SWISH: + b = tl.load(B + cols, mask=col_mask, other=0.0).to(tl.float32) + sigmoid_layer_norm = tl.sigmoid(xhat * w[None, :] + b[None, :]) + sigmoid_layer_norm = tl.where( + row_mask[:, None] & col_mask[None, :], sigmoid_layer_norm, 0.0 + ) + + sigmoid_deriv = sigmoid_layer_norm * (1 - sigmoid_layer_norm) + x_ = wdy * x_block * sigmoid_deriv + x_ = tl.where(row_mask[:, None] & col_mask[None, :], x_, 0.0) + + c1 = tl.sum(xhat * x_, axis=1) / D + c2 = tl.sum(x_, axis=1) / D + c1 = tl.expand_dims(c1, 1) + c2 = tl.expand_dims(c2, 1) + dx = (x_ - (xhat * c1 + c2)) * rstd + + dx = dy_block * sigmoid_layer_norm + dx + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + partial_dw = tl.sum(dy_block * x_block * xhat * sigmoid_deriv, axis=0) + partial_db = tl.sum(dy_block * x_block * sigmoid_deriv, axis=0) + else: + c1 = tl.sum(xhat * wdy, axis=1) / D + c2 = tl.sum(wdy, axis=1) / D + c1 = tl.expand_dims(c1, 1) + c2 = tl.expand_dims(c2, 1) + dx = (wdy - (xhat * c1 + c2)) * rstd + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + partial_dw = tl.sum(dy_block * xhat, axis=0) + partial_db = tl.sum(dy_block, axis=0) + + # Accumulate partial sums in shared memory + acc_dw += partial_dw + acc_db += partial_db + + # Store accumulated sums back to global memory + dw_ptrs = DW + pid.to(tl.int64) * D + cols + db_ptrs = DB + pid.to(tl.int64) * D + cols + tl.store(dw_ptrs, acc_dw, mask=col_mask) + tl.store(db_ptrs, acc_db, mask=col_mask) + + +def _get_bwd_dwdb_configs() -> List[triton.Config]: + configs = [] + BLOCK_N_CHOICES = [32, 64, 128, 256] + if is_sm100(): + BLOCK_N_CHOICES = [128, 256, 512, 1024] + for BLOCK_N in BLOCK_N_CHOICES: + for num_warps in [8, 16] + ([] if torch.ops.hip else [32]): + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _layer_norm_bwd_dwdb( + DW, + DB, + FINAL_DW, + FINAL_DB, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + db = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + db += tl.load(DB + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + tl.store(FINAL_DB + cols, sum_db.to(FINAL_DB.dtype.element_ty), mask=cols < D) + + +def triton_weighted_layer_norm_fwd( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + mean: Optional[torch.Tensor] = None, + rstd: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]: + assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2" + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + learnable = weight is not None + if learnable: + assert bias is not None and weight is not None + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + y = torch.empty_like(x) + compute_mean_and_rstd = mean is None or rstd is None + if mean is None: + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + if rstd is None: + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D: int = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if N == 0: + return y, mean, rstd, BLOCK_D + + # pyre-ignore[28] + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + ) + if learnable: + _weighted_layer_norm_fwd[grid]( + x, + y, + weight, + bias, + mean, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + IS_SWISH=False, + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd, + ) + else: + _layer_norm_fwd[grid]( + x, + y, + mean, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=compute_mean_and_rstd, + ) + + return y, mean, rstd, BLOCK_D + + +def triton_weighted_layer_norm_bwd( + dy: torch.Tensor, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + mean: torch.Tensor, + rstd: torch.Tensor, + learnable: bool, + eps: float, + BLOCK_D: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + num_warps: int = min(max(BLOCK_D // 256, 1), 8) + if learnable: + assert weight is not None and bias is not None + N, D = x.shape + dx = torch.empty_like(x) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 8, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + dbias.zero_() + return dx, dweight, dbias + # pyre-ignore[28] + _weighted_layer_norm_bwd_dx[(tile_num,)]( + dx, + dy, + _dweight, + _dbias, + x, + weight, + bias, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + eps, + IS_SWISH=False, + N=N, + BLOCK_D=BLOCK_D, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D = min(max(BLOCK_D, 4), 128) + _layer_norm_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D, + ) + + return dx, dweight, dbias + else: + N, D = x.shape + dx = torch.empty_like(x) + if N == 0: + return dx, None, None + # pyre-ignore[28] + _layer_norm_bwd_dx[(N,)]( + dx, + dy, + x, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + eps, + BLOCK_D=BLOCK_D, + num_warps=num_warps, + ) + return dx, None, None + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, + ) -> torch.Tensor: + y, mean, rstd, BLOCK_D = triton_weighted_layer_norm_fwd( + x=x, + weight=weight, + bias=bias, + eps=eps, + ) + learnable = weight is not None + if learnable: + ctx.save_for_backward(x, weight, bias, mean, rstd) + else: + ctx.save_for_backward(x, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.eps = eps + ctx.learnable = learnable + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]: + if ctx.learnable: + x, weight, bias, mean, rstd = ctx.saved_tensors + else: + x, mean, rstd = ctx.saved_tensors + weight, bias = None, None + dx, dweight, dbias = triton_weighted_layer_norm_bwd( + dy=dy, + x=x, + weight=weight, + bias=bias, + mean=mean, + rstd=rstd, + learnable=ctx.learnable, + eps=ctx.eps, + BLOCK_D=ctx.BLOCK_D, + ) + return dx, dweight, dbias, None + + +def _get_rms_norm_fwd_configs() -> List[triton.Config]: + """Generate autotune configs for multi-row RMSNorm kernels.""" + configs = [] + for BLOCK_N in [1, 2, 4, 8, 16]: + for num_warps in [1, 2, 4]: + configs.append( + triton.Config( + {"BLOCK_N": BLOCK_N}, + num_warps=num_warps, + ) + ) + return configs + + +@triton.autotune( + configs=_get_rms_norm_fwd_configs(), + key=["BLOCK_D", "SILU"], +) +@triton.jit +def _weighted_rms_norm_fwd( + X, + Y, + W, + Rstd, + N, + D, + eps, + stride_x, + stride_y, + SILU: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + block_id = tl.program_id(0) + start_row = block_id * BLOCK_N + + # Load weight once (shared across all rows in this block) + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + # Create block pointers for X and Y + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + Y_block_ptr = tl.make_block_ptr( + base=Y, + shape=(N, D), + strides=(stride_y, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + + # Compute variance (RMS norm uses x directly, not x - mean) + x_masked = tl.where(row_mask[:, None] & col_mask[None, :], x_block, 0.0) + _var = x_masked * x_masked + var = tl.sum(_var, axis=1) / D + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + rows, rstd, row_mask) + + # Normalize and apply linear transformation + rstd = tl.expand_dims(rstd, 1) + y = x_block * rstd + y = y * w[None, :] + + if SILU: + # pyre-ignore[16]: Module `triton.language.math` has no attribute `fast_dividef` + y = fast_dividef(y, 1.0 + tl.exp(-y)) + + tl.store(Y_block_ptr, y.to(Y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def _weighted_rms_norm_bwd_dx( + DX, + DY, + DW, + X, + W, + Rstd, + Lock, + stride_dx, + stride_dy, + stride_x, + D, + eps, + GROUP_N, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_D) + mask = cols < D + X += row.to(tl.int64) * stride_x + DY += row.to(tl.int64) * stride_dy + DX += row.to(tl.int64) * stride_dx + + # Load data to SRAM + x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) + dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) + rstd = tl.load(Rstd + row) + + # Compute dx + xhat = x * rstd + w = tl.load(W + cols, mask=mask).to(tl.float32) + wdy = w * dy + + xhat = tl.where(mask, xhat, 0.0) + wdy = tl.where(mask, wdy, 0.0) + c1 = tl.sum(xhat * wdy, axis=0) / D + dx = (wdy - (xhat * c1)) * rstd + # Write dx + tl.store(DX + cols, dx, mask=mask) + + # Offset locks and weights/biases gradient pointer for parallel reduction + lock_id = row % GROUP_N + Lock += lock_id + Count = Lock + GROUP_N + DW = DW + lock_id * D + cols + # Accumulate partial sums for dw/db + partial_dw = dy * xhat + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + # First store doesn't accumulate + if count == 0: + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(DW, mask=mask) + tl.store(DW, partial_dw, mask=mask) + # Release the lock + tl.atomic_xchg(Lock, 0) + + +@triton_autotune( + configs=_get_norm_bwd_configs(), + key=["BLOCK_D", "SILU"], +) +@triton.jit +def _weighted_rms_norm_bwd( + DX, + DY, + DW, + X, + W, + Rstd, + stride_dx, + stride_dy, + stride_x, + D, + eps, + N, + SILU: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + tile_num = tl.num_programs(0) + num_blocks = tl.cdiv(N, BLOCK_N) + blocks_per_tile = num_blocks // tile_num + if pid < num_blocks % tile_num: + blocks_per_tile += 1 + + cols = tl.arange(0, BLOCK_D) + col_mask = cols < D + w = tl.load(W + cols, mask=col_mask, other=0.0).to(tl.float32) + + start_block = pid + + acc_dw = tl.zeros([BLOCK_D], dtype=tl.float32) + + for idx in range(blocks_per_tile): + current_block = start_block + idx * tile_num + start_row = current_block * BLOCK_N + + X_block_ptr = tl.make_block_ptr( + base=X, + shape=(N, D), + strides=(stride_x, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DX_block_ptr = tl.make_block_ptr( + base=DX, + shape=(N, D), + strides=(stride_dx, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + DY_block_ptr = tl.make_block_ptr( + base=DY, + shape=(N, D), + strides=(stride_dy, 1), + offsets=(start_row, 0), + block_shape=(BLOCK_N, BLOCK_D), + order=(1, 0), + ) + + # Load data blocks + x_block = tl.load(X_block_ptr, boundary_check=(0, 1), padding_option="zero").to( + tl.float32 + ) + dy_block = tl.load( + DY_block_ptr, boundary_check=(0, 1), padding_option="zero" + ).to(tl.float32) + + # Load rstd for all rows in this block + rows = start_row + tl.arange(0, BLOCK_N) + row_mask = rows < N + rstd = tl.load(Rstd + rows, row_mask, other=0.0) + + # Expand dimensions for broadcasting + rstd = tl.expand_dims(rstd, 1) + + # Compute dx + xhat = x_block * rstd + + # Apply SILU backward if enabled + if SILU: + y_before_silu = xhat * w[None, :] + # pyre-fixme[16] + sig_y = fast_dividef(1.0, 1.0 + tl.exp(-y_before_silu)) + # SILU derivative: sigmoid(y) + y * sigmoid(y) * (1 - sigmoid(y)) + dy_block = dy_block * (sig_y + y_before_silu * sig_y * (1.0 - sig_y)) + + wdy = w[None, :] * dy_block + + c1 = tl.sum(xhat * wdy, axis=1) / D + c1 = tl.expand_dims(c1, 1) + dx = (wdy - (xhat * c1)) * rstd + + # Write dx + tl.store(DX_block_ptr, dx.to(DX.dtype.element_ty), boundary_check=(0, 1)) + + # Accumulate partial sums for dw + # Compute dw for all rows, then sum locally before atomic operation + partial_dw_block = dy_block * xhat + # Local reduction: sum across all rows in this block + partial_dw = tl.sum(partial_dw_block, axis=0) + acc_dw += partial_dw + + DW_ptr = DW + cols + tl.atomic_add(DW_ptr, acc_dw, col_mask) + + +@triton_autotune( + configs=_get_bwd_dwdb_configs(), + key=["D"], +) +@triton.jit +def _rms_norm_bwd_dwdb( + DW, + FINAL_DW, + N, + D, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid = tl.program_id(0) + cols = pid * BLOCK_D + tl.arange(0, BLOCK_D) + dw = tl.zeros((BLOCK_N, BLOCK_D), dtype=tl.float32) + + for i in range(0, N, BLOCK_N): + rows = i + tl.arange(0, BLOCK_N) + # pyre-fixme[16]: `int` has no attribute `__getitem__`. + mask = (rows[:, None] < N) & (cols[None, :] < D) + offs = rows[:, None] * D + cols[None, :] + dw += tl.load(DW + offs, mask=mask, other=0.0) + + sum_dw = tl.sum(dw, axis=0) + tl.store(FINAL_DW + cols, sum_dw.to(FINAL_DW.dtype.element_ty), mask=cols < D) + + +class RMSNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + eps: float, + silu: bool, + ) -> torch.Tensor: + assert x.dim() == 2 + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + assert weight.dim() == 1 + assert weight.numel() == D + + y = torch.empty_like(x) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_D = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BLOCK_D: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + ctx.save_for_backward(x, weight, rstd) + ctx.silu = silu + if N == 0: + return y + + # pyre-ignore[28] + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + ) + _weighted_rms_norm_fwd[grid]( + x, + y, + weight, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + SILU=silu, + BLOCK_D=BLOCK_D, + ) + + ctx.BLOCK_D = BLOCK_D + ctx.eps = eps + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], None, None]: + x, weight, rstd = ctx.saved_tensors + N, D = x.shape + dx = torch.empty_like(x) + if D <= 1024: + GROUP_N = 256 * 8 + elif D <= 4096: + GROUP_N = 128 * 8 + elif D <= 8192: + GROUP_N = 96 * 8 + else: + GROUP_N = 64 * 8 + GROUP_N = N if GROUP_N > N else GROUP_N + dweight = torch.zeros((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + return dx, dweight, None, None + + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 8, N // 4)) + + _weighted_rms_norm_bwd[(tile_num,)]( + dx, + dy, + dweight, + x, + weight, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + ctx.eps, + N=N, + SILU=ctx.silu, + BLOCK_D=ctx.BLOCK_D, + ) + + return dx, dweight, None, None + + +class SwishLayerNormFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + ) -> torch.Tensor: + assert x.dim() == 2, f"x.dim() == {x.dim()}, expected 2" + x = switch_to_contiguous_if_needed(x) + N, D = x.shape + + assert bias is not None and weight is not None + assert weight.dim() == 1 + assert bias.dim() == 1 + assert weight.numel() == D + assert bias.numel() == D + + y = torch.empty_like(x) + mean = torch.empty((N,), dtype=torch.float32, device=x.device) + rstd = torch.empty((N,), dtype=torch.float32, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = min(max(BLOCK_D // 256, 1), 8) + + ctx.save_for_backward(x, weight, bias, mean, rstd) + ctx.BLOCK_D = BLOCK_D + ctx.num_warps = num_warps + ctx.eps = eps + if N == 0: + return y + + # pyre-ignore[28] + grid = lambda meta: ( # noqa E731 + triton.cdiv(N, meta["BLOCK_N"]), + ) + _weighted_layer_norm_fwd[grid]( + x, + y, + weight, + bias, + mean, + rstd, + N, + D, + eps, + x.stride(0), + y.stride(0), + IS_SWISH=True, + TRAINING=True, + BLOCK_D=BLOCK_D, + COMPUTE_MEAN_AND_RSTD=True, + ) + + return y + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, dy: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], None]: + x, weight, bias, mean, rstd = ctx.saved_tensors + N, D = x.shape + dx = torch.empty_like(x) + sms = torch.cuda.get_device_properties(x.device).multi_processor_count + tile_num = max(1, min(sms * 8, N // 4)) + _dweight = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + _dbias = torch.empty((tile_num, D), dtype=torch.float32, device=x.device) + dweight = torch.empty((D,), dtype=weight.dtype, device=x.device) + dbias = torch.empty((D,), dtype=weight.dtype, device=x.device) + if N == 0: + dweight.zero_() + dbias.zero_() + return dx, dweight, dbias, None + # pyre-ignore[28] + _weighted_layer_norm_bwd_dx[(tile_num,)]( + dx, + dy, + _dweight, + _dbias, + x, + weight, + bias, + mean, + rstd, + dx.stride(0), + dy.stride(0), + x.stride(0), + D, + ctx.eps, + IS_SWISH=True, + N=N, + BLOCK_D=ctx.BLOCK_D, + ) + + def grid(META): + return (triton.cdiv(D, META["BLOCK_D"]),) + + blocks = triton.next_power_of_2(sms * 4) + BLOCK_D = triton.next_power_of_2(triton.cdiv(D, blocks)) + BLOCK_D = min(max(BLOCK_D, 4), 128) + _layer_norm_bwd_dwdb[grid]( + _dweight, + _dbias, + dweight, + dbias, + tile_num, + D, + BLOCK_D=BLOCK_D, + ) + + return dx, dweight, dbias, None + + +@torch.fx.wrap +def triton_layer_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> torch.Tensor: + return LayerNormFunction.apply(x, weight, bias, eps) + + +@torch.fx.wrap +def triton_rms_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor], + eps: float, + silu: bool = False, +) -> torch.Tensor: + return RMSNormFunction.apply(x, weight, eps, silu) + + +@torch.fx.wrap +def triton_swish_layer_norm( + x: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + eps: float, +) -> torch.Tensor: + return SwishLayerNormFunction.apply(x, weight, bias, eps) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_position.py b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_position.py new file mode 100644 index 0000000000..793b61f5e0 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/triton/triton_position.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + + +from typing import List, Optional, Tuple + +import torch + +# @manual=//triton:triton +import triton + +# @manual=//triton:triton +import triton.language as tl + +try: + torch.ops.load_library("//hammer/ops/cuda:cuda_ops") +except OSError: + pass + +from generative_recommenders.common import ( + autotune_max_seq_len, + prev_power_of_2, + switch_to_contiguous_if_needed, + triton_autotune, +) + + +def _autotune_configs() -> List[triton.Config]: + configs = [] + for BLOCK_N in [16, 32, 64]: + for num_stages in [1, 2]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK_N": BLOCK_N, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +@triton_autotune( + configs=_autotune_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN"], +) +@triton.jit +def _add_timestamp_position_embeddings_kernel( + SeqEmb, + Offsets, + Lengths, + PosEmb, + TsEmb, + Out, + TS, + PosInds, + TsInds, + NumTargets, + AUTOTUNE_MAX_SEQ_LEN, + D, + num_time_buckets, + time_bucket_increments, + time_bucket_scale, + time_delta, + max_contextual_seq_len, + max_pos_ind, + stride_sn, + stride_pn, + stride_tn, + stride_on, + TRAINING: tl.constexpr, + HAS_MULTIPLE_TARGETS: tl.constexpr, + INTERLEAVE_TARGETS: tl.constexpr, + TIME_BUCKET_FN: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """ + SeqEmb has shape (sum_B(N_i), D), + PosEmb has shape (N_p, D), + TsEmb has shape (N_t, D), + Out has shape (sum_B(N_i), D) + """ + + off_b = tl.program_id(0) + off_n = tl.program_id(1) + seq_start = tl.load(Offsets + off_b) + seq_end = tl.load(Offsets + off_b + 1) + seq_len = seq_end - seq_start + start_n = off_n * BLOCK_N + if start_n >= seq_len: + return + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + seq_emb_offsets = offs_n[:, None] * stride_sn + offs_d[None, :] + SeqEmb += seq_start.to(tl.int64) * stride_sn + mask_n = offs_n < seq_len + # position encoding + seq_len = tl.load(Lengths + off_b) + if HAS_MULTIPLE_TARGETS: + num_targets = tl.load(NumTargets + off_b) + if INTERLEAVE_TARGETS: + high_ind = seq_len - num_targets * 2 + else: + high_ind = seq_len - num_targets + else: + high_ind = seq_len + pos_inds = tl.where(offs_n < high_ind, offs_n, high_ind) + pos_inds = high_ind - pos_inds + max_contextual_seq_len + pos_inds = tl.where(pos_inds < max_pos_ind - 1, pos_inds, max_pos_ind - 1) + pos_inds = tl.where(offs_n < max_contextual_seq_len, offs_n, pos_inds) + if TRAINING: + tl.store(PosInds + seq_start + offs_n, pos_inds, mask=mask_n) + pos_emb_offsets = pos_inds[:, None] * stride_pn + offs_d[None, :] + # timestamp encoding + ts = tl.load(TS + seq_start + offs_n, mask=mask_n) + query_time = tl.load(TS + seq_end - 1) + ts = query_time - ts + time_delta + ts = tl.where(ts > 1e-6, ts, 1e-6) / time_bucket_increments + if TIME_BUCKET_FN == "log": + ts = tl.log(ts) + else: + ts = tl.sqrt(ts) + ts = ts * time_bucket_scale + ts = ts.to(tl.int32) + ts = tl.where(ts > 0, ts, 0) + ts = tl.where(ts < num_time_buckets, ts, num_time_buckets) + if TRAINING: + tl.store(TsInds + seq_start + offs_n, ts, mask=mask_n) + ts_emb_offsets = ts[:, None] * stride_tn + offs_d[None, :] + Out += seq_start.to(tl.int64) * stride_on + out_offsets = Out + offs_n[:, None] * stride_on + offs_d[None, :] + for _d in range(0, D, BLOCK_D): + mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D + seq_emb = tl.load(SeqEmb + seq_emb_offsets, mask=mask) + pos_emb = tl.load(PosEmb + pos_emb_offsets, mask=mask) + ts_emb = tl.load(TsEmb + ts_emb_offsets, mask=mask) + tl.store(out_offsets, seq_emb + (pos_emb + ts_emb).to(seq_emb.dtype), mask=mask) + seq_emb_offsets += BLOCK_D + pos_emb_offsets += BLOCK_D + ts_emb_offsets += BLOCK_D + out_offsets += BLOCK_D + offs_d += BLOCK_D + + +def bwd_pre_hook(nargs): + nargs["Out"].zero_() + + +def _add_embeddings_bwd_configs() -> List[triton.Config]: + configs = [] + for BLOCK in [32, 64, 128]: + for num_stages in [2, 3, 4]: + for num_warps in [2, 4, 8]: + configs.append( + triton.Config( + { + "BLOCK": BLOCK, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=bwd_pre_hook, + ) + ) + return configs + + +@triton_autotune( + configs=_add_embeddings_bwd_configs(), + key=["AUTOTUNE_MAX_SEQ_LEN", "AUTOTUNE_B", "D"], +) +@triton.jit +def _add_embeddings_bwd_kernel( + In, + KeyInds, + ValueInds, + Out, + AUTOTUNE_MAX_SEQ_LEN, + AUTOTUNE_B, + D, + jagged_size, + stride_in, + stride_on, + BLOCK_D: tl.constexpr, + BLOCK: tl.constexpr, +): + off_block = tl.program_id(0) + offs_d = tl.arange(0, BLOCK_D) + mask_d = offs_d < D + key_ind = -1 + key_ind = key_ind.to(KeyInds.dtype.element_ty) # pyre-ignore[16] + accumulator = tl.zeros((BLOCK_D,), dtype=In.dtype.element_ty) + for off_i in range(0, BLOCK): + off = off_block * BLOCK + off_i + if off < jagged_size: + value_ind = tl.load(ValueInds + off) + in_offset = In + value_ind.to(tl.int64) * stride_in + jagged_in = tl.load(in_offset + offs_d, mask=mask_d) + key_ind_new = tl.load(KeyInds + off) + if key_ind == key_ind_new: + accumulator += jagged_in + else: + if key_ind >= 0: + out_offset = Out + key_ind.to(tl.int64) * stride_on + tl.atomic_add( + out_offset + offs_d, + accumulator.to(Out.dtype.element_ty), + mask=mask_d, + sem="relaxed", + ) + key_ind = key_ind_new + accumulator = jagged_in + if key_ind >= 0: + out_offset = Out + key_ind.to(tl.int64) * stride_on + tl.atomic_add( + out_offset + offs_d, + accumulator.to(Out.dtype.element_ty), + mask=mask_d, + sem="relaxed", + ) + + +class _AddTimestampPositionEmbeddingsFunction(torch.autograd.Function): + @staticmethod + # pyre-ignore[14] + def forward( + ctx, + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, + ): + seq_embeddings = switch_to_contiguous_if_needed(seq_embeddings) + pos_embeddings = switch_to_contiguous_if_needed(pos_embeddings) + ts_embeddings = switch_to_contiguous_if_needed(ts_embeddings) + + max_pos_ind = pos_embeddings.shape[0] + B = seq_lengths.shape[0] + N, D = seq_embeddings.shape + assert len(pos_embeddings.shape) == 2 + assert len(ts_embeddings.shape) == 2 + assert ( + pos_embeddings.shape[1] == D + ), "shape[1] of pos_embeddings much match seq_embeddings" + assert ( + ts_embeddings.shape[1] == D + ), "shape[1] of ts_embeddings much match seq_embeddings" + out = torch.empty_like(seq_embeddings) + + timestamps = switch_to_contiguous_if_needed(timestamps) + ts_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32) + pos_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32) + ts_emb_size = ts_embeddings.shape[0] + + grid = lambda meta: ( # noqa E731 + B, + triton.cdiv(max_seq_len, meta["BLOCK_N"]), + ) + BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64 + _add_timestamp_position_embeddings_kernel[grid]( + SeqEmb=seq_embeddings, + Offsets=seq_offsets, + Lengths=seq_lengths, + PosEmb=pos_embeddings, + TsEmb=ts_embeddings, + Out=out, + TS=timestamps, + PosInds=pos_inds, + TsInds=ts_inds, + NumTargets=num_targets, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len), + D=D, + num_time_buckets=ts_emb_size - 1, + time_bucket_increments=60.0, + time_bucket_scale=1.0, + time_delta=0, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + stride_sn=seq_embeddings.stride(0), + stride_pn=pos_embeddings.stride(0), + stride_tn=ts_embeddings.stride(0), + stride_on=out.stride(0), + TRAINING=True, + HAS_MULTIPLE_TARGETS=num_targets is not None, + INTERLEAVE_TARGETS=interleave_targets, + TIME_BUCKET_FN=time_bucket_fn, + BLOCK_D=BLOCK_D, + ) + try: + values = torch.arange(0, N, dtype=torch.int32, device=timestamps.device) + sorted_ts_key_inds, sorted_ts_value_inds = torch.ops.hammer.sort_kv_pairs( + ts_inds, values + ) + sorted_pos_key_inds, sorted_pos_value_inds = torch.ops.hammer.sort_kv_pairs( + pos_inds, values + ) + except Exception: + sorted_ts_key_inds, sorted_ts_value_inds = torch.sort(ts_inds) + sorted_pos_key_inds, sorted_pos_value_inds = torch.sort(pos_inds) + ctx.save_for_backward( + sorted_pos_key_inds, + sorted_pos_value_inds, + sorted_ts_key_inds, + sorted_ts_value_inds, + ) + ctx.B = B + ctx.D = D + ctx.max_seq_len = max_seq_len + ctx.pos_emb_size = pos_embeddings.shape[0] + ctx.ts_emb_size = ts_emb_size + ctx.pos_dtype = pos_embeddings.dtype + ctx.ts_dtype = ts_embeddings.dtype + return out + + @staticmethod + # pyre-ignore[14] + def backward( + ctx, d_out: torch.Tensor + ) -> Tuple[ + torch.Tensor, + None, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + ]: + ( + sorted_pos_key_inds, + sorted_pos_value_inds, + sorted_ts_key_inds, + sorted_ts_value_inds, + ) = ctx.saved_tensors + d_pos_embeddings = torch.empty( + (ctx.pos_emb_size, ctx.D), device=d_out.device, dtype=torch.float32 + ) + d_ts_embeddings = torch.empty( + (ctx.ts_emb_size, ctx.D), device=d_out.device, dtype=torch.float32 + ) + grid = lambda meta: (triton.cdiv(d_out.shape[0], meta["BLOCK"]),) # noqa E731 + AUTOTUNE_B = prev_power_of_2(ctx.B) + _add_embeddings_bwd_kernel[grid]( + In=d_out, + KeyInds=sorted_pos_key_inds, + ValueInds=sorted_pos_value_inds, + Out=d_pos_embeddings, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + AUTOTUNE_B=AUTOTUNE_B, + D=ctx.D, + jagged_size=d_out.shape[0], + stride_in=d_out.stride(0), + stride_on=d_pos_embeddings.stride(0), + BLOCK_D=triton.next_power_of_2(ctx.D), + ) + _add_embeddings_bwd_kernel[grid]( + In=d_out, + KeyInds=sorted_ts_key_inds, + ValueInds=sorted_ts_value_inds, + Out=d_ts_embeddings, + AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len), + AUTOTUNE_B=AUTOTUNE_B, + D=ctx.D, + jagged_size=d_out.shape[0], + stride_in=d_out.stride(0), + stride_on=d_ts_embeddings.stride(0), + BLOCK_D=triton.next_power_of_2(ctx.D), + ) + return ( + d_out, + None, + d_pos_embeddings.to(ctx.pos_dtype), + d_ts_embeddings.to(ctx.ts_dtype), + None, + None, + None, + None, + None, + None, + None, + ) + + +@torch.fx.wrap +def triton_add_timestamp_positional_embeddings( + seq_embeddings: torch.Tensor, + seq_offsets: torch.Tensor, + pos_embeddings: torch.Tensor, + ts_embeddings: torch.Tensor, + timestamps: torch.Tensor, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: torch.Tensor, + num_targets: Optional[torch.Tensor], + interleave_targets: bool, + time_bucket_fn: str, +) -> torch.Tensor: + return _AddTimestampPositionEmbeddingsFunction.apply( + seq_embeddings, + seq_offsets, + pos_embeddings, + ts_embeddings, + timestamps, + max_seq_len, + max_contextual_seq_len, + seq_lengths, + num_targets, + interleave_targets, + time_bucket_fn, + ) diff --git a/recommendation/dlrm_v3/generative_recommenders/ops/utils.py b/recommendation/dlrm_v3/generative_recommenders/ops/utils.py new file mode 100644 index 0000000000..d1fa08ae41 --- /dev/null +++ b/recommendation/dlrm_v3/generative_recommenders/ops/utils.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env python3 + +# pyre-strict + +import torch + + +def is_sm100() -> bool: + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return props.major == 10 and props.minor == 0 + + +def is_sm90() -> bool: + if not torch.cuda.is_available(): + return False + props = torch.cuda.get_device_properties(0) + return props.major == 9 and props.minor == 0 diff --git a/recommendation/dlrm_v3/inference_modules.py b/recommendation/dlrm_v3/inference_modules.py new file mode 100644 index 0000000000..6bc78694f2 --- /dev/null +++ b/recommendation/dlrm_v3/inference_modules.py @@ -0,0 +1,204 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +Inference modules for DLRMv3. + +This module provides inference-specific components for the HSTU model, +including sparse inference modules and utilities for moving tensors between devices. +""" +from typing import Dict, Optional, Tuple + +import torch +from generative_recommenders.modules.dlrm_hstu import ( + DlrmHSTU, + DlrmHSTUConfig, + SequenceEmbedding, +) +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +IS_INFERENCE: bool = True + + +def set_is_inference(is_inference: bool = False) -> None: + """ + Set the global inference mode flag. + + Args: + is_inference: If True, model operates in inference mode (no labels/weights). + If False, model operates in training/eval mode with labels. + """ + global IS_INFERENCE + IS_INFERENCE = is_inference + + +def get_hstu_model( + table_config, + hstu_config: DlrmHSTUConfig, + table_device: str = "meta", + max_hash_size: Optional[int] = None, + is_dense: bool = False, +) -> DlrmHSTU: + """ + Create and initialize an HSTU model for inference. + + Args: + table_config: Dictionary of embedding table configurations. + hstu_config: HSTU model configuration object. + table_device: Device to place embedding tables on ('meta', 'cpu', or 'cuda'). + max_hash_size: Optional maximum hash size to cap embedding table sizes. + is_dense: If True, creates model for dense-only operations. + + Returns: + Initialized DlrmHSTU model in eval mode. + """ + if max_hash_size is not None: + for t in table_config.values(): + t.num_embeddings = ( + max_hash_size if t.num_embeddings > max_hash_size else t.num_embeddings + ) + model = DlrmHSTU( + hstu_configs=hstu_config, + embedding_tables=table_config, + is_inference=IS_INFERENCE, + is_dense=is_dense, + ) + model.eval() + model.recursive_setattr("_use_triton_cc", False) + for _, module in model.named_modules(): + if isinstance(module, EmbeddingBagCollection) or isinstance( + module, EmbeddingCollection + ): + module.to_empty(device=table_device) + return model + + +class HSTUSparseInferenceModule(torch.nn.Module): + """ + Module for sparse (embedding) inference operations. + + Handles embedding lookups and preprocessing for the HSTU model, + running on CPU to handle large embedding tables. + + Args: + table_config: Dictionary of embedding table configurations. + hstu_config: HSTU model configuration object. + """ + + def __init__( + self, + table_config, + hstu_config: DlrmHSTUConfig, + ) -> None: + super().__init__() + self._hstu_model: DlrmHSTU = get_hstu_model( + table_config, + hstu_config, + table_device="cpu", + ) + + def forward( + self, + uih_features: KeyedJaggedTensor, + candidates_features: KeyedJaggedTensor, + ) -> Tuple[ + Dict[str, SequenceEmbedding], + Dict[str, torch.Tensor], + int, + torch.Tensor, + int, + torch.Tensor, + ]: + """ + Run sparse preprocessing and embedding lookups. + + Args: + uih_features: User interaction history features as KeyedJaggedTensor. + candidates_features: Candidate item features as KeyedJaggedTensor. + + Returns: + Tuple containing: + - seq_embeddings: Dictionary of sequence embeddings per feature. + - payload_features: Dictionary of payload feature tensors. + - max_uih_len: Maximum user interaction history length. + - uih_seq_lengths: Tensor of UIH sequence lengths per batch item. + - max_num_candidates: Maximum number of candidates. + - num_candidates: Tensor of candidate counts per batch item. + """ + ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) = self._hstu_model.preprocess( + uih_features=uih_features, + candidates_features=candidates_features, + ) + return ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) + + +def move_sparse_output_to_device( + seq_embeddings: Dict[str, SequenceEmbedding], + payload_features: Dict[str, torch.Tensor], + uih_seq_lengths: torch.Tensor, + num_candidates: torch.Tensor, + device: torch.device, +) -> Tuple[ + Dict[str, SequenceEmbedding], + Dict[str, torch.Tensor], + torch.Tensor, + torch.Tensor, +]: + """ + Move sparse module outputs from CPU to the target device (typically GPU). + + Converts embeddings to bfloat16 for efficient GPU computation. + + Args: + seq_embeddings: Dictionary of sequence embeddings to move. + payload_features: Dictionary of payload features to move. + uih_seq_lengths: UIH sequence lengths tensor to move. + num_candidates: Number of candidates tensor to move. + device: Target device (e.g., torch.device('cuda:0')). + + Returns: + Tuple of moved tensors on the target device. + """ + num_candidates = num_candidates.to(device) + uih_seq_lengths = uih_seq_lengths.to(device) + seq_embeddings = { + k: SequenceEmbedding( + lengths=seq_embeddings[k].lengths.to(device), + embedding=seq_embeddings[k].embedding.to(device).to(torch.bfloat16), + ) + for k in seq_embeddings.keys() + } + for k, v in payload_features.items(): + payload_features[k] = v.to(device) + return seq_embeddings, payload_features, uih_seq_lengths, num_candidates diff --git a/recommendation/dlrm_v3/main.py b/recommendation/dlrm_v3/main.py new file mode 100644 index 0000000000..b5dbbe3169 --- /dev/null +++ b/recommendation/dlrm_v3/main.py @@ -0,0 +1,851 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +mlperf dlrm_v3 inference benchmarking tool. +""" + +import argparse +import array +import logging +import random +import threading + +logging.basicConfig(level=logging.INFO) +import os +import sys +import time +from typing import Any, Dict, List, Optional, Union + +# pyre-ignore [21] +import mlperf_loadgen as lg # @manual +import numpy as np +import torch +from generative_recommenders.common import set_dev_mode, set_verbose_level +from configs import get_embedding_table_config, get_hstu_configs +from datasets.dataset import Dataset, Samples +from datasets.synthetic_streaming import ( + DLRMv3SyntheticStreamingDataset, +) +from data_producer import ( + MultiThreadDataProducer, + QueryItem, + SingleThreadDataProducer, +) +from inference_modules import set_is_inference +from model_family import HSTUModelFamily +from utils import ( + get_dataset, + profiler_or_nullcontext, + SUPPORTED_DATASETS, +) + +logger: logging.Logger = logging.getLogger("main") + +torch.multiprocessing.set_start_method("spawn", force=True) + +USER_CONF = f"{os.path.dirname(__file__)}/user.conf" + + +SCENARIO_MAP = { # pyre-ignore [5] + "Server": lg.TestScenario.Server, + "Offline": lg.TestScenario.Offline, +} + + +def get_args(): # pyre-ignore [3] + """ + Parse command-line arguments for the MLPerf DLRMv3 inference benchmark. + + Returns: + argparse.Namespace: Parsed arguments including dataset selection, model path, + scenario configuration, batch size, and various benchmark parameters. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--dataset", type=str, default="sampled-streaming-100b", choices=SUPPORTED_DATASETS, help="name of the dataset" + ) + parser.add_argument( + "--model-path", type=str, default="", help="path to the model checkpoint. Example: /home/username/data/dlrmv3_trained_checkpoint/" + ) + parser.add_argument( + "--scenario-name", type=str, default="Server", choices={"Server", "Offline"}, help="inference benchmark scenario" + ) + parser.add_argument( + "--batchsize", type=int, default=10, help="batch size used in the benchmark" + ) + parser.add_argument( + "--output-trace", type=bool, default=False, help="Whether to output trace" + ) + parser.add_argument( + "--data-producer-threads", type=int, default=8, help="Number of threads used in data producer" + ) + parser.add_argument( + "--compute-eval", type=bool, default=False, help="If true, will run AccuracyOnly mode and outputs both predictions and labels for accuracy calcuations" + ) + parser.add_argument( + "--find-peak-performance", type=bool, default=False, help="Whether to find peak performance in the benchmark" + ) + parser.add_argument( + "--dataset-path-prefix", type=str, default=f"/home/{os.getlogin()}/dlrmv3_dataset/", help="Prefix to the dataset path. Example: /home/username/" + ) + parser.add_argument( + "--warmup-ratio", type=float, default=0.1, help="The ratio of the dataset used to warmup SUT" + ) + parser.add_argument( + "--num-queries", type=int, default=None, help="Number of queries to run in the benchmark" + ) + parser.add_argument( + "--target-qps", type=int, default=1000, help="Benchmark target QPS. Needs to be tuned for different implementations to balance latency and throughput" + ) + parser.add_argument( + "--numpy-rand-seed", type=int, default=123, help="Numpy random seed" + ) + parser.add_argument( + "--sparse-quant", type=bool, default=False, help="Whether to quantize sparse arch" + ) + parser.add_argument( + "--dataset-percentage", type=float, default=0.0001, help="Percentage of the dataset to run in the benchmark" + ) + args, unknown_args = parser.parse_known_args() + logger.warning(f"unknown_args: {unknown_args}") + return args + + +class Runner: + """ + Orchestrates inference benchmark execution. + + Manages data production, model inference, and result collection for + MLPerf LoadGen-based benchmarking. + + Args: + model: The HSTU model family instance for making predictions. + ds: Dataset to fetch samples from. + num_queries: Total number of queries to process. + data_producer_threads: Number of threads for data loading (default: 1). + batchsize: Batch size for inference (default: 128). + compute_eval: Whether to compute evaluation metrics (default: False). + """ + + def __init__( + self, + model: HSTUModelFamily, + ds: Dataset, + num_queries: int, + data_producer_threads: int = 1, + batchsize: int = 128, + compute_eval: bool = False, + ) -> None: + self.model = model + if data_producer_threads == 1: + self.data_producer: Union[ + MultiThreadDataProducer, SingleThreadDataProducer + ] = SingleThreadDataProducer(ds, self.run_one_item) + else: + self.data_producer = MultiThreadDataProducer( + ds, data_producer_threads, self.run_one_item + ) + self.batchsize = batchsize + self.compute_eval = compute_eval + self.reset_states(num_queries=num_queries) + + def reset_states(self, num_queries: int) -> None: + """ + Reset all internal state for a new benchmark run. + + Args: + num_queries: Number of queries expected in this run. + """ + self.result_timing: List[Dict[str, float]] = [] + self.result_batches: List[int] = [] + self.current_query_ids: List[int] = [] + self.current_content_ids: List[int] = [] + self.current_t0: List[float] = [] + self.num_queries: int = num_queries + self.processed_queries: int = 0 + + def run_one_item(self, qitem: QueryItem) -> None: + """ + Process a single query item through model inference. + + Runs prediction, records timing metrics, and sends results back to LoadGen. + + Args: + qitem: Query item containing batch of samples to process. + """ + try: + t0_prediction: float = time.time() + prediction_output = self.model.predict(qitem.samples) + dt_prediction: float = time.time() - t0_prediction + assert prediction_output is not None + ( + mt_target_preds, + mt_target_labels, + mt_target_weights, + dt_sparse, + dt_dense, + ) = prediction_output + if self.compute_eval: + assert mt_target_labels is not None + assert mt_target_weights is not None + self.result_timing.append( + { + "total": time.time() - qitem.start, + "prediction": dt_prediction, + "queue": qitem.dt_queue, + "batching": qitem.dt_batching, + "sparse": dt_sparse, + "dense": dt_dense, + } + ) + self.result_batches.append(len(qitem.query_ids)) + except Exception as ex: # pylint: disable=broad-except + logger.error("thread: failed, %s", ex) + finally: + candidate_size = mt_target_preds.size(1) // len(qitem.query_ids) + if not self.compute_eval: + for i, query_id in enumerate(qitem.query_ids): + query_mt_target_preds = ( + mt_target_preds[ # pyre-ignore [61] + 0, + candidate_size * i : candidate_size * (i + 1), + ] + .view(-1) + .float() + .numpy() + ) + response_array = array.array("B", query_mt_target_preds.tobytes()) + bi = response_array.buffer_info() + # since we send buffer to loadgen, needs `response_array` in memory during send + lg.QuerySamplesComplete( + [lg.QuerySampleResponse(query_id, bi[0], bi[1])] + ) + else: + for i, query_id in enumerate(qitem.query_ids): + query_mt_target_preds = ( + mt_target_preds[ # pyre-ignore [61] + 0, candidate_size * i : candidate_size * (i + 1) + ] + .view(-1) + .float() + .numpy() + ) + query_mt_target_labels = ( + mt_target_labels[ # pyre-ignore [16,61] + 0, candidate_size * i : candidate_size * (i + 1) + ] + .view(-1) + .float() + .numpy() + ) + query_mt_target_weights = ( + mt_target_weights[ # pyre-ignore [61] + 0, candidate_size * i : candidate_size * (i + 1) + ] + .view(-1) + .float() + .numpy() + ) + np_array = np.concatenate( + [ + query_mt_target_preds, + query_mt_target_labels, + query_mt_target_weights, + np.array([candidate_size]).astype(np.float32), + ] + ) + response_array = array.array("B", np_array.tobytes()) + bi = response_array.buffer_info() + # since we send buffer to loadgen, needs `response_array` in memory during send + lg.QuerySamplesComplete( + [lg.QuerySampleResponse(query_id, bi[0], bi[1])] + ) + + def enqueue(self, query_samples, t0: float) -> None: # pyre-ignore [2] + """ + Enqueue query samples for batch processing. + + Collects samples until batch size is reached, then dispatches to data producer. + + Args: + query_samples: List of LoadGen query sample objects. + t0: Timestamp when this batch started. + """ + self.current_query_ids.extend([q.id for q in query_samples]) + self.current_content_ids.extend([q.index for q in query_samples]) + self.current_t0.append(t0) + self.processed_queries += len(query_samples) + t0: float = min(self.current_t0) + dt_queue: float = max(self.current_t0) - min(self.current_t0) + if ( + self.processed_queries >= self.num_queries + or len(self.current_query_ids) >= self.batchsize + ): + for i in range(len(self.current_query_ids) // self.batchsize): + self.data_producer.enqueue( + query_ids=self.current_query_ids[ + i * self.batchsize : (i + 1) * self.batchsize + ], + content_ids=self.current_content_ids[ + i * self.batchsize : (i + 1) * self.batchsize + ], + t0=t0, + dt_queue=dt_queue, + ) + remaining_s: int = len(self.current_query_ids) % self.batchsize + if remaining_s > 0: + self.data_producer.enqueue( + query_ids=self.current_query_ids[-remaining_s:], + content_ids=self.current_content_ids[-remaining_s:], + t0=t0, + dt_queue=dt_queue, + ) + self.current_query_ids = [] + self.current_content_ids = [] + self.current_t0 = [] + + def finish(self) -> None: + """Signal data producer to finish and wait for completion.""" + self.data_producer.finish() + + +def add_results( + final_results: Dict[str, Any], + result_timing: List[Dict[str, float]], + result_batches: List[int], +) -> None: + """ + Aggregate and log benchmark results. + + Computes percentile statistics and QPS metrics from timing data. + + Args: + final_results: Dictionary to populate with aggregated results. + result_timing: List of timing dictionaries for each batch. + result_batches: List of batch sizes processed. + """ + percentiles: list[float] = [50.0, 80.0, 90.0, 95.0, 99.0, 99.9] + buckets_dict: Dict[str, List[float]] = {} + buckets_str_dict: Dict[str, str] = {} + total_timing: list[float] = [result["total"] for result in result_timing] + for key in ["total", "prediction", "queue", "batching", "sparse", "dense"]: + timing: list[float] = [result[key] for result in result_timing] + buckets: List[float] = np.percentile(timing, percentiles).tolist() + buckets_str: str = ",".join( + ["| {}:{:.4f}| ".format(p, b) for p, b in zip(percentiles, buckets)] + ) + buckets_dict[key] = buckets + buckets_str_dict[key] = buckets_str + total_batches = sum(result_batches) + + final_results["good"] = len(total_timing) + final_results["avg_time"] = np.mean(total_timing) + final_results["percentiles"] = { + str(k): v for k, v in zip(percentiles, buckets_dict["total"]) + } + final_results["qps"] = total_batches / final_results["took"] + final_results["count"] = total_batches + + for i, timing in enumerate(result_timing): + logger.warning(f"timing of {i}: {timing}") + + logger.warning( + "{} qps={:.2f}, avg_query_time={:.4f}, time={:.3f}, queries={}, tiles={}".format( + final_results["scenario"], + final_results["qps"], + final_results["avg_time"], + final_results["took"], + len(result_timing), + buckets_str_dict["total"], + ) + ) + for key in ["prediction", "queue", "batching", "sparse", "dense"]: + logger.warning(f"{key}: {buckets_str_dict[key]}") + + +def get_num_queries( + input_size: Optional[int], + one_pass_size: int, + scenario_name: str, + offline_target_qps: int, + target_duration: float, +) -> int: + """ + Determine the number of queries to run based on scenario and settings. + + Args: + input_size: User-specified query count (None to use defaults). + one_pass_size: Size of one complete pass through the dataset. + scenario_name: MLPerf scenario name ('Server' or 'Offline'). + offline_target_qps: Target QPS for offline scenario. + target_duration: Target duration in milliseconds. + + Returns: + Number of queries to execute in the benchmark run. + """ + if scenario_name == "Offline": + # consistent with https://github.com/mlcommons/inference/blob/8999c4d686f6e4a180da14597c97063fce7c9f33/loadgen/test_settings_internal.cc#L147 + return int(1.1 * target_duration / 1000 * offline_target_qps) + else: + if input_size is None: + return one_pass_size + return input_size + + +class StreamingQuerySampler: + """ + Sampler for streaming dataset + The execution order is determined by `StreamingQuerySampler.run_order`, not by the QSL or input query ID. + This ensures that queries are executed according to their timestamp constraints. + """ + + def __init__( + self, + ds: DLRMv3SyntheticStreamingDataset, + dataset_percentage: float, + scenario_name: str, + offline_target_qps: int, + target_duration: float, + input_queries: Optional[int] = None, + compute_eval: bool = False, + ) -> None: + self.ds: DLRMv3SyntheticStreamingDataset = ds + self.ds.is_inference = True + self.inference_ts: int = self.ds.total_ts - self.ds.train_ts + self.start_ts: int = self.ds.train_ts + self.dataset_percentage: float = dataset_percentage + self.num_unique_requests: List[int] = self.get_num_unique_requests( + warmup_ratio=1.0 + ) + self.num_unique_requests_cumsum: List[int] = np.cumsum( + self.num_unique_requests + ).tolist() + self.total_requests: int = sum(self.num_unique_requests) + self.run_order: List[List[int]] = self.build_random_exec_order() + self.ts_idx: int = 0 + self.ts_processed_cnt: int = 0 + self.last_loaded: float = -1.0 + num_queries: int = get_num_queries( + input_size=input_queries, + one_pass_size=self.total_requests, + scenario_name=scenario_name, + offline_target_qps=offline_target_qps, + target_duration=target_duration, + ) + logger.warning( + f"StreamingQuerySampler constructred to handle {num_queries} queries" + ) + self.num_repeats: int = ( + max(1, num_queries // self.total_requests) if not compute_eval else 1 + ) + self.remaining_queries: int = ( + num_queries % self.total_requests if not compute_eval else 0 + ) + self._lock = threading.Lock() + + def get_num_unique_requests(self, warmup_ratio: float) -> List[int]: + """ + Calculate number of unique requests per timestamp. + + Args: + warmup_ratio: Fraction of users to include in warmup. + + Returns: + List of request counts per timestamp. + """ + num_unique_requests = [ + int( + self.ds.ts_to_users_cumsum[t][-1] + * self.dataset_percentage + * warmup_ratio + ) + for t in range(self.start_ts, self.start_ts + self.inference_ts) + ] + return num_unique_requests + + def build_random_exec_order(self) -> List[List[int]]: + """ + Build randomized execution order for each timestamp. + + Returns: + List of shuffled index lists, one per timestamp. + """ + order = [] + for req_size in self.num_unique_requests: + within_ts_order = list(range(req_size)) + random.shuffle(within_ts_order) + order.append(within_ts_order) + return order + + def init_sut(self) -> None: + """Initialize System Under Test state for a new benchmark run.""" + self.ts_idx = 0 + self.ts_processed_cnt = 0 + self.ds.set_ts(self.start_ts) + + def load_query_samples(self, query_ids: List[Optional[int]]) -> None: + """ + Load query samples into memory for the benchmark. + + Args: + query_ids: List of query identifiers to load. + """ + length = len(query_ids) + ts_idx: int = 0 + while self.num_unique_requests_cumsum[ts_idx] < length: + ts_idx += 1 + for i in range(0, ts_idx): + self.ds.set_ts(i + self.start_ts) + self.ds.load_query_samples(self.run_order[i]) + self.ds.set_ts(ts_idx + self.start_ts) + delta_length = ( + length + if ts_idx == 0 + else length - self.num_unique_requests_cumsum[ts_idx - 1] + ) + self.ds.load_query_samples(self.run_order[ts_idx][:delta_length]) + self.init_sut() + self.last_loaded = time.time() + + def unload_query_samples(self, sample_list: List[int]) -> None: + """ + Unload query samples from memory. + + Args: + sample_list: List of sample identifiers to unload. + """ + self.ds.unload_query_samples(sample_list) + + def get_samples(self, id_list: List[int]) -> List[Samples]: + """ + Get samples for a batch of queries, handling timestamp boundaries. + + Args: + id_list: List of query identifiers. + + Returns: + List of Samples objects, potentially spanning multiple timestamps. + """ + batch_size: int = len(id_list) + with self._lock: + curr_ts_idx: int = self.ts_idx + curr_ts_unique_requests: int = self.num_unique_requests[curr_ts_idx] + curr_ts_queries: int = curr_ts_unique_requests * self.num_repeats + if curr_ts_idx == self.inference_ts - 1: + curr_ts_queries += self.remaining_queries + begin_query_idx: int = self.ts_processed_cnt + end_query_idx: int = min(begin_query_idx + batch_size, curr_ts_queries) + begin_request_idx: int = begin_query_idx % curr_ts_unique_requests + end_request_idx: int = end_query_idx % curr_ts_unique_requests + if begin_query_idx + batch_size >= curr_ts_queries: + self.ts_idx += 1 + self.ts_processed_cnt = begin_query_idx + batch_size - curr_ts_queries + else: + self.ts_processed_cnt = begin_query_idx + batch_size + # requests of current ts + outputs: List[Samples] = [] + if end_request_idx > begin_request_idx: + output: Samples = self.ds.get_samples_with_ts( + self.run_order[curr_ts_idx][begin_request_idx:end_request_idx], + curr_ts_idx + self.start_ts, + ) + outputs.append(output) + else: + if begin_request_idx < curr_ts_unique_requests: + output: Samples = self.ds.get_samples_with_ts( + self.run_order[curr_ts_idx][begin_request_idx:], + curr_ts_idx + self.start_ts, + ) + outputs.append(output) + if end_request_idx > 0: + output = self.ds.get_samples_with_ts( + self.run_order[curr_ts_idx][0:end_request_idx], + curr_ts_idx + self.start_ts, + ) + outputs.append(output) + # requests of next ts + if begin_query_idx + batch_size > curr_ts_queries: + output: Samples = self.ds.get_samples_with_ts( + self.run_order[curr_ts_idx + 1][ + : begin_query_idx + batch_size - curr_ts_queries + ], + curr_ts_idx + 1 + self.start_ts, + ) + outputs.append(output) + return outputs + + def get_item_count(self) -> int: + """ + Get total number of items in the dataset. + + Returns: + Total request count across all timestamps. + """ + return self.total_requests + + +def run( + dataset: str = "sampled-streaming-100b", + model_path: str = "", + scenario_name: str = "Server", + batchsize: int = 16, + output_trace: bool = False, + data_producer_threads: int = 4, + compute_eval: bool = False, + find_peak_performance: bool = False, + dataset_path_prefix: str = "", + warmup_ratio: float = 0.1, + target_qps: Optional[int] = None, + num_queries: Optional[int] = None, + numpy_rand_seed: int = 123, + sparse_quant: bool = False, + dataset_percentage: float = 1.0, +) -> None: + """ + Execute the MLPerf DLRMv3 inference benchmark. + + Sets up the model, dataset, and LoadGen infrastructure, then runs + warmup and official benchmark phases. + + Args: + dataset: Dataset identifier to use. + model_path: Path to model checkpoint directory. + scenario_name: MLPerf scenario ('Server' or 'Offline'). + batchsize: Batch size for inference. + output_trace: Whether to output profiling traces. + data_producer_threads: Number of data loading threads. + compute_eval: Whether to compute accuracy metrics. + find_peak_performance: Whether to run peak performance finding mode. + dataset_path_prefix: Prefix path for dataset files. + warmup_ratio: Fraction of data to use for warmup. + target_qps: Target queries per second. + num_queries: Number of queries to run (None for automatic). + numpy_rand_seed: Random seed for reproducibility. + sparse_quant: Whether to quantize sparse embeddings. + dataset_percentage: Fraction of dataset to use. + """ + set_dev_mode(False) + if scenario_name not in SCENARIO_MAP: + raise NotImplementedError("valid scanarios:" + str(list(SCENARIO_MAP.keys()))) + scenario = SCENARIO_MAP[scenario_name] + np.random.seed(numpy_rand_seed) + random.seed(numpy_rand_seed) + + hstu_config = get_hstu_configs(dataset) + hstu_config.max_num_candidates = hstu_config.max_num_candidates_inference + table_config = get_embedding_table_config(dataset) + set_is_inference(is_inference=not compute_eval) + + user_conf = os.path.abspath(USER_CONF) + if not os.path.exists(user_conf): + logger.error("{} not found".format(user_conf)) + sys.exit(1) + + settings = lg.TestSettings() + settings.FromConfig(user_conf, model_path, scenario_name) + settings.scenario = scenario + settings.mode = lg.TestMode.PerformanceOnly + if compute_eval: + settings.mode = lg.TestMode.AccuracyOnly + if find_peak_performance: + settings.mode = lg.TestMode.FindPeakPerformance + if target_qps: + settings.server_target_qps = float(target_qps) + settings.offline_expected_qps = float(target_qps) + + model_family = HSTUModelFamily( + hstu_config=hstu_config, + table_config=table_config, + sparse_quant=sparse_quant, + output_trace=output_trace, + compute_eval=compute_eval, + ) + is_streaming: bool = "streaming" in dataset + dataset, kwargs = get_dataset(dataset, dataset_path_prefix) + + ds: Dataset = dataset( + hstu_config=hstu_config, + embedding_config=table_config, + is_inference=not compute_eval, + **kwargs, + ) + if is_streaming: + ds = StreamingQuerySampler( # pyre-ignore + ds=ds, # pyre-ignore [6] + dataset_percentage=dataset_percentage, + input_queries=num_queries, + compute_eval=compute_eval, + scenario_name=scenario_name, + offline_target_qps=settings.offline_expected_qps, + target_duration=settings.min_duration_ms, + ) + model_family.load(model_path) + + # warmup + for autotune_bs in range(batchsize, 0, -1): + logger.warning(f"Autotune for batch size {autotune_bs}") + warmup_ids = list(range(autotune_bs)) + ds.load_query_samples(warmup_ids) + for _ in range(4 * int(os.environ.get("WORLD_SIZE", 1))): + if is_streaming: + ds.init_sut() # pyre-ignore [16] + sample: Union[Samples, List[Samples]] = ds.get_samples(warmup_ids) + if isinstance(sample, Samples): + model_family.predict(sample) + else: + for s in sample: + model_family.predict(s) + ds.unload_query_samples(None) + for h in logger.handlers: + h.flush() + logger.info("Model forward warmup done") + + count = int( + ds.get_item_count() * dataset_percentage + if not is_streaming + else ds.get_item_count() + ) + train_size: int = 0 + if compute_eval: + count = count - train_size + + runner: Runner = Runner( + model_family, + ds, + data_producer_threads=data_producer_threads, + batchsize=batchsize, + compute_eval=compute_eval, + num_queries=count, + ) + + def issue_queries(query_samples) -> None: # pyre-ignore [2] + if compute_eval: + for sample in query_samples: + sample.index = sample.index + train_size + runner.enqueue(query_samples, time.time()) + + def load_query_samples(query_ids: List[int]) -> None: + if compute_eval: + query_ids = [q + train_size for q in query_ids] + ds.load_query_samples(query_ids) + + def flush_queries() -> None: + pass + + if scenario == lg.TestScenario.Server: + # inference benchmark warmup + if is_streaming: + ds.init_sut() + warmup_count: int = sum( + ds.get_num_unique_requests( # pyre-ignore [16] + warmup_ratio=warmup_ratio + ) + ) + else: + warmup_count: int = int(count * warmup_ratio) + runner.reset_states(num_queries=warmup_count) + final_results = { + "runtime": model_family.name(), + "version": model_family.version(), + "time": int(time.time()), + "scenario": str(scenario), + } + settings.min_query_count = warmup_count + settings.max_query_count = warmup_count + sut = lg.ConstructSUT(issue_queries, flush_queries) + qsl = lg.ConstructQSL( + warmup_count, + warmup_count, + load_query_samples, + ds.unload_query_samples, + ) + with profiler_or_nullcontext(enabled=output_trace, with_stack=False): + logger.info(f"starting warmup {scenario} with {warmup_count} queries") + lg.StartTest(sut, qsl, settings) + lg.DestroyQSL(qsl) + lg.DestroySUT(sut) + + # official run + if is_streaming: + ds.init_sut() + final_results = { + "runtime": model_family.name(), + "version": model_family.version(), + "time": int(time.time()), + "scenario": str(scenario), + } + query_size: int = get_num_queries( + input_size=num_queries, + one_pass_size=count, + scenario_name=scenario_name, + offline_target_qps=settings.offline_expected_qps, + target_duration=settings.min_duration_ms, + ) + settings.min_query_count = query_size + settings.max_query_count = query_size + runner.reset_states(num_queries=query_size if not compute_eval else count) + sut = lg.ConstructSUT(issue_queries, flush_queries) + qsl = lg.ConstructQSL( + count, + count, + load_query_samples, + ds.unload_query_samples, + ) + with profiler_or_nullcontext(enabled=output_trace, with_stack=False): + logger.info( + f"starting {scenario} with {query_size} queries and {query_size // count} repeats" + ) + lg.StartTest(sut, qsl, settings) + runner.finish() + final_results["took"] = time.time() - ds.last_loaded + lg.DestroyQSL(qsl) + lg.DestroySUT(sut) + + add_results( + final_results, + runner.result_timing, + runner.result_batches, + ) + # If multiple subprocesses are running the model send a signal to stop them + if int(os.environ.get("WORLD_SIZE", 1)) > 1: + model_family.predict(None) + + +def main() -> None: + set_verbose_level(1) + args = get_args() + logger.info(args) + run( + dataset=args.dataset, + model_path=args.model_path, + scenario_name=args.scenario_name, + batchsize=args.batchsize, + output_trace=args.output_trace, + data_producer_threads=args.data_producer_threads, + compute_eval=args.compute_eval, + find_peak_performance=args.find_peak_performance, + dataset_path_prefix=args.dataset_path_prefix, + warmup_ratio=args.warmup_ratio, + target_qps=args.target_qps, + num_queries=args.num_queries, + numpy_rand_seed=args.numpy_rand_seed, + sparse_quant=args.sparse_quant, + dataset_percentage=args.dataset_percentage, + ) + + +if __name__ == "__main__": + main() diff --git a/recommendation/dlrm_v3/model_family.py b/recommendation/dlrm_v3/model_family.py new file mode 100644 index 0000000000..e40b8d4c02 --- /dev/null +++ b/recommendation/dlrm_v3/model_family.py @@ -0,0 +1,705 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-strict +""" +model_family for dlrm_v3. +""" + +import copy +import functools +import logging +import os +import time +import uuid +from threading import Event +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torchrec +from checkpoint import ( + load_nonsparse_checkpoint, + load_sparse_checkpoint, +) +from configs import HASH_SIZE +from datasets.dataset import Samples +from inference_modules import ( + get_hstu_model, + HSTUSparseInferenceModule, + move_sparse_output_to_device, + set_is_inference, +) +from utils import Profiler +from generative_recommenders.modules.dlrm_hstu import DlrmHSTUConfig, SequenceEmbedding +from pyre_extensions import none_throws +from torch import quantization as quant +from torchrec.distributed.quant_embedding import QuantEmbeddingCollection +from torchrec.modules.embedding_configs import EmbeddingConfig, QuantConfig +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt +from torchrec.test_utils import get_free_port + +logger: logging.Logger = logging.getLogger(__name__) + + +class HSTUModelFamily: + """ + High-level interface for the HSTU model family. + + Manages both sparse (embedding) and dense (transformer) components of the + HSTU model, supporting distributed inference across multiple GPUs. + + Args: + hstu_config: Configuration object for the HSTU model. + table_config: Dictionary of embedding table configurations. + output_trace: Whether to enable profiling trace output. + sparse_quant: Whether to quantize sparse embeddings. + compute_eval: Whether to compute evaluation metrics (includes labels). + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + table_config: Dict[str, EmbeddingConfig], + output_trace: bool = False, + sparse_quant: bool = False, + compute_eval: bool = False, + ) -> None: + self.hstu_config = hstu_config + self.table_config = table_config + self.sparse: ModelFamilySparseDist = ModelFamilySparseDist( + hstu_config=hstu_config, + table_config=table_config, + quant=sparse_quant, + ) + + assert torch.cuda.is_available(), "CUDA is required for this benchmark." + ngpus = torch.cuda.device_count() + self.world_size = int(os.environ.get("WORLD_SIZE", str(ngpus))) + logger.warning(f"Using {self.world_size} GPU(s)...") + dense_model_family_clazz = ( + ModelFamilyDenseDist + if self.world_size > 1 + else ModelFamilyDenseSingleWorker + ) + self.dense: Union[ModelFamilyDenseDist, ModelFamilyDenseSingleWorker] = ( + dense_model_family_clazz( + hstu_config=hstu_config, + table_config=table_config, + output_trace=output_trace, + compute_eval=compute_eval, + ) + ) + + def version(self) -> str: + """Return the PyTorch version string.""" + return torch.__version__ + + def name(self) -> str: + """Return the model family name identifier.""" + return "model-family-hstu" + + def load(self, model_path: str) -> None: + """ + Load model checkpoints from disk. + + Args: + model_path: Base path to the model checkpoint directory. + """ + self.sparse.load(model_path=model_path) + self.dense.load(model_path=model_path) + + def predict( + self, samples: Optional[Samples] + ) -> Optional[ + Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], float, float + ] + ]: + """ + Run inference on a batch of samples. + + Processes samples through sparse embeddings, then dense forward pass. + + Args: + samples: Input samples containing features. If None, signals shutdown. + + Returns: + Tuple of (predictions, labels, weights, sparse_time, dense_time) or None. + """ + with torch.no_grad(): + if samples is None: + self.dense.predict(None, None, 0, None, 0, None) + return None + ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + dt_sparse, + ) = self.sparse.predict(samples) + out = self.dense.predict( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) + ( # pyre-ignore [23] + mt_target_preds, + mt_target_labels, + mt_target_weights, + dt_dense, + ) = out + return ( + mt_target_preds, + mt_target_labels, + mt_target_weights, + dt_sparse, + dt_dense, + ) + + +def ec_patched_forward_wo_embedding_copy( + ec_module: torchrec.EmbeddingCollection, + features: KeyedJaggedTensor, # can also take TensorDict as input +) -> Dict[str, JaggedTensor]: + """ + Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` + and returns a `Dict[str, JaggedTensor]`, which is the result of the individual embeddings for each feature. + + Args: + features (KeyedJaggedTensor): KJT of form [F X B X L]. + + Returns: + Dict[str, JaggedTensor] + """ + features = maybe_td_to_kjt(features, None) + feature_embeddings: Dict[str, JaggedTensor] = {} + jt_dict: Dict[str, JaggedTensor] = features.to_dict() + for i, emb_module in enumerate(ec_module.embeddings.values()): + feature_names = ec_module._feature_names[i] + embedding_names = ec_module._embedding_names_by_table[i] + for j, embedding_name in enumerate(embedding_names): + feature_name = feature_names[j] + f = jt_dict[feature_name] + indices = torch.clamp(f.values(), min=0, max=HASH_SIZE - 1) + lookup = emb_module( + input=indices + ) # remove the dtype cast at https://github.com/meta-pytorch/torchrec/blob/0a2cebd5472a7edc5072b3c912ad8aaa4179b9d9/torchrec/modules/embedding_modules.py#L486 + feature_embeddings[embedding_name] = JaggedTensor( + values=lookup, + lengths=f.lengths(), + weights=f.values() if ec_module._need_indices else None, + ) + return feature_embeddings + + +class ModelFamilySparseDist: + """ + Sparse Arch module manager. + + Handles loading and inference of sparse embedding lookups, optionally + with quantization for memory efficiency. + + Args: + hstu_config: HSTU model configuration. + table_config: Embedding table configurations. + quant: Whether to apply dynamic quantization to embeddings. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + table_config: Dict[str, EmbeddingConfig], + quant: bool = False, + ) -> None: + super(ModelFamilySparseDist, self).__init__() + self.hstu_config = hstu_config + self.table_config = table_config + self.module: Optional[torch.nn.Module] = None + self.quant: bool = quant + + def load(self, model_path: str) -> None: + """ + Load sparse model checkpoint and optionally apply quantization. + + Args: + model_path: Path to the model checkpoint directory. + """ + logger.warning(f"Loading sparse module from {model_path}") + + sparse_arch: HSTUSparseInferenceModule = HSTUSparseInferenceModule( + table_config=self.table_config, + hstu_config=self.hstu_config, + ) + load_sparse_checkpoint(model=sparse_arch._hstu_model, path=model_path) + sparse_arch.eval() + if self.quant: + self.module = quant.quantize_dynamic( + sparse_arch, + qconfig_spec={ + torchrec.EmbeddingCollection: QuantConfig( + activation=quant.PlaceholderObserver.with_args( + dtype=torch.float + ), + weight=quant.PlaceholderObserver.with_args(dtype=torch.int8), + ), + }, + mapping={ + torchrec.EmbeddingCollection: QuantEmbeddingCollection, + }, + inplace=False, + ) + else: + sparse_arch._hstu_model._embedding_collection.forward = ( # pyre-ignore[8] + functools.partial( + ec_patched_forward_wo_embedding_copy, + sparse_arch._hstu_model._embedding_collection, + ) + ) + self.module = sparse_arch + logger.warning(f"sparse module is {self.module}") + + def predict( + self, samples: Samples + ) -> Tuple[ + Dict[str, SequenceEmbedding], + Dict[str, torch.Tensor], + int, + torch.Tensor, + int, + torch.Tensor, + float, + ]: + """ + Run sparse forward pass (embedding lookups). + + Args: + samples: Input samples with feature tensors. + + Returns: + Tuple of (seq_embeddings, payload_features, max_uih_len, uih_seq_lengths, + max_num_candidates, num_candidates, elapsed_time). + """ + with torch.profiler.record_function("sparse forward"): + module: torch.nn.Module = none_throws(self.module) + assert self.module is not None + uih_features = samples.uih_features_kjt + candidates_features = samples.candidates_features_kjt + t0: float = time.time() + ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) = module( + uih_features=uih_features, + candidates_features=candidates_features, + ) + dt_sparse: float = time.time() - t0 + return ( + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + dt_sparse, + ) + + +class ModelFamilyDenseDist: + """ + Distributed dense module manager for multi-GPU inference. + + Spawns worker processes for each GPU to run dense forward passes in parallel, + with samples distributed via inter-process queues. + + Args: + hstu_config: HSTU model configuration. + table_config: Embedding table configurations. + output_trace: Whether to enable profiling traces. + compute_eval: Whether to compute evaluation metrics. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + table_config: Dict[str, EmbeddingConfig], + output_trace: bool = False, + compute_eval: bool = False, + ) -> None: + super(ModelFamilyDenseDist, self).__init__() + self.hstu_config = hstu_config + self.table_config = table_config + self.output_trace = output_trace + self.compute_eval = compute_eval + + ngpus = torch.cuda.device_count() + self.world_size = int(os.environ.get("WORLD_SIZE", str(ngpus))) + self.rank = 0 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + self.dist_backend = "nccl" + + ctx = mp.get_context("spawn") + self.samples_q: List[mp.Queue] = [ctx.Queue() for _ in range(self.world_size)] + self.result_q: List[mp.Queue] = [ctx.Queue() for _ in range(self.world_size)] + + def load(self, model_path: str) -> None: + """ + Load dense model and spawn worker processes for distributed inference. + + Args: + model_path: Path to the model checkpoint directory. + """ + logger.warning(f"Loading dense module from {model_path}") + + ctx = mp.get_context("spawn") + processes = [] + for rank in range(self.world_size): + p = ctx.Process( + target=self.distributed_setup, + args=( + rank, + self.world_size, + model_path, + ), + ) + p.start() + processes.append(p) + + def distributed_setup(self, rank: int, world_size: int, model_path: str) -> None: + """ + Initialize and run a dense worker process. + + Each worker loads the model, processes samples from its queue, and + returns results. + + Args: + rank: Process rank (GPU index). + world_size: Total number of worker processes. + model_path: Path to model checkpoint. + """ + # nprocs_per_rank = 16 + # start_core: int = nprocs_per_rank * rank + 128 + # cores: set[int] = set([start_core + i for i in range(nprocs_per_rank)]) + # os.sched_setaffinity(0, cores) + set_is_inference(is_inference=not self.compute_eval) + model = get_hstu_model( + table_config=self.table_config, + hstu_config=self.hstu_config, + table_device="cpu", + max_hash_size=100, + is_dense=True, + ).to(torch.bfloat16) + model.set_training_dtype(torch.bfloat16) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(f"cuda:{rank}") + load_nonsparse_checkpoint( + model=model, device=device, optimizer=None, path=model_path + ) + model = model.to(device) + model.eval() + profiler = Profiler(rank) if self.output_trace else None + + with torch.no_grad(): + while True: + item = self.samples_q[rank].get() + # If -1 is received terminate all subprocesses + if item == -1: + break + if self.output_trace: + assert profiler is not None + profiler.step() + with torch.profiler.record_function("get_item_from_queue"): + # Copy here to release data in the producer to avoid invalid cuda caching allocator release. + item = copy.deepcopy(item) + ( + id, + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) = item + assert seq_embeddings is not None + with torch.profiler.record_function("dense forward"): + ( + _, + _, + _, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = model.main_forward( + seq_embeddings=seq_embeddings, + payload_features=payload_features, + max_uih_len=max_uih_len, + uih_seq_lengths=uih_seq_lengths, + max_num_candidates=max_num_candidates, + num_candidates=num_candidates, + ) + # mt_target_preds = torch.empty(1, 2048 * 20).to(device="cpu") + # mt_target_labels = None + # mt_target_weights = None + assert mt_target_preds is not None + mt_target_preds = mt_target_preds.detach().to(device="cpu") + if mt_target_labels is not None: + mt_target_labels = mt_target_labels.detach().to(device="cpu") + if mt_target_weights is not None: + mt_target_weights = mt_target_weights.detach().to(device="cpu") + self.result_q[rank].put( + (id, mt_target_preds, mt_target_labels, mt_target_weights) + ) + + def capture_output( + self, id: uuid.UUID, rank: int + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Retrieve inference results from a worker process. + + Args: + id: Unique identifier for the request. + rank: Worker rank to retrieve from. + + Returns: + Tuple of (predictions, labels, weights). + """ + while True: + recv_id, preds, labels, weights = self.result_q[rank].get() + assert recv_id == id + return preds, labels, weights + + def get_rank(self) -> int: + """ + Get the next worker rank for load balancing. + + Returns: + Rank index, cycling through available workers. + """ + rank = self.rank + self.rank = (self.rank + 1) % self.world_size + return rank + + def predict( + self, + seq_embeddings: Optional[Dict[str, SequenceEmbedding]], + payload_features: Optional[Dict[str, torch.Tensor]], + max_uih_len: int, + uih_seq_lengths: Optional[torch.Tensor], + max_num_candidates: int, + num_candidates: Optional[torch.Tensor], + ) -> Optional[ + Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], float] + ]: + """ + Run distributed dense forward pass. + + Dispatches work to a worker process and collects results. + + Args: + seq_embeddings: Sequence embeddings from sparse module. + payload_features: Additional feature tensors. + max_uih_len: Maximum UIH sequence length. + uih_seq_lengths: Per-sample UIH lengths. + max_num_candidates: Maximum candidates per sample. + num_candidates: Per-sample candidate counts. + + Returns: + Tuple of (predictions, labels, weights, elapsed_time) or None if shutdown. + """ + id = uuid.uuid4() + # If none is received terminate all subprocesses + if seq_embeddings is None: + for rank in range(self.world_size): + self.samples_q[rank].put(-1) + return None + rank = self.get_rank() + device = torch.device(f"cuda:{rank}") + assert ( + payload_features is not None + and num_candidates is not None + and uih_seq_lengths is not None + ) + t0: float = time.time() + seq_embeddings, payload_features, uih_seq_lengths, num_candidates = ( + move_sparse_output_to_device( + seq_embeddings=seq_embeddings, + payload_features=payload_features, + uih_seq_lengths=uih_seq_lengths, + num_candidates=num_candidates, + device=device, + ) + ) + self.samples_q[rank].put( + ( + id, + seq_embeddings, + payload_features, + max_uih_len, + uih_seq_lengths, + max_num_candidates, + num_candidates, + ) + ) + (mt_target_preds, mt_target_labels, mt_target_weights) = self.capture_output( + id, rank + ) + dt_dense = time.time() - t0 + return ( + mt_target_preds, + mt_target_labels, + mt_target_weights, + dt_dense, + ) + + +class ModelFamilyDenseSingleWorker: + """ + Single-worker dense module manager for single-GPU inference. + + Simpler alternative to ModelFamilyDenseDist for single-GPU setups. + + Args: + hstu_config: HSTU model configuration. + table_config: Embedding table configurations. + output_trace: Whether to enable profiling traces. + compute_eval: Whether to compute evaluation metrics. + """ + + def __init__( + self, + hstu_config: DlrmHSTUConfig, + table_config: Dict[str, EmbeddingConfig], + output_trace: bool = False, + compute_eval: bool = False, + ) -> None: + self.model: Optional[torch.nn.Module] = None + self.hstu_config = hstu_config + self.table_config = table_config + self.output_trace = output_trace + self.device: torch.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + self.profiler: Optional[Profiler] = ( + Profiler(rank=0) if self.output_trace else None + ) + + def load(self, model_path: str) -> None: + """ + Load dense model for single-GPU inference. + + Args: + model_path: Path to the model checkpoint directory. + """ + logger.warning(f"Loading dense module from {model_path}") + self.model = ( + get_hstu_model( + table_config=self.table_config, + hstu_config=self.hstu_config, + table_device="cpu", + is_dense=True, + ) + .to(self.device) + .to(torch.bfloat16) + ) + self.model.set_training_dtype(torch.bfloat16) + load_nonsparse_checkpoint( + model=self.model, device=self.device, optimizer=None, path=model_path + ) + assert self.model is not None + self.model.eval() + + def predict( + self, + seq_embeddings: Optional[Dict[str, SequenceEmbedding]], + payload_features: Optional[Dict[str, torch.Tensor]], + max_uih_len: int, + uih_seq_lengths: Optional[torch.Tensor], + max_num_candidates: int, + num_candidates: Optional[torch.Tensor], + ) -> Optional[ + Tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[torch.Tensor], + float, + ] + ]: + """ + Run dense forward pass on single GPU. + + Args: + seq_embeddings: Sequence embeddings from sparse module. + payload_features: Additional feature tensors. + max_uih_len: Maximum UIH sequence length. + uih_seq_lengths: Per-sample UIH lengths. + max_num_candidates: Maximum candidates per sample. + num_candidates: Per-sample candidate counts. + + Returns: + Tuple of (predictions, labels, weights, elapsed_time). + """ + if self.output_trace: + assert self.profiler is not None + self.profiler.step() + assert ( + payload_features is not None + and uih_seq_lengths is not None + and num_candidates is not None + and seq_embeddings is not None + ) + t0: float = time.time() + with torch.profiler.record_function("dense forward"): + seq_embeddings, payload_features, uih_seq_lengths, num_candidates = ( + move_sparse_output_to_device( + seq_embeddings=seq_embeddings, + payload_features=payload_features, + uih_seq_lengths=uih_seq_lengths, + num_candidates=num_candidates, + device=self.device, + ) + ) + assert self.model is not None + ( + _, + _, + _, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) = self.model.main_forward( # pyre-ignore [29] + seq_embeddings=seq_embeddings, + payload_features=payload_features, + max_uih_len=max_uih_len, + uih_seq_lengths=uih_seq_lengths, + max_num_candidates=max_num_candidates, + num_candidates=num_candidates, + ) + assert mt_target_preds is not None + mt_target_preds = mt_target_preds.detach().to(device="cpu") + if mt_target_labels is not None: + mt_target_labels = mt_target_labels.detach().to(device="cpu") + if mt_target_weights is not None: + mt_target_weights = mt_target_weights.detach().to(device="cpu") + dt_dense: float = time.time() - t0 + return mt_target_preds, mt_target_labels, mt_target_weights, dt_dense diff --git a/recommendation/dlrm_v3/requirements.txt b/recommendation/dlrm_v3/requirements.txt new file mode 100644 index 0000000000..bd76e80d94 --- /dev/null +++ b/recommendation/dlrm_v3/requirements.txt @@ -0,0 +1,6 @@ +torch==2.8.0 +fbgemm_gpu==1.3.0 +torchrec==1.3.0 +gin_config==0.5.0 +pandas==2.3.2 +tensorboard==2.20.0 diff --git a/recommendation/dlrm_v3/run_benchmark.sh b/recommendation/dlrm_v3/run_benchmark.sh new file mode 100644 index 0000000000..18cb4c5ad1 --- /dev/null +++ b/recommendation/dlrm_v3/run_benchmark.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 WORLD_SIZE=8 \ + python main.py --dataset sampled-streaming-100b 2>&1 | tee /home/$USER/dlrmv3-inference-benchmark.log diff --git a/recommendation/dlrm_v3/setup.sh b/recommendation/dlrm_v3/setup.sh new file mode 100644 index 0000000000..e25c63c800 --- /dev/null +++ b/recommendation/dlrm_v3/setup.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash + +conda create --name dlrmv3 python=3.13 +conda activate dlrmv3 +pip install -r requirements.txt +git_dir=$(git rev-parse --show-toplevel) +pip install $git_dir/loadgen diff --git a/recommendation/dlrm_v3/streaming_synthetic_data.py b/recommendation/dlrm_v3/streaming_synthetic_data.py new file mode 100644 index 0000000000..8046909e00 --- /dev/null +++ b/recommendation/dlrm_v3/streaming_synthetic_data.py @@ -0,0 +1,664 @@ +# pyre-strict +""" +Streaming synthetic data generator for DLRMv3. + +This module generates synthetic streaming recommendation data for benchmarking +and testing purposes. It creates user-item interaction histories with timestamps, +ratings, and category-based item distributions. +""" + +import csv +import logging +import math +import multiprocessing +import os +import random +import shutil +import time +from typing import Dict, List, Tuple + +import numpy as np + +logger: logging.Logger = logging.getLogger(__name__) + + +class StreamingSyntheticDataGenerator: + """ + Generator for streaming synthetic recommendation data. + + Creates realistic user-item interaction data with temporal dynamics, + category preferences, and rating distributions for benchmarking + recommendation systems. + + Args: + num_categories: Number of item categories. + categories_per_user: Number of categories each user is interested in. + num_users: Total number of users to generate. + num_items: Total number of items in the catalog. + num_timestamps: Number of time periods in the streaming data. + avg_samples_per_item: Average number of interactions per item. + train_ratio: Fraction of timestamps used for training. + user_sampling_ratio: Probability of sampling a user at each timestamp. + num_eval_candidates: Number of candidates for evaluation. + num_inference_candidates: Number of candidates for inference. + debug: If True, use deterministic ratings for debugging. + rank: Process rank for distributed generation. + """ + + def __init__( + self, + num_categories: int, + categories_per_user: int, + num_users: int, + num_items: int, + num_timestamps: int, + avg_samples_per_item: int, + train_ratio: float, + user_sampling_ratio: float, + num_eval_candidates: int, + num_inference_candidates: int, + debug: bool = False, + rank: int = 0, + ) -> None: + self.num_categories = num_categories + self.categories_per_user = categories_per_user + self.num_users = num_users + self.num_items = num_items + self.num_timestamps = num_timestamps + self.avg_samples_per_item = avg_samples_per_item + self.avg_seq_len_per_timestamp = int( + num_items * avg_samples_per_item / num_users / num_timestamps + ) + self.items_per_category: int = num_items // num_categories + self.category_to_start_end_item_idx: Dict[int, Tuple[int, int]] = {} + for i in range(num_categories): + start_idx = i * self.items_per_category + end_idx = (i + 1) * self.items_per_category + self.category_to_start_end_item_idx[i] = (start_idx, end_idx) + self.alpha_range = (1, 500) + self.min_seq_len: int = num_eval_candidates + 1 + self.train_ratio = train_ratio + self.num_eval_candidates = num_eval_candidates + self.num_inference_candidates = num_inference_candidates + self.debug = debug + self.total_cnt = 0 + self.rank = rank + logger.warning(f"rank {self.rank}: start generating item rating") + np.random.seed(1001) + self.item_rating = np.random.choice( # pyre-ignore [4] + [5.0, 4.0, 3.0, 2.0, 1.0], size=num_items, p=[0.2, 0.25, 0.25, 0.2, 0.1] + ) + logger.warning(f"rank {self.rank}: finish generating item rating") + self.user_sampling_ratio = user_sampling_ratio + + def generate_one_timestamp( + self, + category_to_cnt: Dict[int, int], + categories: List[int], + t: int, + id: int, + output_folder: str, + uih_seq_len: int, + eval: bool, + inference: bool, + file_idx: int, + ts_buffers: Dict[int, List[int]], + ) -> Tuple[List[int], List[float], List[int], List[float], Dict[int, int]]: + """ + Generate interaction data for a single user at one timestamp. + + Args: + category_to_cnt: Running count of interactions per category. + categories: Categories this user is interested in. + t: Current timestamp index. + id: User ID. + output_folder: Output directory for files. + uih_seq_len: Length of user interaction history to generate. + eval: Whether this is for evaluation. + inference: Whether this is for inference. + file_idx: File index for output. + ts_buffers: Buffer for timestamp data. + + Returns: + Tuple of (uih_item_ids, uih_ratings, candidate_ids, candidate_ratings, + updated_category_counts). + """ + if t >= 0 and (not eval): + if t not in ts_buffers: + ts_buffers[t] = [] + ts_buffers[t].append(id) + seq_len: int = self.num_inference_candidates if inference else uih_seq_len + self.total_cnt += seq_len + alpha = random.randint(self.alpha_range[0], self.alpha_range[1]) + total_cnt = sum(category_to_cnt.values()) + p = np.array( + [ + (alpha / len(categories) + category_to_cnt[c]) / (alpha + total_cnt) + for c in categories + ] + ) + item_categories = np.random.choice(categories, size=seq_len, p=p) + unique, counts = np.unique(item_categories, return_counts=True) + for cat, cnt in zip(unique, counts): + category_to_cnt[cat] += int(cnt) + sample_end_idx = int( + self.items_per_category * max((t + 1), 1) / self.num_timestamps + ) + sample_inds = np.random.randint(0, sample_end_idx, size=seq_len) + offsets = np.array( + [self.category_to_start_end_item_idx[cat][0] for cat in item_categories] + ) + sample_inds = sample_inds + offsets + num_categories = len(categories) + quarter = num_categories // 4 + half = num_categories // 2 + three_quarter = num_categories // 4 * 3 + category_to_ratings = {} + cos1 = math.cos(t * math.pi / 4) + cos2 = math.cos((t + 2) * math.pi / 4) + cos3 = math.cos((t + 4) * math.pi / 4) + for i, cat in enumerate(categories): + if i < quarter: + if self.debug: + ratings = np.full(seq_len, 5.0) + else: + ratings = np.random.choice( + [4.5 + 0.5 * cos1, 4.0 + 0.5 * cos2], + size=seq_len, + p=[0.8, 0.2], + ) + elif i < half: + if self.debug: + ratings = np.full(seq_len, 4.0) + else: + ratings = np.random.choice( + [4.5 + 0.5 * cos1, 4.0 + 0.5 * cos2, 3.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + elif i < three_quarter: + if self.debug: + ratings = np.full(seq_len, 3.0) + else: + ratings = np.random.choice( + [3.5 + 0.5 * cos1, 3.0 + 0.5 * cos2, 2.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + else: + if self.debug: + ratings = np.full(seq_len, 2.0) + else: + ratings = np.random.choice( + [2.5 + 0.5 * cos1, 2.0 + 0.5 * cos2, 1.5 + 0.5 * cos3], + size=seq_len, + p=[0.1, 0.8, 0.1], + ) + category_to_ratings[cat] = ratings + sample_inds = sample_inds.tolist() + sample_ratings = [ + ( + category_to_ratings[item_categories[i]][i] + + self.item_rating[sample_inds[i]] + ) + / 2 + for i in range(seq_len) + ] + if not inference: + sub_indices = random.sample(range(seq_len), self.num_eval_candidates) + sample_candidate_inds = [sample_inds[i] for i in sub_indices] + sample_candidate_ratings = [sample_ratings[i] for i in sub_indices] + sample_uih_inds = sample_inds + sample_uih_ratings = sample_ratings + else: + sub_indices = random.sample(range(seq_len), uih_seq_len) + sample_uih_inds = [sample_inds[i] for i in sub_indices] + sample_uih_ratings = [sample_ratings[i] for i in sub_indices] + sample_candidate_inds = sample_inds + sample_candidate_ratings = sample_ratings + return ( + sample_uih_inds, + sample_uih_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) + + def gen_rand_seq_len(self) -> int: + """ + Generate a random sequence length from a Gaussian distribution. + + Returns: + Sequence length, guaranteed to be at least min_seq_len. + """ + seq_len = round( + random.gauss( + self.avg_seq_len_per_timestamp, self.avg_seq_len_per_timestamp // 4 + ) + ) + seq_len = self.min_seq_len if seq_len < self.min_seq_len else seq_len + return seq_len + + def get_timestamp_sample(self, t: int) -> int: + """ + Determine if a user should be sampled at this timestamp. + + Args: + t: Timestamp index. Base timestamp (-1) is always sampled. + + Returns: + 1 if the user should be sampled, 0 otherwise. + """ + if t == -1: + sample = 1 + else: + sample = np.random.choice( + [1, 0], + size=1, + p=[self.user_sampling_ratio, 1 - self.user_sampling_ratio], + )[0] + return sample + + def generate_one_user( + self, + id: int, + output_folder: str, + file_idx: int, + ts_buffers: Dict[int, List[int]], + ) -> List[str]: + """ + Generate complete interaction history for one user. + + Creates training, evaluation, and inference data for a single user + across all timestamps. + + Args: + id: User ID. + output_folder: Output directory. + file_idx: File index for output. + ts_buffers: Buffer for timestamp metadata. + + Returns: + List of CSV row values for this user's data. + """ + categories = random.sample(range(self.num_categories), self.categories_per_user) + category_to_cnt = {c: 0 for c in categories} + out_list: List[str] = [] + # t = -1 as base UIH + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=-1, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append(",".join([str(rat) for rat in sample_candidate_ratings])) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + # train + for t in range(int(self.num_timestamps * self.train_ratio)): + if self.get_timestamp_sample(t): + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=t, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append( + ",".join([str(rat) for rat in sample_candidate_ratings]) + ) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + else: + out_list += ["", "", "", ""] + # eval + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=int(self.num_timestamps * self.train_ratio), + id=id, + output_folder=output_folder, + uih_seq_len=self.num_eval_candidates, + eval=True, + inference=False, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append(",".join([str(rat) for rat in sample_candidate_ratings])) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + # inference + for t in range( + int(self.num_timestamps * self.train_ratio), self.num_timestamps + ): + if self.get_timestamp_sample(t): + ( + sample_inds, + sample_ratings, + sample_candidate_inds, + sample_candidate_ratings, + category_to_cnt, + ) = self.generate_one_timestamp( + category_to_cnt=category_to_cnt, + categories=categories, + t=t, + id=id, + output_folder=output_folder, + uih_seq_len=self.gen_rand_seq_len(), + eval=False, + inference=True, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + out_list.append(",".join([str(ind) for ind in sample_candidate_inds])) + out_list.append( + ",".join([str(rat) for rat in sample_candidate_ratings]) + ) + out_list.append(",".join([str(ind) for ind in sample_inds])) + out_list.append(",".join([str(rat) for rat in sample_ratings])) + else: + out_list += ["", "", "", ""] + return out_list + + def write_dataset( + self, output_folder: str, num_files: int, file_idx: int, seed: int + ) -> None: + """ + Write dataset for a single file partition. + + Args: + output_folder: Output directory path. + num_files: Total number of files in the dataset. + file_idx: Index of this file partition. + seed: Random seed for reproducibility. + """ + t0 = time.time() + num_users_per_file = self.num_users // num_files + user_id: int = num_users_per_file * file_idx + random.seed(seed + file_idx) + np.random.seed(seed + file_idx) + # Buffer timestamp data in memory to avoid excessive file I/O + ts_buffers: Dict[int, List[int]] = {} + output_file = output_folder + f"{file_idx}.csv" + with open(output_file, "w") as file: + writer = csv.writer(file) + for i in range(num_users_per_file): + out_list = self.generate_one_user( + id=user_id, + output_folder=output_folder, + file_idx=file_idx, + ts_buffers=ts_buffers, + ) + user_id += 1 + writer.writerow(out_list) + if i % 10000 == 0: + logger.warning( + f"rank {self.rank}: Done with users {i} for file {file_idx + 1} / {num_files}, total_cnt = {self.total_cnt}, spends {time.time() - t0} seconds." + ) + # Write buffered timestamp data after all users are processed + for ts, user_ids in ts_buffers.items(): + ts_file = output_folder + f"ts_{file_idx}_{ts}.csv" + with open(ts_file, "w") as f: + writer = csv.writer(f) + for uid in user_ids: + writer.writerow([uid]) + logger.warning( + f"rank {self.rank}: Wrote {len(ts_buffers)} timestamp files for file {file_idx}" + ) + + +def worker( + rank: int, + world_size: int, + num_files: int, + num_users: int, + num_items: int, + num_categories: int, + categories_per_user: int, + num_timestamps: int, + avg_samples_per_item: int, + num_eval_candidates: int, + num_inference_candidates: int, + train_ratio: float, + user_sampling_ratio: float, + output_folder: str, +) -> None: + """ + Worker function for parallel data generation. + + Each worker generates a subset of the dataset files. + + Args: + rank: Worker rank. + world_size: Total number of workers. + num_files: Total files to generate. + num_users: Total users in dataset. + num_items: Total items in catalog. + num_categories: Number of item categories. + categories_per_user: Categories per user. + num_timestamps: Number of time periods. + avg_samples_per_item: Average interactions per item. + num_eval_candidates: Eval candidates count. + num_inference_candidates: Inference candidates count. + train_ratio: Training data fraction. + user_sampling_ratio: User sampling probability. + output_folder: Output directory. + """ + generator = StreamingSyntheticDataGenerator( + num_categories=num_categories, + categories_per_user=categories_per_user, + num_users=num_users, + num_items=num_items, + num_timestamps=num_timestamps, + avg_samples_per_item=avg_samples_per_item, + train_ratio=train_ratio, + user_sampling_ratio=user_sampling_ratio, + num_eval_candidates=num_eval_candidates, + num_inference_candidates=num_inference_candidates, + debug=False, + rank=rank, + ) + num_files_per_rank = num_files // world_size + file_indices = [i + rank * num_files_per_rank for i in range(num_files_per_rank)] + for file_idx in file_indices: + logger.warning(f"rank {rank}: start generating file {file_idx}") + generator.write_dataset( + output_folder=output_folder, + num_files=num_files, + file_idx=file_idx, + seed=1001, + ) + logger.warning(f"rank {rank}: finish generating file {file_idx}") + + +def write_offset(output_folder: str, num_files: int, num_users: int) -> None: + """ + Write file byte offsets for random access to user data. + + Creates an offset.csv file containing byte positions for each user + within their respective data files. + + Args: + output_folder: Directory containing data files. + num_files: Number of data files. + num_users: Total number of users. + """ + with open(output_folder + "offset.csv", "a") as output_file: + writer = csv.writer(output_file) + for i in range(num_files): + input_file = output_folder + f"{i}.csv" + offsets = [] + with open(input_file, "r") as f: + while True: + offset = f.tell() + line = f.readline() + if not line: + break + offsets.append(offset) + assert ( + len(offsets) == num_users // num_files + ), f"num_users {num_users // num_files} != {len(offsets)}" + logger.warning(f"offsets for file {i} finished") + writer.writerow([",".join([str(offset) for offset in offsets])]) + + +def write_ts_metadata(output_folder: str, total_ts: int, num_files: int) -> None: + """ + Write timestamp metadata for streaming simulation. + + Creates files tracking which users are active at each timestamp + and cumulative counts for efficient streaming access. + + Args: + output_folder: Output directory path. + total_ts: Total number of timestamps. + num_files: Number of data files. + """ + with open(output_folder + "requests_per_ts.csv", "w") as file_requests: + with open(output_folder + "users_cumsum_per_ts.csv", "w") as file_cumsum: + requests_writer = csv.writer(file_requests) + cumsum_writer = csv.writer(file_cumsum) + for ts in range(total_ts): + requests = [] + num_users_per_file = [] + for file in range(num_files): + with open(f"{output_folder}ts_{file}_{ts}.csv", "r") as file: + reader = csv.reader(file) + size = 0 + for row in reader: + requests.append(int(row[0])) + size += 1 + num_users_per_file.append(size) + cumsum = np.cumsum(num_users_per_file).tolist() + assert cumsum[-1] == len(requests) + requests_writer.writerow([",".join([str(r) for r in requests])]) + cumsum_writer.writerow([",".join([str(s) for s in cumsum])]) + logger.warning(f"ts {ts} finished") + with open( + output_folder + "requests_per_ts_offset.csv", "w" + ) as file_requests_offset: + writer = csv.writer(file_requests_offset) + input_file = output_folder + "requests_per_ts.csv" + offsets = [] + with open(input_file, "r") as f: + while True: + offset = f.tell() + line = f.readline() + if not line: + break + offsets.append(offset) + assert len(offsets) == total_ts, f"total_ts {total_ts} != {len(offsets)}" + logger.warning("offsets for file requests_per_ts.csv finished") + writer.writerow([",".join([str(offset) for offset in offsets])]) + + +def copy_sub_dataset(src_folder: str) -> None: + """ + Copy a subset of dataset files for quick testing. + + Creates a sampled_data subdirectory with essential files. + + Args: + src_folder: Source folder containing full dataset. + """ + dst_folder = src_folder + "sampled_data/" + files_to_copy = [ + "0.csv", + "offset.csv", + "requests_per_ts.csv", + "requests_per_ts_offset.csv", + "users_cumsum_per_ts.csv", + ] + os.makedirs(dst_folder, exist_ok=True) + for filename in files_to_copy: + src_path = os.path.join(src_folder, filename) + dst_path = os.path.join(dst_folder, filename) + shutil.copy2(src_path, dst_path) + logger.warning("Files copied successfully.") + + +def main() -> None: + """ + Main entry point for synthetic data generation. + + Configures and launches parallel workers to generate a complete + streaming recommendation dataset. + """ + processes = [] + num_files = 100 + num_users = 5_000_000 + num_items = 1_000_000_000 + num_categories = 128 + categories_per_user = 4 + num_timestamps = 100 + avg_samples_per_item = 50 + num_eval_candidates = 32 + num_inference_candidates = 2048 + train_ratio = 0.9 + user_sampling_ratio = 0.7 + world_size = 5 + username = os.getlogin() + output_folder = f"/home/{username}/data/streaming-100b/" + for i in range(world_size): + p = multiprocessing.Process( + target=worker, + args=( + i, + world_size, + num_files, + num_users, + num_items, + num_categories, + categories_per_user, + num_timestamps, + avg_samples_per_item, + num_eval_candidates, + num_inference_candidates, + train_ratio, + user_sampling_ratio, + output_folder, + ), + ) + processes.append(p) + p.start() + for p in processes: + p.join() + write_offset(output_folder, num_files, num_users) + write_ts_metadata(output_folder, num_timestamps, num_files) + copy_sub_dataset(src_folder=output_folder) + + +if __name__ == "__main__": + main() diff --git a/recommendation/dlrm_v3/user.conf b/recommendation/dlrm_v3/user.conf new file mode 100644 index 0000000000..6ee9e874f8 --- /dev/null +++ b/recommendation/dlrm_v3/user.conf @@ -0,0 +1,3 @@ +# Please set these fields depending on the performance of your system to +# override default LoadGen settings. +*.Server.target_latency = 80 diff --git a/recommendation/dlrm_v3/utils.py b/recommendation/dlrm_v3/utils.py new file mode 100644 index 0000000000..4d18d360d1 --- /dev/null +++ b/recommendation/dlrm_v3/utils.py @@ -0,0 +1,417 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pyre-unsafe +""" +mlperf dlrm_v3 inference benchmarking tool. +""" + +import contextlib +import logging +import os +from typing import Callable, Dict, List, Optional + +import gin +import tensorboard # @manual=//tensorboard:lib # noqa: F401 - required implicit dep when using torch.utils.tensorboard + +import torch +from datasets.dataset import DLRMv3RandomDataset +from datasets.synthetic_streaming import ( + DLRMv3SyntheticStreamingDataset, +) +from generative_recommenders.modules.multitask_module import ( + MultitaskTaskType, + TaskConfig, +) +from torch.profiler import profile, profiler, ProfilerActivity # pyre-ignore [21] +from torch.utils.tensorboard import SummaryWriter +from torchrec.metrics.accuracy import AccuracyMetricComputation +from torchrec.metrics.gauc import GAUCMetricComputation +from torchrec.metrics.mae import MAEMetricComputation +from torchrec.metrics.mse import MSEMetricComputation +from torchrec.metrics.ne import NEMetricComputation + +from torchrec.metrics.rec_metric import RecMetricComputation + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("utils") + + +def _on_trace_ready_fn( + rank: Optional[int] = None, +) -> Callable[[torch.profiler.profile], None]: + """ + Create a callback function for handling profiler trace output. + + Args: + rank: Optional process rank for distributed training (included in filename). + + Returns: + A callback function that exports profiler traces to Manifold storage. + """ + def handle_fn(p: torch.profiler.profile) -> None: + bucket_name = "hammer_gpu_traces" + pid = os.getpid() + rank_str = f"_rank_{rank}" if rank is not None else "" + file_name = f"libkineto_activities_{pid}_{rank_str}.json" + manifold_path = "tree/dlrm_v3_bench" + target_object_name = manifold_path + "/" + file_name + ".gz" + path = f"manifold://{bucket_name}/{manifold_path}/{file_name}" + logger.warning( + p.key_averages(group_by_input_shape=True).table( + sort_by="self_cuda_time_total" + ) + ) + logger.warning( + f"trace url: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath={target_object_name}&bucket={bucket_name}" + ) + p.export_chrome_trace(path) + + return handle_fn + + +def profiler_or_nullcontext(enabled: bool, with_stack: bool): + """ + Create a profiler context manager or null context based on enabled flag. + + Args: + enabled: Whether to enable profiling. + with_stack: Whether to include stack traces in profile. + + Returns: + Either a torch.profiler.profile context manager or nullcontext. + """ + return ( + profile( + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + on_trace_ready=_on_trace_ready_fn(), + with_stack=with_stack, + ) + if enabled + else contextlib.nullcontext() + ) + + +class Profiler: + """ + Wrapper around PyTorch profiler with scheduled profiling. + + Implements a wait-warmup-active schedule for controlled profiling that + avoids startup noise and captures representative performance data. + + Args: + rank: Process rank for trace file naming. + active: Number of active profiling steps (default: 50). + """ + + def __init__(self, rank, active: int = 50) -> None: + self.rank = rank + self._profiler: profiler.profile = torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=10, + warmup=20, + active=active, + repeat=1, + ), + on_trace_ready=_on_trace_ready_fn(self.rank), + # pyre-fixme[16]: Module `profiler` has no attribute `ProfilerActivity`. + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=True, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + ) + + def step(self) -> None: + """Advance the profiler to the next step.""" + self._profiler.step() + + +@gin.configurable +class MetricsLogger: + """ + Logger for tracking and computing recommendation metrics. + + Supports both classification metrics (NE, Accuracy, GAUC) and regression + metrics (MSE, MAE) based on multitask configuration. + + Args: + multitask_configs: List of task configurations defining metric types. + batch_size: Batch size for metric computation. + window_size: Window size for running metric aggregation. + device: Device to place metric tensors on. + rank: Process rank for distributed training. + tensorboard_log_path: Optional path for TensorBoard logging. + """ + + def __init__( + self, + multitask_configs: List[TaskConfig], + batch_size: int, + window_size: int, + device: torch.device, + rank: int, + tensorboard_log_path: str = "", + ) -> None: + self.multitask_configs: List[TaskConfig] = multitask_configs + all_classification_tasks: List[str] = [ + task.task_name + for task in self.multitask_configs + if task.task_type != MultitaskTaskType.REGRESSION + ] + all_regression_tasks: List[str] = [ + task.task_name + for task in self.multitask_configs + if task.task_type == MultitaskTaskType.REGRESSION + ] + assert all_classification_tasks + all_regression_tasks == [ + task.task_name for task in multitask_configs + ] + self.task_names: List[str] = all_classification_tasks + all_regression_tasks + + self.class_metrics: Dict[str, List[RecMetricComputation]] = { + "train": [], + "eval": [], + } + if all_classification_tasks: + for mode in ["train", "eval"]: + self.class_metrics[mode].append( + NEMetricComputation( + my_rank=rank, + batch_size=batch_size, + n_tasks=len(all_classification_tasks), + window_size=window_size, + ).to(device) + ) + self.class_metrics[mode].append( + AccuracyMetricComputation( + my_rank=rank, + batch_size=batch_size, + n_tasks=len(all_classification_tasks), + window_size=window_size, + ).to(device) + ) + self.class_metrics[mode].append( + GAUCMetricComputation( + my_rank=rank, + batch_size=batch_size, + n_tasks=len(all_classification_tasks), + window_size=window_size, + ).to(device) + ) + + self.regression_metrics: Dict[str, List[RecMetricComputation]] = { + "train": [], + "eval": [], + } + if all_regression_tasks: + for mode in ["train", "eval"]: + self.regression_metrics[mode].append( + MSEMetricComputation( + my_rank=rank, + batch_size=batch_size, + n_tasks=len(all_regression_tasks), + window_size=window_size, + ).to(device) + ) + self.regression_metrics[mode].append( + MAEMetricComputation( + my_rank=rank, + batch_size=batch_size, + n_tasks=len(all_regression_tasks), + window_size=window_size, + ).to(device) + ) + + self.global_step: Dict[str, int] = {"train": 0, "eval": 0} + self.tb_logger: Optional[SummaryWriter] = None + if tensorboard_log_path != "": + self.tb_logger = SummaryWriter(log_dir=tensorboard_log_path, purge_step=0) + self.tb_logger.flush() + + @property + def all_metrics(self) -> Dict[str, List[RecMetricComputation]]: + """ + Get all metrics for train and eval modes. + + Returns: + Dictionary mapping mode ('train'/'eval') to list of metric computations. + """ + return { + "train": self.class_metrics["train"] + self.regression_metrics["train"], + "eval": self.class_metrics["eval"] + self.regression_metrics["eval"], + } + + def update( + self, + predictions: torch.Tensor, + weights: torch.Tensor, + labels: torch.Tensor, + num_candidates: torch.Tensor, + mode: str = "train", + ) -> None: + """ + Update metrics with new batch of predictions and labels. + + Args: + predictions: Model prediction tensor. + weights: Sample weight tensor. + labels: Ground truth label tensor. + num_candidates: Number of candidates per sample (for GAUC). + mode: Either 'train' or 'eval'. + """ + for metric in self.all_metrics[mode]: + if isinstance(metric, GAUCMetricComputation): + metric.update( + predictions=predictions, + labels=labels, + weights=weights, + num_candidates=num_candidates, + ) + else: + metric.update( + predictions=predictions, + labels=labels, + weights=weights, + ) + self.global_step[mode] += 1 + + def compute(self, mode: str = "train") -> Dict[str, float]: + """ + Compute and return all metrics for the current window. + + Args: + mode: Either 'train' or 'eval'. + + Returns: + Dictionary mapping metric names to their computed values. + """ + all_computed_metrics = {} + + for metric in self.all_metrics[mode]: + computed_metrics = metric.compute() + for computed in computed_metrics: + all_values = computed.value.cpu() + for i, task_name in enumerate(self.task_names): + key = f"metric/{str(computed.metric_prefix) + str(computed.name)}/{task_name}" + all_computed_metrics[key] = all_values[i] + + logger.info( + f"{mode} - Step {self.global_step[mode]} metrics: {all_computed_metrics}" + ) + return all_computed_metrics + + def compute_and_log( + self, + mode: str = "train", + additional_logs: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + ) -> Dict[str, float]: + """ + Compute metrics and log to TensorBoard. + + Args: + mode: Either 'train' or 'eval'. + additional_logs: Optional additional data to log. + + Returns: + Dictionary mapping metric names to their computed values. + + Raises: + AssertionError: If TensorBoard logger is not configured. + """ + assert self.tb_logger is not None + all_computed_metrics = self.compute(mode) + for k, v in all_computed_metrics.items(): + self.tb_logger.add_scalar( # pyre-ignore [16] + f"{mode}_{k}", + v, + global_step=self.global_step[mode], + ) + + if additional_logs is not None: + for tag, data in additional_logs.items(): + for data_name, data_value in data.items(): + self.tb_logger.add_scalar( + f"{tag}/{mode}_{data_name}", + data_value.detach().clone().cpu(), + global_step=self.global_step[mode], + ) + return all_computed_metrics + + def reset(self, mode: str = "train"): + """ + Reset all metrics for a given mode. + + Args: + mode: Either 'train' or 'eval'. + """ + for metric in self.all_metrics[mode]: + metric.reset() + + +# the datasets we support +SUPPORTED_DATASETS = [ + "streaming-100b", + "sampled-streaming-100b", +] + + +@gin.configurable +def get_dataset(name: str, new_path_prefix: str = ""): + """ + Get dataset class and configuration by name. + + Args: + name: Dataset identifier (must be in SUPPORTED_DATASETS). + new_path_prefix: Optional prefix to prepend to data paths. + + Returns: + Tuple of (dataset_class, kwargs_dict) for dataset instantiation. + + Raises: + AssertionError: If dataset name is not supported. + """ + assert name in SUPPORTED_DATASETS, f"dataset {name} not supported" + if name == "streaming-100b": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "" + ), + "train_ts": 90, + "total_ts": 100, + "num_files": 100, + "num_users": 5_000_000, + "num_items": 1_000_000_000, + "num_categories": 128, + }, + ) + if name == "sampled-streaming-100b": + return ( + DLRMv3SyntheticStreamingDataset, + { + "ratings_file_prefix": os.path.join( + new_path_prefix, "sampled_data/" + ), + "train_ts": 90, + "total_ts": 100, + "num_files": 1, + "num_users": 50_000, + "num_items": 1_000_000_000, + "num_categories": 128, + }, + )