Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions paddleformers/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
ColumnParallelLinear,
RowParallelLinear,
)
from paddle.distributed.flex_checkpoint.dcp.sharded_weight import (
build_sharded_state_dict,
)

from ...transformers import linear_utils

Expand Down Expand Up @@ -365,6 +368,13 @@ def __init__(
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": 0, "lora_A": 0}, structured_name_prefix)

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged
Expand Down Expand Up @@ -514,6 +524,13 @@ def __init__(
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": 0, "lora_A": 0}, structured_name_prefix)

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
Expand Down Expand Up @@ -629,6 +646,13 @@ def __init__(
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": 1, "bias": 0, "lora_B": 1}, structured_name_prefix)

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged
Expand Down Expand Up @@ -761,6 +785,13 @@ def __init__(
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False

def sharded_state_dict(
self,
structured_name_prefix: str = "",
):
state_dict = self.state_dict(structured_name_prefix="")
return build_sharded_state_dict(state_dict, {"weight": 1, "bias": 0, "lora_B": 1}, structured_name_prefix)

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
Expand Down
33 changes: 25 additions & 8 deletions paddleformers/peft/lora/lora_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import math
import os
import re
import sys
import tempfile
from collections import OrderedDict
from functools import partial
Expand All @@ -43,10 +44,12 @@
clean_unrelated_safetensors,
dtype_guard,
load_state_dict,
prepare_safe_save_state_dict,
)
from ...transformers.utils import (
dtype_byte_size,
get_checkpoint_shard_files,
is_safetensors_available,
weight_name_suffix,
)
from ...utils.distributed import distributed_allgather, distributed_gather
Expand All @@ -55,6 +58,14 @@
from ...utils.tools import get_env_device
from .lora_config import LoRAAutoConfig, LoRAConfig

if is_safetensors_available():
from safetensors.numpy import save_file as safe_save_file

if sys.platform.startswith("win"):
from safetensors import safe_open
else:
from ...utils.safetensors import fast_safe_open as safe_open


def get_lora_layers():
try:
Expand Down Expand Up @@ -255,7 +266,6 @@ def get_tensor_parallel_split_mappings():
@classmethod
def from_pretrained(cls, model, lora_path, **kwargs):
lora_config = kwargs.pop("lora_config", None)
load_checkpoint_format = kwargs.pop("load_checkpoint_format", None)
# init lora config & lora model
if not isinstance(lora_config, LoRAConfig):
lora_config = LoRAConfig.from_pretrained(lora_path)
Expand All @@ -273,7 +283,7 @@ def from_pretrained(cls, model, lora_path, **kwargs):
loaded_keys = sharded_metadata["all_checkpoint_keys"]
expected_keys = set(lora_model.get_trainable_state_dict().keys())
missing_keys = expected_keys - set(loaded_keys)
if len(missing_keys) > 0 and load_checkpoint_format != "flex_checkpoint":
if len(missing_keys) > 0:
raise ValueError(f"missing_keys: {missing_keys}")

error_msgs = []
Expand Down Expand Up @@ -438,6 +448,13 @@ def _convert_tensor_parallel(self, lora_state_dict):
logger.warning(f"{name} not found in lora_state_dict!")
return lora_state_dict

def sharded_state_dict(self, *args, **kwargs):
sharded_state_dict = self.model.sharded_state_dict()
lora_sharded_state_dict = {}
for name, weight in sharded_state_dict.items():
lora_sharded_state_dict[name] = weight
return lora_sharded_state_dict

def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
save_model_config = kwargs.get("save_model_config", True)
save_checkpoint_format = kwargs.get("save_checkpoint_format", None)
Expand Down Expand Up @@ -503,13 +520,14 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
weight_filename = os.path.join(save_directory, lora_weight_name)
if total_size != 0:
logger.info(f"Saving LoRA weights to {weight_filename}")
paddle.save(tensor_state_dict, weight_filename, safetensors=safetensors)
tensor_state_dict, metadata = prepare_safe_save_state_dict(tensor_state_dict, save_to_hf=safetensors)
safe_save_file(tensor_state_dict, weight_filename, metadata=metadata)
else:
lora_weight_name = _add_variant(LORA_WEIGHTS_NAME, variant)
weight_filename = os.path.join(save_directory, lora_weight_name)
paddle.save(trainable_state_dict, weight_filename, safetensors=safetensors)

def replace_name_and_gen_index(path):
def replace_name_and_gen_index_lora(path):
index_mapping = {}
safetensor_files = [fname for fname in os.listdir(path) if fname.endswith(".pdparams")]
total_files_num = len(safetensor_files)
Expand All @@ -520,12 +538,11 @@ def replace_name_and_gen_index(path):
cur_file_index += 1
file_path = os.path.join(path, file)
new_file_name = f"peft_model-{cur_file_index:05d}-of-{total_files_num:05d}.safetensors"
from safetensors.paddle import safe_open

with safe_open(file_path, framework="paddle") as f:
with safe_open(file_path, framework="np") as f:
for key in f.keys():
index_mapping[key] = new_file_name
single_size += f.get_tensor(key).numel().item() * dtype_byte_size(f.get_tensor(key).dtype)
single_size += f.get_tensor(key).nbytes
total_size += single_size
new_file_path = os.path.join(path, new_file_name)
os.rename(file_path, new_file_path)
Expand Down Expand Up @@ -554,7 +571,7 @@ def replace_name_and_gen_index(path):
model_config_to_save.tensor_parallel_degree = -1
model_config_to_save.save_pretrained(save_directory)
if safetensors:
replace_name_and_gen_index(save_directory)
replace_name_and_gen_index_lora(save_directory)

def _find_and_replace_module(self, model, module_name, lora_config):
parent_module = model
Expand Down
48 changes: 15 additions & 33 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,12 +1296,7 @@ def train(
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

if resume_from_checkpoint is not None:
if isinstance(self.model, LoRAModel):
self.model.from_pretrained(
self.model, resume_from_checkpoint, load_checkpoint_format=self.args.load_checkpoint_format
)
else:
self._load_flex_checkpoint(resume_from_checkpoint)
self._load_flex_checkpoint(resume_from_checkpoint)
else:
if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
Expand Down Expand Up @@ -3518,13 +3513,12 @@ def _save_checkpoint(self, model, metrics=None):
self.args.optim_shard_num,
)
elif self.args.save_checkpoint_format == "flex_checkpoint":
if not isinstance(self.model, LoRAModel):
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
if self.dp_group.rank > 0: # this should only work for MoE saving
self._save_ckpt_func(
Expand Down Expand Up @@ -3576,16 +3570,13 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir,
)
elif self.args.save_checkpoint_format == "flex_checkpoint":
if isinstance(self.model, LoRAModel):
self.save_model(output_dir)
else:
self._save_flex_model_state(output_dir)
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
self._save_flex_model_state(output_dir)
self._save_flex_optimizer_state(output_dir)
if self.args.should_save:
if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
else:
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
self._save_ckpt_func(
Expand Down Expand Up @@ -3812,16 +3803,7 @@ def _save(
output_dir, is_main_process, save_checkpoint_format=self.args.save_checkpoint_format
)
else:
is_main_process = paddle.distributed.get_rank() == 0
if isinstance(self.model, LoRAModel):
self.model.save_pretrained(
output_dir,
merge_tensor_parallel=True,
variant=self.args.weight_name_suffix,
save_checkpoint_format=self.args.save_checkpoint_format,
)
else:
self._save_flex_model_state(output_dir)
self._save_flex_model_state(output_dir)

if self.tokenizer is not None and self.args.save_tokenizer:
self.tokenizer.save_pretrained(output_dir)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ sentencepiece
huggingface_hub>=0.19.2
protobuf>=3.20.2
visualdl
safetensors @ https://paddle-whl.bj.bcebos.com/nightly/cu126/safetensors/safetensors-0.6.2.dev0-cp38-abi3-linux_x86_64.whl
safetensors
fast_dataindex>=0.1.1 ; platform_system == "Linux"
aistudio-sdk>=0.3.0
jinja2
Expand Down