-
Notifications
You must be signed in to change notification settings - Fork 192
Description
Hello,
I've run into an issue while trying to reproduce your code that I haven't been able to solve.
I followed these steps precisely:
Train the model: I used a very simple command to train the model, allowing the script to use its default parameters:
Bash
python scripts/segmentation_train.py --data_dir data/BraTS2021 --batch_size 8
Training Output: The training completed successfully and generated the model checkpoint, for example:
results/emasavedmodel_0.9999_020000.pt
Attempt Sampling: I then tried to run inference with the following command, which failed. This command includes all the parameters I could infer from the code and documentation to match the training configuration:
Bash
python scripts/segmentation_sample.py
--model_path results/emasavedmodel_0.9999_020000.pt
--data_dir data/BraTS2021
--out_dir ./BraTS2021_validation_results
--data_name BRATS
--in_ch 5
--image_size 256
--num_channels 256
--num_res_blocks 2
--num_heads 1
--attention_resolutions 16
--learn_sigma True
--use_scale_shift_norm False
--diffusion_steps 1000
--noise_schedule linear
--rescale_learned_sigmas False
--rescale_timesteps False
Error Log: After running the sampling command, I received the following complete error message:
[Please paste the very long RuntimeError log here]
The issue seems to be a model architecture mismatch between the model created by the training script and the one created by the sampling script. Any help you could provide would be greatly appreciated.
Thank you!