Standardize Backbone API for SSL compatibility and structural consistency#30
Standardize Backbone API for SSL compatibility and structural consistency#30
Conversation
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
There was a problem hiding this comment.
Pull request overview
This PR standardizes the BACKBONES interface across torch_ecg to better support SSL-style feature extraction by adding a consistent forward_features entry point and a corresponding feature-shape inference method.
Changes:
- Added
SizeMixin.compute_features_output_shape(...)as a default feature-shape inference hook (delegating tocompute_output_shape). - Implemented
forward_features(...)(+ per-backbonecompute_features_output_shape(...)) across the registered CNN backbones (ResNet, RegNet, VGG16, Xception, DenseNet, MobileNetV1/2/3, MultiScopicCNN). - Added a new unit test to validate the standardized backbone API.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| torch_ecg/utils/utils_nn.py | Adds default compute_features_output_shape on SizeMixin. |
| torch_ecg/models/cnn/xception.py | Adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/vgg.py | Imports Tensor; adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/resnet.py | Adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/regnet.py | Imports Tensor; adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/multi_scopic.py | Adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/mobilenet.py | Adds forward_features and feature-shape inference method to V1/V2/V3. |
| torch_ecg/models/cnn/ho_resnet.py | Adds typed method stubs to align with standardized interface (still unimplemented). |
| torch_ecg/models/cnn/efficientnet.py | Adds typed method stubs to align with standardized interface (still unimplemented). |
| torch_ecg/models/cnn/densenet.py | Adds forward_features and feature-shape inference method. |
| torch_ecg/models/cnn/darknet.py | Adjusts base-class order; adds typed method stubs to align with standardized interface (still unimplemented). |
| test/test_models/test_backbone_api.py | Introduces a new parametrized test for the standardized backbone API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
torch_ecg/utils/utils_nn.py
Outdated
| def compute_features_output_shape( | ||
| self, seq_len: Optional[int] = None, batch_size: Optional[int] = None | ||
| ) -> Sequence[Union[int, None]]: | ||
| """Compute the output shape of the features. | ||
|
|
||
| By default, this is the same as the output shape of the model. | ||
| For backbones with pooling and classification heads, this should be | ||
| overridden to return the shape of the features before global pooling. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| seq_len : int, optional | ||
| Length of the input signal tensor. | ||
| batch_size : int, optional | ||
| Batch size of the input signal tensor. | ||
|
|
||
| Returns | ||
| ------- | ||
| output_shape : sequence | ||
| Output shape of the features. | ||
|
|
||
| """ | ||
| return self.compute_output_shape(seq_len, batch_size) | ||
|
|
| # Get default config if available in ECG_CRNN_CONFIG | ||
| config = None | ||
| for k, v in ECG_CRNN_CONFIG.cnn.items(): | ||
| if k.lower() == backbone_name.lower(): | ||
| config = deepcopy(v) | ||
| break | ||
|
|
||
| if config is None: | ||
| # Some backbones might not be in ECG_CRNN_CONFIG, skip for now | ||
| # or provide a minimal dummy config if known | ||
| pytest.skip(f"No default config found for backbone: {backbone_name}") | ||
|
|
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #30 +/- ##
==========================================
- Coverage 93.43% 93.36% -0.08%
==========================================
Files 139 139
Lines 18413 18472 +59
==========================================
+ Hits 17205 17247 +42
- Misses 1208 1225 +17 ☔ View full report in Codecov by Sentry. |
This PR completes the standardization of the Backbone API across
torch_ecg. Every model registered in theBACKBONESregistry now implements a consistent interface for feature extraction, which is a critical prerequisite for Self-Supervised Learning (SSL) architectures likeMAEandSimCLR.Key Changes
forward_features(self, x: Tensor) -> Tensor:Standardized method to extract high-dimensional feature maps before global pooling/classification heads.compute_features_output_shape(self, seq_len, batch_size):Allows downstream modules to statically infer feature dimensions.SizeMixinintorch_ecg.utils.utils_nnto support the new feature shape inference logic.