Skip to content

feat(ctx): define data attribution for AnalogContext with read-only protection#765

Open
Zhaoxian-Wu wants to merge 4 commits intoIBM:masterfrom
Zhaoxian-Wu:feat/AnalogCtx-attribution
Open

feat(ctx): define data attribution for AnalogContext with read-only protection#765
Zhaoxian-Wu wants to merge 4 commits intoIBM:masterfrom
Zhaoxian-Wu:feat/AnalogCtx-attribution

Conversation

@Zhaoxian-Wu
Copy link
Copy Markdown

@Zhaoxian-Wu Zhaoxian-Wu commented Mar 26, 2026

Note: Supersedes #717. The original PR was developed on the fork's master branch, which made it difficult to keep clean and rebase. This PR continues the same work on a dedicated feature branch (feat/AnalogCtx-attribution) for a cleaner history. All reviewer feedback from #717 has been addressed — see the detailed response below.

Problem

AnalogContext is exposed as an nn.Parameter, but its .data is a dummy scalar tensor. This means standard tensor operations produce wrong or meaningless results:

from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)

ctx = next(model.parameters())
ctx.size()      # torch.Size([])     — expected (out, in)
ctx.norm()      # tensor(1.)         — meaningless
ctx > 0         # tensor(False)      — wrong
ctx.nonzero()   # tensor([], size=(1, 0)) — wrong

This makes it impossible to inspect analog weights through the standard PyTorch parameter interface, which breaks compatibility with many training tools and analysis scripts.

Solution

Bind analog_ctx.data to the actual tile weights, so all read operations work naturally. At the same time, block in-place mutations by default to respect the physical constraints of analog devices.

Key changes

  1. as_ref parameter for get_weights() — Python tiles return a direct reference when as_ref=True; default (False) preserves the existing convention (detached CPU copy).

  2. _bind_shared_weights() for C++ tiles — Allocates a shared torch.Tensor and passes it to the C++ tile via set_shared_weights(), so both Python and C++ operate on the same memory with no explicit sync needed.

  3. ReadOnlyWeightView — A torch.Tensor subclass that blocks all in-place ops using PyTorch's trailing-underscore naming convention (future-proof, zero maintenance).

  4. Three-level readonly controlrpu_config.mapping.readonly_weights (per-layer), convert_to_analog(readonly=) (global), ctx.writable() (runtime).

Test results

  • test_analog_ctx: 123 passed, 27 skipped (pre-existing)
  • Full suite: 3763 passed, 0 regressions
  • A few pre-existing failures on master (Conv3d, RNN ~1e-4 numerical mismatches) are not introduced by this PR. Environment: 2x NVIDIA RTX PRO 6000 Blackwell (driver 580.95), CUDA 12.0, PyTorch 2.8.0+cu128, cuDNN 91002.

Response to #717 review

Hi @maljoras-sony and @PabloCarmona, thanks for your patience — I know it's been a while since the original review, and I really appreciate you coming back to this discussion. I've taken the time to carefully address all the concerns raised.

@maljoras-sony, your review raised three key concerns, and this update addresses all of them. Feel free to let me know if there remains any concerns.

Concern 1: Out-of-sync weights for C++ tiles

"Note that this will only be a copy of the current weights. So if you update the weights (using RPUCuda) the analog_ctx.data will not be synchronized correctly with the actual weight."

Root cause: C++ tiles store weights in C++ memory. get_weights() can only return a copy — any as_ref=True approach that works for Python tiles does not apply here.

Solution — _bind_shared_weights(): At tile construction time, we allocate a torch.Tensor on the Python side and pass it to the C++ tile via set_shared_weights(). After this call, both Python and C++ operate on the same memory:

Python:  analog_ctx.data ──→ _shared_weight_tensor ←── C++ tile internal storage
                              (same data_ptr)
  • tile.update() / tile.set_weights() modify the tensor in-place — no explicit sync needed during normal training
  • Device moves (cpu() / cuda()) invalidate the old pointer and rebind via _bind_shared_weights()
  • __getstate__ / __setstate__ skip the shared tensor (rebuilt on load)

Concern 2: Unintended weight modification

"we do not encourage users to fiddle with the analog weights, which are not accessible directly in reality."

Solution — ReadOnlyWeightView: A torch.Tensor subclass that blocks all in-place mutations. Instead of a fragile blocklist, we use PyTorch's own naming convention — all in-place ops end with _ — so the guard is simply func_name.endswith('_'). This automatically covers any new ops PyTorch adds in the future.

Three levels of control (all optional, default is read-only):

Level API Use case
Per-layer rpu_config.mapping.readonly_weights Fine-grained via specific_rpu_config_fun
Global convert_to_analog(readonly=False) Quick toggle for research
Runtime ctx.writable() context manager Temporary access in a code block
from aihwkit.nn import AnalogLinear
model = AnalogLinear(6, 4, False)
ctx = next(model.parameters())
ctx.data.norm()        # ✅ reads work
ctx.data.add_(1.0)     # ❌ RuntimeError

with ctx.writable():
    ctx.data.add_(1.0) # ✅ explicit opt-in

Concern 3: Breaking the get_weights convention

"the current convention is that get_weights always returns CPU weights... Moreover, get_weights will always produce a copy (without the backward trace) by design, to avoid implicit things that cannot be done with analog weights."

We preserve this convention fully. The public get_weights() API still returns a detached CPU copy by default (as_ref=False). The as_ref=True path is strictly internal — used only by _get_tile_weights_ref() and _bind_shared_weights() to set up the shared storage between analog_ctx.data and the tile. Users calling tile.get_weights() see no behavior change.

# Public API — unchanged, returns detached CPU copy
w, b = analog_tile.get_weights()     # as_ref=False (default)
w.add_(1.0)                          # safe — this is a copy, tile is unaffected

# Internal only — used to bind analog_ctx.data
ref = tile.tile.get_weights(as_ref=True)  # direct reference to tile storage

@PabloCarmona, could you take a look at these changes when you have a chance? I'd really appreciate any comments or suggestions. If the direction looks good, it would be great to move toward merging — happy to make any further adjustments needed. Thanks!

@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch 2 times, most recently from d4695cb to da36bc9 Compare March 29, 2026 03:09
@PabloCarmona
Copy link
Copy Markdown
Collaborator

PabloCarmona commented Mar 30, 2026

@Zhaoxian-Wu, please check the lint errors here: https://github.com/IBM/aihwkit/actions/runs/23700170869/job/69154808771?pr=765 and push again. Run the make commands related to it in your local env so you don't face those after push, thanks!

And don't forget to sign off your commits to pass the DCO check! 😉

@PabloCarmona
Copy link
Copy Markdown
Collaborator

Hello @maljoras @maljoras-sony can you look at this one? Thanks!

Zhaoxian-Wu and others added 4 commits April 1, 2026 18:38
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…r AnalogContext

- weight data management
- support checkpoint loading for old version toolkit
- sync analog_ctx and tile when recreating

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…nalogCtx tests

- Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of
  self.bias.data (Tensor), which caused TypeError when Conv layers re-assign
  bias as a bool in reset_parameters.
- Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution:
  correct shape, norm, nonzero, comparison ops, CUDA support, backward
  compatibility with old checkpoints, and convert_to_analog.

Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com>
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…weights

Add as_ref parameter to tile-level get_weights() across all simulator
tiles. When as_ref=True, return a direct reference to internal weight
storage; when False (default), return an independent copy.

For C++ tiles (RPUCuda) where get_weights() can only return a copy,
add _bind_shared_weights() to allocate a torch.Tensor and pass it to
the C++ tile via set_shared_weights() — both Python and C++ then
operate on the same memory with no explicit sync needed.

Add ReadOnlyWeightView (Tensor subclass) to prevent accidental
in-place modification of analog weights. Uses PyTorch's trailing-
underscore naming convention to block all in-place ops (future-proof).
Configurable via AnalogContext.readonly flag, writable() context
manager, MappingParameter.readonly_weights, and convert_to_analog
readonly parameter.

Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
@Zhaoxian-Wu Zhaoxian-Wu force-pushed the feat/AnalogCtx-attribution branch from da36bc9 to abd3578 Compare April 1, 2026 22:38
@Zhaoxian-Wu
Copy link
Copy Markdown
Author

Zhaoxian-Wu commented Apr 6, 2026

Hi @PabloCarmona, it's weird that my local test result varied from the online one. All tests and code style checks have passed on my end. Could you please provide a short testing guide on how to reproduce the test by myself? It would be better if the guide could start with the environment setup, since I suspect it's due to a version mismatch.

@PabloCarmona
Copy link
Copy Markdown
Collaborator

PabloCarmona commented Apr 6, 2026

@Zhaoxian-Wu as I can see this PR is changing code related to our custom simulator tiles involving the CUDA, Python and C code and that can lead to fails in the tests related to them like the one pointed here in the logs:

FAILED tests/test_torch_tiles.py::test_discretization_behavior[BoundManagementType.NONE--1--1-10.0-10.0] - assert False
 +  where False = allclose(tensor([[10.]], grad_fn=<ClampBackward1>), tensor([[-8.9111]], grad_fn=<AnalogFunctionBackward>), atol=1e-05)

Can you review if that could be the case? Since you are introducing changes in the methods arguments or/and behaviour related to it? If you have any more doubts we can help you with let me know again, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants