Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -1703,3 +1703,59 @@ llama4_rope,huggingface,full,memory,MB,T,sequence length,2048,314.01611328125,31
llama4_rope,huggingface,full,memory,MB,T,sequence length,4096,596.03173828125,596.03173828125,596.03173828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
llama4_rope,huggingface,full,memory,MB,T,sequence length,8192,1160.06298828125,1160.06298828125,1160.06298828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
llama4_rope,huggingface,full,memory,MB,T,sequence length,16384,2288.12548828125,2288.12548828125,2288.12548828125,"{""dtype"": ""torch.bfloat16"", ""hidden_size"": 8192, ""num_q_heads"": 32, ""num_kv_heads"": 8}",NVIDIA H100 80GB HBM3,2025-08-07 21:42:21,0.6.1
rms_norm,liger,forward,speed,ms,H,hidden size,1024,0.012608000077307224,0.01228800043463707,0.013311999849975109,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,2048,0.023552000522613525,0.02252800017595291,0.023552000522613525,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,4096,0.04198399931192398,0.04193919897079468,0.04198399931192398,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,8192,0.08191999793052673,0.08089599758386612,0.08191999793052673,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,16384,0.16076800227165222,0.15987199544906616,0.16150400042533875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,32768,0.4485119879245758,0.44543999433517456,0.45158401131629944,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,liger,forward,speed,ms,H,hidden size,65536,3.617824077606201,3.6136960983276367,3.6296703815460205,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:22,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,1024,0.11366400122642517,0.11366400122642517,0.11468800157308578,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,2048,0.20684799551963806,0.20582400262355804,0.20684799551963806,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,4096,0.39321601390838623,0.3929600119590759,0.39423999190330505,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,8192,0.7639039754867554,0.7628800272941589,0.7649279832839966,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,16384,1.5032320022583008,1.5022079944610596,1.5052800178527832,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,32768,2.9818880558013916,2.979840040206909,2.984665632247925,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,huggingface,forward,speed,ms,H,hidden size,65536,5.941247940063477,5.938176155090332,5.954355239868164,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:26,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,1024,0.6062080264091492,0.5888000130653381,0.6339135766029358,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,2048,0.6021119952201843,0.5847616195678711,0.6299136281013489,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,4096,0.6082559823989868,0.5939711928367615,0.6330047845840454,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,8192,0.6133760213851929,0.5920000076293945,0.6430720090866089,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,16384,0.6082559823989868,0.6072319746017456,0.626035213470459,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,32768,3.246112108230591,3.240755081176758,3.2557055950164795,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,liger,full,speed,ms,H,hidden size,65536,12.116991996765137,12.101632118225098,12.507136344909668,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:30,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,1024,0.6912000179290771,0.6686336398124695,0.7167999744415283,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,2048,0.8130559921264648,0.8120319843292236,0.814079999923706,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,4096,1.5523840188980103,1.5508607625961304,1.5544320344924927,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,8192,3.018752098083496,3.0158207416534424,3.0257151126861572,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,16384,5.935103893280029,5.933055877685547,5.961728096008301,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,32768,11.782511711120605,11.776410102844238,12.265472412109375,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,huggingface,full,speed,ms,H,hidden size,65536,23.446529388427734,23.42975425720215,24.190053939819336,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:34,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,1024,0.05222399905323982,0.031007999554276466,0.22066561877727509,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,2048,0.04198399931192398,0.04095999896526337,0.1603199690580368,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,4096,0.07168000191450119,0.07168000191450119,0.2099200040102005,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,8192,0.13516800105571747,0.13414399325847626,0.13619199395179749,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,16384,0.4505600035190582,0.449535995721817,0.45260798931121826,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,32768,2.8037118911743164,2.7996160984039307,2.811903953552246,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,liger,backward,speed,ms,H,hidden size,65536,8.493056297302246,8.485785484313965,8.521522521972656,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:38,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,1024,0.32972800731658936,0.32870399951934814,0.3307519853115082,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,2048,0.6082559823989868,0.6072319746017456,0.609279990196228,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,4096,1.1612160205841064,1.1601344347000122,1.1624319553375244,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,8192,2.2568960189819336,2.254848003387451,2.2599680423736572,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,16384,4.436992168426514,4.433919906616211,4.460339069366455,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,32768,8.804351806640625,8.800000190734863,9.1975679397583,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,huggingface,backward,speed,ms,H,hidden size,65536,17.515518188476562,17.501888275146484,18.062015533447266,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:41,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,1024,12.34375,12.34375,12.34375,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,2048,24.6796875,24.6796875,24.6796875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,4096,49.3515625,49.3515625,49.3515625,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,8192,98.6953125,98.6953125,98.6953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,16384,197.3828125,197.3828125,197.3828125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,32768,394.7578125,394.7578125,394.7578125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,liger,full,memory,MB,H,hidden size,65536,789.5078125,789.5078125,789.5078125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,1024,80.01953125,80.01953125,80.01953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,2048,160.03125,160.03125,160.03125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,4096,320.0546875,320.0546875,320.0546875,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,8192,640.1015625,640.1015625,640.1015625,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,16384,1280.1953125,1280.1953125,1280.1953125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,32768,2560.3828125,2560.3828125,2560.3828125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
rms_norm,huggingface,full,memory,MB,H,hidden size,65536,5120.7578125,5120.7578125,5120.7578125,"{""M"": 2048, ""dtype"": ""torch.bfloat16"", ""eps"": 1e-06}",NVIDIA GeForce RTX 3090,2025-07-12 18:21:42,0.6.0
2 changes: 1 addition & 1 deletion benchmark/scripts/benchmark_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def full():
"kernel_name": "rms_norm",
"x_name": "H",
"x_label": "hidden size",
"x_values": [2**i for i in range(10, 16)],
"x_values": [2**i for i in range(10, 17)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
"overwrite": args.overwrite,
Expand Down
2 changes: 1 addition & 1 deletion benchmark/scripts/benchmark_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def full():
kernel_name="softmax",
x_name="N",
x_label="hidden size",
x_values=[128, 256, 512, 1024, 2048, 4096],
x_values=[2**i for i in range(8, 17)],
kernel_providers=["liger", "torch"],
extra_benchmark_configs=[
{"M": 2048, "dtype": torch.float32},
Expand Down
26 changes: 22 additions & 4 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import triton
import triton.language as tl

from liger_kernel.ops.utils import calculate_num_stages
from liger_kernel.ops.utils import calculate_settings
from liger_kernel.ops.utils import compare_version
from liger_kernel.ops.utils import ensure_contiguous
Expand Down Expand Up @@ -139,7 +140,7 @@ def _rms_norm_backward_kernel(

row_block_id = tl.program_id(0).to(tl.int64)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
row_end = tl.minimum((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols

Expand Down Expand Up @@ -312,7 +313,7 @@ def _block_rms_norm_backward_kernel(
W_row = tl.load(W_ptr + col_offsets, mask=col_mask, other=0.0)
W_row = W_row + offset

for start in range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW):
for start in tl.range(pid * BLOCK_ROW, n_rows, NUM_SMS * BLOCK_ROW, num_stages=2):
row_idx = start + tl.arange(0, BLOCK_ROW)
row_mask = row_idx < n_rows
dY_row = tl.load(
Expand Down Expand Up @@ -382,6 +383,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
num_stages = calculate_num_stages()

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# RSTD is to cache rstd for each row
Expand Down Expand Up @@ -412,6 +414,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
**kernel_args, # XPU-specific optimization
)
else:
Expand All @@ -433,12 +436,13 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode):
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
**kernel_args, # XPU-specific optimization
)
return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode


def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode):
def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, num_stages, in_place, row_mode):
shape = dY.shape
dim = shape[-1]
dY = dY.view(-1, dim)
Expand Down Expand Up @@ -490,6 +494,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
**kernel_args, # XPU-specific optimization
)
else:
Expand All @@ -516,6 +521,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_stages=num_stages,
**kernel_args, # XPU-specific optimization
)
dX = dX.view(*shape)
Expand Down Expand Up @@ -554,12 +560,14 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row
W: (H,)
"""
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode)
num_stages = calculate_num_stages()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can just return num_stages from rms_norm_forward() like num_warps to avoid calling it again.

ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.in_place = in_place
ctx.row_mode = row_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.num_stages = num_stages
ctx.save_for_backward(X, W, RSTD)
return Y

Expand All @@ -571,6 +579,16 @@ def backward(ctx, dY):
"""
X, W, RSTD = ctx.saved_tensors
dX, dW = rms_norm_backward(
dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode
dY,
X,
W,
RSTD,
ctx.offset,
ctx.casting_mode,
ctx.BLOCK_SIZE,
ctx.num_warps,
ctx.num_stages,
ctx.in_place,
ctx.row_mode,
)
return dX, dW, None, None, None, None, None
Loading
Loading