Skip to content

add PyTorchModelHubMixin to TabFM#33

Merged
erzel merged 6 commits into
google-research:mainfrom
kashif:add-pytorch-hub-mixin
Jul 2, 2026
Merged

add PyTorchModelHubMixin to TabFM#33
erzel merged 6 commits into
google-research:mainfrom
kashif:add-pytorch-hub-mixin

Conversation

@kashif

@kashif kashif commented Jul 1, 2026

Copy link
Copy Markdown
Contributor

Makes TabFM extend PyTorchModelHubMixin so it gets from_pretrained,
save_pretrained, and push_to_hub for free.

  • load() now calls TabFM.from_pretrained(HF_REPO_ID, subfolder=model_type)
    instead of snapshot_download + manual torch.load + load_state_dict
  • save_pretrained writes model.safetensors (preferred over .bin) plus a
    proper config.json with all init params including is_classifier
  • _from_pretrained translates the legacy hub task: "classification" field to
    is_classifier: bool so existing weights load without any config update
  • Removes the Config/ClassificationConfig/RegressionConfig dataclasses and
    all the manual json/bin saving from convert_and_upload.py

Users can now also do:

from tabfm.src.pytorch.model import TabFM
model = TabFM.from_pretrained("google/tabfm-1.0.0-pytorch", subfolder="classification")
model.push_to_hub("my-org/my-tabfm-fork")

TabFM now extends PyTorchModelHubMixin giving it from_pretrained,
save_pretrained, and push_to_hub. The load() helper uses
TabFM.from_pretrained() instead of manual snapshot_download + torch.load.
save_pretrained writes model.safetensors which is the preferred format.
Remove redundant config dataclasses and manual json/bin saving.

@erzel erzel left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Please see my comments below.

Comment thread tabfm/src/pytorch/model.py Outdated
Comment thread tabfm/src/pytorch/model.py Outdated
Comment thread tabfm/src/pytorch/model.py Outdated
Comment thread tabfm/src/pytorch/tabfm_v1_0_0.py Outdated
Comment thread tabfm/src/pytorch/model.py Outdated
Comment thread tabfm/src/pytorch/tabfm_v1_0_0.py Outdated
@kashif

kashif commented Jul 2, 2026

Copy link
Copy Markdown
Contributor Author

@erzel addressed all the review points: TabFM_HF now lives in tabfm_v1_0_0.py, subclasses TabFM and PyTorchModelHubMixin, and always delegates to the superclass _from_pretrained (only the config fetching differs by resolving to a local dir first). Duplicated config translation logic is merged into one helper, the pytype annotation is back, added a logging.warning when config.json is missing, and the local path check in load() raises immediately. Verified hub subfolder load, local dir load, and load() all work.

# Conflicts:
#	tabfm/src/pytorch/tabfm_v1_0_0.py

@erzel erzel left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the changes.

@erzel erzel merged commit 5df7def into google-research:main Jul 2, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants