Skip to content

[feature] Non-zero gamma support in ChoppedTransferCompound#764

Open
Zhaoxian-Wu wants to merge 5 commits intoIBM:masterfrom
Zhaoxian-Wu:feat/chopper-gamma
Open

[feature] Non-zero gamma support in ChoppedTransferCompound#764
Zhaoxian-Wu wants to merge 5 commits intoIBM:masterfrom
Zhaoxian-Wu:feat/chopper-gamma

Conversation

@Zhaoxian-Wu
Copy link
Copy Markdown

Non-zero gamma support, scale_fast_lr, and documentation

Overview

This PR makes two improvements to ChoppedTransferRPUDevice / ChoppedTransferCompound:

  1. Non-zero gamma supportChoppedTransferCompound can now be used as a residual-learning device where the fast array A contributes directly to the effective weight, with correct chopper de-correlation applied during weight reduction.
  2. scale_fast_lr parameter — a new parameter, analogous to the existing scale_transfer_lr, that controls whether the fast-device LR tracks the current optimizer LR.
  3. Document update — A new paragraph to discuss the benefit of gamma is attached to the documents. To provide sufficient context, this PR also expands the algorithm documentation to cover TTv1 through TTv4.

1. Non-zero gamma support in ChoppedTransferRPUDevice

Previously, checkSupported() enforced fullyHidden(), which hard-blocked any configuration where the fast array A contributes to the visible weight (gamma != 0). This restriction is lifted, and correct behaviour is implemented via a reduceToWeights override (CPU + CUDA).

Background: A is updated with per-element chopper sign flips and is therefore stored in "chopped" form: A_stored[i,j] ≈ c_d[i]·c_x[j]·A_true[i,j]. The base-class weight-reduction GEMV computes W = gamma·A_stored + C, which is incorrect when gamma != 0 because the chopper factors are not cancelled. The new override applies a correction after the GEMV:

W[i,j] += gamma * (c_d[i]*c_x[j] - 1) * A_stored[i,j]

The - 1 term accounts for the fact that the base GEMV already contributed gamma * A_stored; the correction adds only the remaining gamma * (c_d[i]*c_x[j] - 1) * A_stored to reach the correct gamma * (c_d[i]*c_x[j]) * A_stored. On CUDA this is implemented as the new kernelApplyChopperCorrectionToWeights kernel; on CPU it is a simple loop. Both paths are no-ops when gamma == 0 (the default).

Motivation: Non-zero gamma implements the residual learning mechanism described in Wu et al. (2025) [15] and Li et al. [19]: A acts as a residual correction on top of C, compensating for C's quantisation errors and device non-idealities cycle-by-cycle, while C accumulates the long-term gradient signal via discrete transfer pulses. The two-array decomposition W = gamma·A + C also enables bit-slicing (precision enhancement): A can represent finer-grained updates than C's native conductance step, reducing the effective weight granularity without modifying the underlying analog device.

Files: src/rpucuda/rpu_chopped_transfer_device.{cpp,h},
src/rpucuda/cuda/rpucuda_chopped_transfer_device.{cu,h}


2. scale_fast_lr parameter

scale_fast_lr is introduced as the analogue of the existing scale_transfer_lr: just as scale_transfer_lr controls whether transfer_lr is multiplied by the current optimizer LR, scale_fast_lr controls the same behaviour for fast_lr.

The parameter is added to TransferRPUDeviceMetaParameter (C++ base struct, default True) and exposed to the Python bindings and to the TransferCompound dataclass. ChoppedTransferCompound overrides the default to False, consistent with the existing convention for that device class.

The corresponding logic is implemented in:

  • TransferRPUDevice<T>::getPulseCountLearningRate (CPU)
  • TransferRPUDeviceCuda<T>::getPulseCountLearningRate (CUDA)
  • ChoppedTransferRPUDevice[Cuda]<T>::getPulseCountLearningRate,
    auto_scale branch (CPU + CUDA)

Files: src/rpucuda/rpu_transfer_device.{cpp,h},
src/rpucuda/cuda/rpucuda_transfer_device.cu,
src/rpucuda/rpu_chopped_transfer_device.cpp,
src/rpucuda/cuda/rpucuda_chopped_transfer_device.cu,
src/aihwkit/simulator/rpu_base_src/rpu_base_devices.cpp,
src/aihwkit/simulator/configs/compounds.py


3. Documentation

compounds.pyChoppedTransferCompound docstring

A detailed pseudocode block is added to the ChoppedTransferCompound docstring to improve readability and serve as the authoritative reference for the internal LR-scaling logic. The block covers:

  • base_buffer_granularity — threshold calculation from buffer_granularity, dw_min_A, and the optional auto_granularity period scaling
  • final_fast_lr — derivation from fast_lr, scale_fast_lr, the fast_lr=0 fallback, and the auto_scale formula
    (base_fast_lr * desired_BL * dw_min_A / (x_max * d_max))
  • final_transfer_lr — both the default and correct_gradient_magnitudes branches
  • Recursion — numbered step-by-step listing of the full update: gradient site (W = gamma·A + C), chopper application, H accumulation, threshold test, pulse dispatch (C += n_steps·dw_min_C), and the forget_buffer / momentum interaction

docs/source/analog_update.rst

The algorithm overview is extended from three methods (Plain SGD, Mixed Precision, Tiki-taka) to the full TTv1-TTv4 family. New sections:

  • TTv1 formulation with residual-learning and bit-slicing discussion (non-zero gamma, Wu et al. [15], Li et al. [18])
  • TTv2 (Buffered Transfer) — floating-point H buffer between A and C
  • TTv3 / c-TTv2 (Chopped Buffered Transfer) — chopper-modulated reads
  • TTv4 / AGAD (Dynamic Chopped Transfer) — on-the-fly symmetric-point estimation

docs/source/using_simulator.rst

  • Compound device table extended with BufferedTransferCompound, ChoppedTransferCompound, and DynamicTransferCompound entries
  • Buffered Transfer Compound Device (TTv2) subsection added
  • Transfer Compound Device (TTv1) section extended with a residual-learning discussion and math formulas for non-zero gamma
  • Chopped Transfer Compound Device (TTv3) expanded with buffer-strategy comparison and a residual-learning subsection with code examples
  • Dynamic Transfer Compound Device (TTv4 / AGAD) subsection added

docs/source/paper_references.rst

Five new references added:

# Paper
[15] Wu et al. 2025 NeurIPS - analog training on non-ideal devices
[16] Gokmen 2021 Frontiers AI - TTv2 (buffered transfer)
[17] Rasch et al. 2024 Nature Comms — TTv3/TTv4 (c-TTv2 / AGAD)
[18] Li et al. 2026 AISTATS - Residual learning on multi-arrays training

Testing

  • Non-zero gamma: set gamma=0.1 in a ChoppedTransferCompound config; verify that the tile's visible weight equals gamma·chop_corrected_A + C rather than gamma·A_stored + C.
  • scale_fast_lr: train with fast_lr > 0, scale_fast_lr=True, and a LR scheduler; verify the effective pulse-count LR tracks the optimizer LR on both CPU and CUDA, including with auto_scale=True.

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hey @Zhaoxian-Wu, can you please address the lint errors and update the pr with a new commit, so we can check everything is right? Thanks!

Checkout errors here: https://github.com/IBM/aihwkit/actions/runs/23523015289/job/69153947890?pr=764

@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @maljoras @maljoras-sony can you take a look and help us here?

Bug fixes:
- CPU TransferRPUDevice::getPulseCountLearningRate now honours
  scale_fast_lr (was always returning raw fast_lr)
- CPU and CUDA ChoppedTransferRPUDevice::getPulseCountLearningRate
  now applies scale_fast_lr in the auto_scale branch (was missing)
- Remove duplicate auto_momentum line in printToStream

Feature: reduceToWeights for non-zero gamma
- Add ChoppedTransferRPUDevice::reduceToWeights (CPU) and
  ChoppedTransferRPUDeviceCuda::reduceToWeights (CUDA) that apply
  per-element chopper correction when gamma != 0:
    W[i,j] += gamma * (c_d[i]*c_x[j] - 1) * A_stored[i,j]
  Enables residual-learning configurations with ChoppedTransfer.
- New CUDA kernel: kernelApplyChopperCorrectionToWeights

Cleanup:
- Remove partial buffer_as_momentum field and its CUDA kernel
- Expose scale_fast_lr to Python TransferCompound (default True;
  ChoppedTransferCompound keeps its existing default False)

Docs:
- Rewrite ChoppedTransferCompound docstring: corrected
  base_buffer_granularity / final_fast_lr / final_transfer_lr
  formulas and full numbered recursion pseudocode
- analog_update.rst: add TTv2 / TTv3 / TTv4 / RL-v2 sections,
  residual-learning and bit-slicing discussion
- using_simulator.rst: add per-algorithm subsections (TTv2, TTv3,
  TTv4) and gamma residual-learning explanation with code examples
- paper_references.rst: add references [15]-[19]

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hey @Zhaoxian-Wu, can you please address the lint errors and update the pr with a new commit, so we can check everything is right? Thanks!

Checkout errors here: https://github.com/IBM/aihwkit/actions/runs/23523015289/job/69153947890?pr=764

Thanks for your reminder! I have already fixed the lint and style errors of two PRs. Feel free to let me know of any other improvements.

@maljoras-sony
Copy link
Copy Markdown
Contributor

@Zhaoxian-Wu looks like a great addition, many thanks. @PabloCarmona, Let me find some time over the weekend to take a closer look.

@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Hi @PabloCarmona ,

I noticed that the CI lint check is failing with the following mypy errors in src/aihwkit/simulator/tiles/periphery.py:

src/aihwkit/simulator/tiles/periphery.py:983: error: Expected iterable as variadic argument  [misc]
src/aihwkit/simulator/tiles/periphery.py:1009: error: Expected iterable as variadic argument  [misc]
src/aihwkit/simulator/tiles/periphery.py:1011: error: Expected iterable as variadic argument  [misc]

Weirdly, this error seems to exist on master as well — they are not introduced by the PR. Could you run the following on your end to confirm?

mypy --show-error-codes src/

Here is my speculation. The issue is in add_quant_periphery_bias(), where tensor_view: Optional[Tuple] is used as a variadic argument (*tensor_view) without a None guard. This was silently accepted by older torch type stubs, but appears to have been exposed by a recent torch release that ships stricter .pyi stubs for Tensor.view().

@PabloCarmona
Copy link
Copy Markdown
Collaborator

Thanks @Zhaoxian-Wu, I will take a closer look and fix it in master. I'll let you know when I finish. In the meantime, let's also give time to @maljoras-sony to look at the PR and review. Thanks again to both!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants