feat(ctx): define data attribution for AnalogContext with read-only protection#765
feat(ctx): define data attribution for AnalogContext with read-only protection#765Zhaoxian-Wu wants to merge 4 commits intoIBM:masterfrom
Conversation
d4695cb to
da36bc9
Compare
|
@Zhaoxian-Wu, please check the lint errors here: https://github.com/IBM/aihwkit/actions/runs/23700170869/job/69154808771?pr=765 and push again. Run the make commands related to it in your local env so you don't face those after push, thanks! And don't forget to sign off your commits to pass the DCO check! 😉 |
|
Hello @maljoras @maljoras-sony can you look at this one? Thanks! |
Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…r AnalogContext - weight data management - support checkpoint loading for old version toolkit - sync analog_ctx and tile when recreating Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…nalogCtx tests - Fix TileModuleArray.get_weights() returning self.bias (Parameter) instead of self.bias.data (Tensor), which caused TypeError when Conv layers re-assign bias as a bool in reset_parameters. - Add test_analog_ctx.py verifying PR IBM#717 AnalogContext data attribution: correct shape, norm, nonzero, comparison ops, CUDA support, backward compatibility with old checkpoints, and convert_to_analog. Signed-off-by: Zhaoxian-Wu <wuzhaoxian97@gmail.com> Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
…weights Add as_ref parameter to tile-level get_weights() across all simulator tiles. When as_ref=True, return a direct reference to internal weight storage; when False (default), return an independent copy. For C++ tiles (RPUCuda) where get_weights() can only return a copy, add _bind_shared_weights() to allocate a torch.Tensor and pass it to the C++ tile via set_shared_weights() — both Python and C++ then operate on the same memory with no explicit sync needed. Add ReadOnlyWeightView (Tensor subclass) to prevent accidental in-place modification of analog weights. Uses PyTorch's trailing- underscore naming convention to block all in-place ops (future-proof). Configurable via AnalogContext.readonly flag, writable() context manager, MappingParameter.readonly_weights, and convert_to_analog readonly parameter. Signed-off-by: Zhaoxian Wu <wuzhaoxian97@gmail.com>
da36bc9 to
abd3578
Compare
|
Hi @PabloCarmona, it's weird that my local test result varied from the online one. All tests and code style checks have passed on my end. Could you please provide a short testing guide on how to reproduce the test by myself? It would be better if the guide could start with the environment setup, since I suspect it's due to a version mismatch. |
|
@Zhaoxian-Wu as I can see this PR is changing code related to our custom simulator tiles involving the CUDA, Python and C code and that can lead to fails in the tests related to them like the one pointed here in the logs: Can you review if that could be the case? Since you are introducing changes in the methods arguments or/and behaviour related to it? If you have any more doubts we can help you with let me know again, thanks! |
Problem
AnalogContextis exposed as annn.Parameter, but its.datais a dummy scalar tensor. This means standard tensor operations produce wrong or meaningless results:This makes it impossible to inspect analog weights through the standard PyTorch parameter interface, which breaks compatibility with many training tools and analysis scripts.
Solution
Bind
analog_ctx.datato the actual tile weights, so all read operations work naturally. At the same time, block in-place mutations by default to respect the physical constraints of analog devices.Key changes
as_refparameter forget_weights()— Python tiles return a direct reference whenas_ref=True; default (False) preserves the existing convention (detached CPU copy)._bind_shared_weights()for C++ tiles — Allocates a sharedtorch.Tensorand passes it to the C++ tile viaset_shared_weights(), so both Python and C++ operate on the same memory with no explicit sync needed.ReadOnlyWeightView— Atorch.Tensorsubclass that blocks all in-place ops using PyTorch's trailing-underscore naming convention (future-proof, zero maintenance).Three-level
readonlycontrol —rpu_config.mapping.readonly_weights(per-layer),convert_to_analog(readonly=)(global),ctx.writable()(runtime).Test results
test_analog_ctx: 123 passed, 27 skipped (pre-existing)master(Conv3d, RNN ~1e-4 numerical mismatches) are not introduced by this PR. Environment: 2x NVIDIA RTX PRO 6000 Blackwell (driver 580.95), CUDA 12.0, PyTorch 2.8.0+cu128, cuDNN 91002.Response to #717 review
Hi @maljoras-sony and @PabloCarmona, thanks for your patience — I know it's been a while since the original review, and I really appreciate you coming back to this discussion. I've taken the time to carefully address all the concerns raised.
@maljoras-sony, your review raised three key concerns, and this update addresses all of them. Feel free to let me know if there remains any concerns.
Concern 1: Out-of-sync weights for C++ tiles
Root cause: C++ tiles store weights in C++ memory.
get_weights()can only return a copy — anyas_ref=Trueapproach that works for Python tiles does not apply here.Solution —
_bind_shared_weights(): At tile construction time, we allocate atorch.Tensoron the Python side and pass it to the C++ tile viaset_shared_weights(). After this call, both Python and C++ operate on the same memory:tile.update()/tile.set_weights()modify the tensor in-place — no explicit sync needed during normal trainingcpu()/cuda()) invalidate the old pointer and rebind via_bind_shared_weights()__getstate__/__setstate__skip the shared tensor (rebuilt on load)Concern 2: Unintended weight modification
Solution —
ReadOnlyWeightView: Atorch.Tensorsubclass that blocks all in-place mutations. Instead of a fragile blocklist, we use PyTorch's own naming convention — all in-place ops end with_— so the guard is simplyfunc_name.endswith('_'). This automatically covers any new ops PyTorch adds in the future.Three levels of control (all optional, default is read-only):
rpu_config.mapping.readonly_weightsspecific_rpu_config_funconvert_to_analog(readonly=False)ctx.writable()context managerConcern 3: Breaking the
get_weightsconventionWe preserve this convention fully. The public
get_weights()API still returns a detached CPU copy by default (as_ref=False). Theas_ref=Truepath is strictly internal — used only by_get_tile_weights_ref()and_bind_shared_weights()to set up the shared storage betweenanalog_ctx.dataand the tile. Users callingtile.get_weights()see no behavior change.@PabloCarmona, could you take a look at these changes when you have a chance? I'd really appreciate any comments or suggestions. If the direction looks good, it would be great to move toward merging — happy to make any further adjustments needed. Thanks!