diff --git a/tabfm/src/hugging_face/convert_and_upload.py b/tabfm/src/hugging_face/convert_and_upload.py index 949d7cd..5ad6be4 100644 --- a/tabfm/src/hugging_face/convert_and_upload.py +++ b/tabfm/src/hugging_face/convert_and_upload.py @@ -15,7 +15,6 @@ """Utility to convert JAX TabFM checkpoints to PyTorch and upload to Hugging Face.""" import os -import json import logging from typing import Any, Dict, Optional, Tuple, Literal from absl import app @@ -23,11 +22,11 @@ import jax.numpy as jnp import numpy as np import torch -from torch import Tensor -from jaxtyping import Float, Int, Bool +from jaxtyping import Float, Int from tabfm.src.jax import tabfm_v1_0_0 from tabfm.src.pytorch import model as MT +from tabfm.src.pytorch.tabfm_v1_0_0 import TabFM_HF from tabfm.src.hugging_face.torch_convert import jax_params, convert # Architecture config of the v1.0.0 checkpoint. @@ -82,7 +81,7 @@ def convert_model( model_type: Literal["classification", "regression"], checkpoint_path: Optional[str] = None, -) -> Tuple[MT.TabFM, float]: +) -> Tuple[TabFM_HF, float]: """Converts JAX checkpoint of model_type to PyTorch TabFM and runs parity verification.""" logging.info("Loading JAX %s model...", model_type) is_classifier = (model_type == "classification") @@ -99,7 +98,7 @@ def convert_model( decoder_hidden = jp["icl_predictor.decoder.layers.0.kernel"].shape[1] logging.info("Instantiating PyTorch model...") - torch_model = MT.TabFM( + torch_model = TabFM_HF( decoder_hidden=decoder_hidden, is_classifier=is_classifier, **V1_0_0_CONFIG, @@ -152,29 +151,16 @@ def verify_parity( def save_checkpoint( - model: MT.TabFM, + model: TabFM_HF, output_dir: str, model_type: str, ) -> str: - """Saves PyTorch state dict and config metadata to output_dir.""" + """Saves model weights and config to output_dir using save_pretrained.""" model_dir = os.path.join(output_dir, model_type) os.makedirs(model_dir, exist_ok=True) - - weight_path = os.path.join(model_dir, "pytorch_model.bin") - logging.info("Saving PyTorch weights to %s...", weight_path) - torch.save(model.state_dict(), weight_path) - - config_path = os.path.join(model_dir, "config.json") - config_data = { - "model_type": "tabfm", - "version": "1.0.0", - "task": model_type, - "framework": "pytorch", - **V1_0_0_CONFIG, - } - with open(config_path, "w") as f: - json.dump(config_data, f, indent=2) - logging.info("Created config file at %s", config_path) + logging.info("Saving model to %s...", model_dir) + model.save_pretrained(model_dir) + logging.info("Saved model to %s", model_dir) return model_dir @@ -204,28 +190,10 @@ def main(argv): if FLAGS.repo_id: if not FLAGS.token: raise ValueError("Hugging Face token is required when repo_id is provided.") - + from huggingface_hub import HfApi # pylint: disable=g-import-not-at-top api = HfApi(token=FLAGS.token) - - # Upload root config - tmp_root_config = "/tmp/tabfm_root_config_pytorch.json" - root_config = { - "model_type": "tabfm", - "version": "1.0.0", - "framework": "pytorch", - } - with open(tmp_root_config, "w") as f: - json.dump(root_config, f, indent=2) - - logging.info("Uploading root config.json to %s...", FLAGS.repo_id) - api.upload_file( - path_or_fileobj=tmp_root_config, - path_in_repo="config.json", - repo_id=FLAGS.repo_id, - repo_type="model", - ) - + for mtype, sdir in local_dirs.items(): logging.info("Uploading %s folder to %s...", mtype, FLAGS.repo_id) api.upload_folder( diff --git a/tabfm/src/pytorch/tabfm_v1_0_0.py b/tabfm/src/pytorch/tabfm_v1_0_0.py index 88298d1..cfcbf0d 100644 --- a/tabfm/src/pytorch/tabfm_v1_0_0.py +++ b/tabfm/src/pytorch/tabfm_v1_0_0.py @@ -12,62 +12,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +import json import os import threading from typing import Any, Dict, Optional from absl import logging import torch +from huggingface_hub import PyTorchModelHubMixin, constants, snapshot_download from tabfm.src.pytorch.model import TabFM HF_REPO_ID = "google/tabfm-1.0.0-pytorch" -@dataclass(frozen=True) -class Config: - max_classes: int = 10 - embed_dim: int = 256 - col_num_blocks: int = 3 - col_nhead: int = 4 - col_num_inds: int = 256 - row_num_blocks: int = 3 - row_nhead: int = 8 - row_num_cls: int = 8 - icl_num_blocks: int = 24 - icl_nhead: int = 8 - ff_factor: int = 4 - feature_group_size: int = 3 - num_freq: int = 32 - is_classifier: bool = True - - def to_dict(self) -> Dict[str, Any]: - return { - "max_classes": self.max_classes, - "embed_dim": self.embed_dim, - "col_num_blocks": self.col_num_blocks, - "col_nhead": self.col_nhead, - "col_num_inds": self.col_num_inds, - "row_num_blocks": self.row_num_blocks, - "row_nhead": self.row_nhead, - "row_num_cls": self.row_num_cls, - "icl_num_blocks": self.icl_num_blocks, - "icl_nhead": self.icl_nhead, - "ff_factor": self.ff_factor, - "feature_group_size": self.feature_group_size, - "num_freq": self.num_freq, - "is_classifier": self.is_classifier, - } - -@dataclass(frozen=True) -class ClassificationConfig(Config): - is_classifier: bool = True - -@dataclass(frozen=True) -class RegressionConfig(Config): - is_classifier: bool = False - _LOAD_CACHE_LOCK = threading.Lock() -_LOAD_CACHE: Dict[Any, TabFM] = {} +_LOAD_CACHE: Dict[Any, "TabFM_HF"] = {} + + +class TabFM_HF( + TabFM, + PyTorchModelHubMixin, + repo_url="https://github.com/google-research/tabfm", + license="other", +): + """PyTorch TabFM model with HuggingFace Hub support. + + Subclasses TabFM directly (rather than wrapping it) and mixes in + PyTorchModelHubMixin, keeping the Hugging Face specific loading logic out + of the plain model class. + """ + + @classmethod + def _from_pretrained( + cls, + *, + model_id, + revision, + cache_dir, + force_download, + local_files_only, + token, + map_location="cpu", + strict=True, + **model_kwargs, + ): + subfolder = model_kwargs.pop("subfolder", None) + + def _apply_config(cfg): + if "is_classifier" not in model_kwargs and "task" in cfg: + model_kwargs["is_classifier"] = cfg.pop("task") == "classification" + for key in ("model_type", "version", "framework"): + cfg.pop(key, None) + for k, v in cfg.items(): + if k not in model_kwargs: + model_kwargs[k] = v + + # translate config keys already merged into model_kwargs by from_pretrained() + _apply_config(model_kwargs) + + if subfolder is None: + local_id = model_id + elif os.path.isdir(model_id): + local_id = os.path.join(model_id, subfolder) + else: + base_path = snapshot_download( + repo_id=model_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + allow_patterns=[f"{subfolder}/**"], + ) + local_id = os.path.join(base_path, subfolder) + + if subfolder is not None: + cfg_path = os.path.join(local_id, constants.CONFIG_NAME) + if os.path.exists(cfg_path): + with open(cfg_path) as f: + _apply_config(json.load(f)) + else: + logging.warning("No config.json found in %s", local_id) + + return super()._from_pretrained( + model_id=local_id, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + map_location=map_location, + strict=strict, + **model_kwargs, + ) + def load( model_type: str = "classification", @@ -76,7 +113,7 @@ def load( device: Optional[str] = None, dtype: Any = torch.bfloat16, use_cache: bool = True, -) -> TabFM: +) -> "TabFM_HF": """Loads the PyTorch TabFM v1.0.0 model with pre-trained weights. The checkpoint is stored in float32, but the model is designed to run in @@ -86,57 +123,62 @@ def load( ``dtype`` is provided for float32 debugging / quality comparison; the model is designed for bfloat16 and this option may be removed in a future release. + + Args: + model_type: 'classification' or 'regression'. + checkpoint_path: Local directory or weights file. If None, downloads from + Hugging Face (google/tabfm-1.0.0-pytorch). + device: Target device (e.g. 'cuda', 'cpu'). Defaults to 'cpu'. + dtype: Compute dtype to cast the model to after loading. Defaults to + bfloat16; pass None to keep the float32 weights. + use_cache: Reuse a process-wide cached model for identical settings. + + Returns: + An eval-mode TabFM_HF model with pre-trained weights loaded. """ + if model_type not in ("classification", "regression"): + raise ValueError( + f"Unsupported model_type: {model_type!r}. " + "Must be 'classification' or 'regression'." + ) + cache_key = (model_type, checkpoint_path, device, str(dtype)) - if use_cache: _LOAD_CACHE_LOCK.acquire() - try: if use_cache and cache_key in _LOAD_CACHE: return _LOAD_CACHE[cache_key] - if model_type == "classification": - config = ClassificationConfig() - elif model_type == "regression": - config = RegressionConfig() - else: - raise ValueError(f"Unsupported model_type: {model_type}.") - - model = TabFM(**config.to_dict()) - if checkpoint_path is None: - try: - from huggingface_hub import snapshot_download - logging.info("Downloading TabFM v1.0.0 PyTorch %s weights...", model_type) - base_path = snapshot_download(repo_id=HF_REPO_ID) - checkpoint_file = os.path.join(base_path, model_type, "pytorch_model.bin") - except ImportError as e: - raise ImportError("huggingface_hub is required to download weights.") from e + logging.info( + "Downloading TabFM v1.0.0 PyTorch %s weights from Hugging Face...", + model_type, + ) + model = TabFM_HF.from_pretrained(HF_REPO_ID, subfolder=model_type) else: - checkpoint_file = checkpoint_path - if os.path.isdir(checkpoint_file): - for sub in [os.path.join(checkpoint_file, model_type, "pytorch_model.bin"), - os.path.join(checkpoint_file, "pytorch_model.bin")]: - if os.path.exists(sub): - checkpoint_file = sub - break - - if not os.path.exists(checkpoint_file): - raise FileNotFoundError(f"Weights not found at: {checkpoint_file}") - - logging.info("Loading PyTorch state dict from %s...", checkpoint_file) - state_dict = torch.load(checkpoint_file, map_location="cpu") - model.load_state_dict(state_dict, strict=True) + local_dir = checkpoint_path + if not os.path.isdir(local_dir): + raise FileNotFoundError(f"Local checkpoint path not found: {local_dir}") + sub = os.path.join(local_dir, model_type) + if os.path.isdir(sub): + local_dir = sub + + if os.path.exists(os.path.join(local_dir, "config.json")): + model = TabFM_HF.from_pretrained(local_dir) + else: + # no config.json: pass is_classifier explicitly + model = TabFM_HF.from_pretrained( + local_dir, + is_classifier=(model_type == "classification"), + ) if dtype is not None: model = model.to(dtype) # engage the bf16 compute design (see docstring) if device is not None: model = model.to(device) - model.eval() - + if use_cache: _LOAD_CACHE[cache_key] = model return model