|
24 | 24 | import torch.nn.functional as F |
25 | 25 | from crossfit import op |
26 | 26 | from crossfit.backend.torch.hf.model import HFModel |
27 | | -from huggingface_hub import hf_hub_download |
| 27 | +from huggingface_hub import PyTorchModelHubMixin |
28 | 28 | from peft import PeftModel |
29 | | -from safetensors.torch import load_file |
30 | 29 | from torch.nn import Dropout, Linear |
31 | 30 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
32 | 31 |
|
@@ -75,7 +74,7 @@ class AegisConfig: |
75 | 74 | ] |
76 | 75 |
|
77 | 76 |
|
78 | | -class InstructionDataGuardNet(torch.nn.Module): |
| 77 | +class InstructionDataGuardNet(torch.nn.Module, PyTorchModelHubMixin): |
79 | 78 | def __init__(self, input_dim, dropout=0.7): |
80 | 79 | super().__init__() |
81 | 80 | self.input_dim = input_dim |
@@ -180,12 +179,14 @@ def load_model(self, device: str = "cuda"): |
180 | 179 | add_instruction_data_guard=self.config.add_instruction_data_guard, |
181 | 180 | ) |
182 | 181 | if self.config.add_instruction_data_guard: |
183 | | - weights_path = hf_hub_download( |
184 | | - repo_id=self.config.instruction_data_guard_path, |
185 | | - filename="model.safetensors", |
| 182 | + model.instruction_data_guard_net = ( |
| 183 | + model.instruction_data_guard_net.from_pretrained( |
| 184 | + self.config.instruction_data_guard_path |
| 185 | + ) |
| 186 | + ) |
| 187 | + model.instruction_data_guard_net = model.instruction_data_guard_net.to( |
| 188 | + device |
186 | 189 | ) |
187 | | - state_dict = load_file(weights_path) |
188 | | - model.instruction_data_guard_net.load_state_dict(state_dict) |
189 | 190 | model.instruction_data_guard_net.eval() |
190 | 191 |
|
191 | 192 | model = model.to(device) |
|
0 commit comments