Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 11 additions & 43 deletions tabfm/src/hugging_face/convert_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,18 @@
"""Utility to convert JAX TabFM checkpoints to PyTorch and upload to Hugging Face."""

import os
import json
import logging
from typing import Any, Dict, Optional, Tuple, Literal
from absl import app
from absl import flags
import jax.numpy as jnp
import numpy as np
import torch
from torch import Tensor
from jaxtyping import Float, Int, Bool
from jaxtyping import Float, Int

from tabfm.src.jax import tabfm_v1_0_0
from tabfm.src.pytorch import model as MT
from tabfm.src.pytorch.tabfm_v1_0_0 import TabFM_HF
from tabfm.src.hugging_face.torch_convert import jax_params, convert

# Architecture config of the v1.0.0 checkpoint.
Expand Down Expand Up @@ -82,7 +81,7 @@
def convert_model(
model_type: Literal["classification", "regression"],
checkpoint_path: Optional[str] = None,
) -> Tuple[MT.TabFM, float]:
) -> Tuple[TabFM_HF, float]:
"""Converts JAX checkpoint of model_type to PyTorch TabFM and runs parity verification."""
logging.info("Loading JAX %s model...", model_type)
is_classifier = (model_type == "classification")
Expand All @@ -99,7 +98,7 @@ def convert_model(
decoder_hidden = jp["icl_predictor.decoder.layers.0.kernel"].shape[1]

logging.info("Instantiating PyTorch model...")
torch_model = MT.TabFM(
torch_model = TabFM_HF(
decoder_hidden=decoder_hidden,
is_classifier=is_classifier,
**V1_0_0_CONFIG,
Expand Down Expand Up @@ -152,29 +151,16 @@ def verify_parity(


def save_checkpoint(
model: MT.TabFM,
model: TabFM_HF,
output_dir: str,
model_type: str,
) -> str:
"""Saves PyTorch state dict and config metadata to output_dir."""
"""Saves model weights and config to output_dir using save_pretrained."""
model_dir = os.path.join(output_dir, model_type)
os.makedirs(model_dir, exist_ok=True)

weight_path = os.path.join(model_dir, "pytorch_model.bin")
logging.info("Saving PyTorch weights to %s...", weight_path)
torch.save(model.state_dict(), weight_path)

config_path = os.path.join(model_dir, "config.json")
config_data = {
"model_type": "tabfm",
"version": "1.0.0",
"task": model_type,
"framework": "pytorch",
**V1_0_0_CONFIG,
}
with open(config_path, "w") as f:
json.dump(config_data, f, indent=2)
logging.info("Created config file at %s", config_path)
logging.info("Saving model to %s...", model_dir)
model.save_pretrained(model_dir)
logging.info("Saved model to %s", model_dir)
return model_dir


Expand Down Expand Up @@ -204,28 +190,10 @@ def main(argv):
if FLAGS.repo_id:
if not FLAGS.token:
raise ValueError("Hugging Face token is required when repo_id is provided.")

from huggingface_hub import HfApi # pylint: disable=g-import-not-at-top
api = HfApi(token=FLAGS.token)

# Upload root config
tmp_root_config = "/tmp/tabfm_root_config_pytorch.json"
root_config = {
"model_type": "tabfm",
"version": "1.0.0",
"framework": "pytorch",
}
with open(tmp_root_config, "w") as f:
json.dump(root_config, f, indent=2)

logging.info("Uploading root config.json to %s...", FLAGS.repo_id)
api.upload_file(
path_or_fileobj=tmp_root_config,
path_in_repo="config.json",
repo_id=FLAGS.repo_id,
repo_type="model",
)


for mtype, sdir in local_dirs.items():
logging.info("Uploading %s folder to %s...", mtype, FLAGS.repo_id)
api.upload_folder(
Expand Down
202 changes: 122 additions & 80 deletions tabfm/src/pytorch/tabfm_v1_0_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,99 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import json
import os
import threading
from typing import Any, Dict, Optional
from absl import logging
import torch
from huggingface_hub import PyTorchModelHubMixin, constants, snapshot_download

from tabfm.src.pytorch.model import TabFM

HF_REPO_ID = "google/tabfm-1.0.0-pytorch"

@dataclass(frozen=True)
class Config:
max_classes: int = 10
embed_dim: int = 256
col_num_blocks: int = 3
col_nhead: int = 4
col_num_inds: int = 256
row_num_blocks: int = 3
row_nhead: int = 8
row_num_cls: int = 8
icl_num_blocks: int = 24
icl_nhead: int = 8
ff_factor: int = 4
feature_group_size: int = 3
num_freq: int = 32
is_classifier: bool = True

def to_dict(self) -> Dict[str, Any]:
return {
"max_classes": self.max_classes,
"embed_dim": self.embed_dim,
"col_num_blocks": self.col_num_blocks,
"col_nhead": self.col_nhead,
"col_num_inds": self.col_num_inds,
"row_num_blocks": self.row_num_blocks,
"row_nhead": self.row_nhead,
"row_num_cls": self.row_num_cls,
"icl_num_blocks": self.icl_num_blocks,
"icl_nhead": self.icl_nhead,
"ff_factor": self.ff_factor,
"feature_group_size": self.feature_group_size,
"num_freq": self.num_freq,
"is_classifier": self.is_classifier,
}

@dataclass(frozen=True)
class ClassificationConfig(Config):
is_classifier: bool = True

@dataclass(frozen=True)
class RegressionConfig(Config):
is_classifier: bool = False

_LOAD_CACHE_LOCK = threading.Lock()
_LOAD_CACHE: Dict[Any, TabFM] = {}
_LOAD_CACHE: Dict[Any, "TabFM_HF"] = {}


class TabFM_HF(
TabFM,
PyTorchModelHubMixin,
repo_url="https://github.com/google-research/tabfm",
license="other",
):
"""PyTorch TabFM model with HuggingFace Hub support.

Subclasses TabFM directly (rather than wrapping it) and mixes in
PyTorchModelHubMixin, keeping the Hugging Face specific loading logic out
of the plain model class.
"""

@classmethod
def _from_pretrained(
cls,
*,
model_id,
revision,
cache_dir,
force_download,
local_files_only,
token,
map_location="cpu",
strict=True,
**model_kwargs,
):
subfolder = model_kwargs.pop("subfolder", None)

def _apply_config(cfg):
if "is_classifier" not in model_kwargs and "task" in cfg:
model_kwargs["is_classifier"] = cfg.pop("task") == "classification"
for key in ("model_type", "version", "framework"):
cfg.pop(key, None)
for k, v in cfg.items():
if k not in model_kwargs:
model_kwargs[k] = v

# translate config keys already merged into model_kwargs by from_pretrained()
_apply_config(model_kwargs)

if subfolder is None:
local_id = model_id
elif os.path.isdir(model_id):
local_id = os.path.join(model_id, subfolder)
else:
base_path = snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
allow_patterns=[f"{subfolder}/**"],
)
local_id = os.path.join(base_path, subfolder)

if subfolder is not None:
cfg_path = os.path.join(local_id, constants.CONFIG_NAME)
if os.path.exists(cfg_path):
with open(cfg_path) as f:
_apply_config(json.load(f))
else:
logging.warning("No config.json found in %s", local_id)

return super()._from_pretrained(
model_id=local_id,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
map_location=map_location,
strict=strict,
**model_kwargs,
)


def load(
model_type: str = "classification",
Expand All @@ -76,7 +113,7 @@ def load(
device: Optional[str] = None,
dtype: Any = torch.bfloat16,
use_cache: bool = True,
) -> TabFM:
) -> "TabFM_HF":
"""Loads the PyTorch TabFM v1.0.0 model with pre-trained weights.

The checkpoint is stored in float32, but the model is designed to run in
Expand All @@ -86,57 +123,62 @@ def load(

``dtype`` is provided for float32 debugging / quality comparison; the model is
designed for bfloat16 and this option may be removed in a future release.

Args:
model_type: 'classification' or 'regression'.
checkpoint_path: Local directory or weights file. If None, downloads from
Hugging Face (google/tabfm-1.0.0-pytorch).
device: Target device (e.g. 'cuda', 'cpu'). Defaults to 'cpu'.
dtype: Compute dtype to cast the model to after loading. Defaults to
bfloat16; pass None to keep the float32 weights.
use_cache: Reuse a process-wide cached model for identical settings.

Returns:
An eval-mode TabFM_HF model with pre-trained weights loaded.
"""
if model_type not in ("classification", "regression"):
raise ValueError(
f"Unsupported model_type: {model_type!r}. "
"Must be 'classification' or 'regression'."
)

cache_key = (model_type, checkpoint_path, device, str(dtype))

if use_cache:
_LOAD_CACHE_LOCK.acquire()

try:
if use_cache and cache_key in _LOAD_CACHE:
return _LOAD_CACHE[cache_key]

if model_type == "classification":
config = ClassificationConfig()
elif model_type == "regression":
config = RegressionConfig()
else:
raise ValueError(f"Unsupported model_type: {model_type}.")

model = TabFM(**config.to_dict())

if checkpoint_path is None:
try:
from huggingface_hub import snapshot_download
logging.info("Downloading TabFM v1.0.0 PyTorch %s weights...", model_type)
base_path = snapshot_download(repo_id=HF_REPO_ID)
checkpoint_file = os.path.join(base_path, model_type, "pytorch_model.bin")
except ImportError as e:
raise ImportError("huggingface_hub is required to download weights.") from e
logging.info(
"Downloading TabFM v1.0.0 PyTorch %s weights from Hugging Face...",
model_type,
)
model = TabFM_HF.from_pretrained(HF_REPO_ID, subfolder=model_type)
else:
checkpoint_file = checkpoint_path
if os.path.isdir(checkpoint_file):
for sub in [os.path.join(checkpoint_file, model_type, "pytorch_model.bin"),
os.path.join(checkpoint_file, "pytorch_model.bin")]:
if os.path.exists(sub):
checkpoint_file = sub
break

if not os.path.exists(checkpoint_file):
raise FileNotFoundError(f"Weights not found at: {checkpoint_file}")

logging.info("Loading PyTorch state dict from %s...", checkpoint_file)
state_dict = torch.load(checkpoint_file, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
local_dir = checkpoint_path
if not os.path.isdir(local_dir):
raise FileNotFoundError(f"Local checkpoint path not found: {local_dir}")
sub = os.path.join(local_dir, model_type)
if os.path.isdir(sub):
local_dir = sub

if os.path.exists(os.path.join(local_dir, "config.json")):
model = TabFM_HF.from_pretrained(local_dir)
else:
# no config.json: pass is_classifier explicitly
model = TabFM_HF.from_pretrained(
local_dir,
is_classifier=(model_type == "classification"),
)

if dtype is not None:
model = model.to(dtype) # engage the bf16 compute design (see docstring)

if device is not None:
model = model.to(device)

model.eval()

if use_cache:
_LOAD_CACHE[cache_key] = model
return model
Expand Down
Loading