diff --git a/CMakeLists.txt b/CMakeLists.txt index 57e97ddc..97ba5acf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -176,7 +176,7 @@ add_executable(gpt2 example/gpt2/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc - example/gpt2/checkpoint_loader.cc + example/common/checkpoint_loader.cc example/common/tokenizer.cc ) link_infini_train_exe(gpt2) @@ -185,7 +185,7 @@ add_executable(llama3 example/llama3/main.cc example/common/tiny_shakespeare_dataset.cc example/common/utils.cc - example/llama3/checkpoint_loader.cc + example/common/checkpoint_loader.cc example/common/tokenizer.cc ) link_infini_train_exe(llama3) diff --git a/example/common/checkpoint_loader.cc b/example/common/checkpoint_loader.cc new file mode 100644 index 00000000..c5a456e5 --- /dev/null +++ b/example/common/checkpoint_loader.cc @@ -0,0 +1,1060 @@ +#include "example/common/checkpoint_loader.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "example/common/utils.h" +#include "example/gpt2/config.h" +#include "example/llama3/config.h" +#include "infini_train/include/nn/modules/normalization.h" +#include "infini_train/include/nn/modules/sparse.h" +#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" +#include "infini_train/include/nn/modules/transformer/mlp.h" +#include "infini_train/include/nn/modules/transformer/transformer.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" +#include "infini_train/include/nn/parallel/tensor_parallel.h" +#include "infini_train/include/tensor.h" + +using namespace infini_train; +namespace nn = infini_train::nn; + +namespace { +constexpr int kRandomSeed = 42; + +// TODO(dcj): make this rng generator compatible with torch later +static std::mt19937 gen{kRandomSeed}; +} // namespace + +namespace { +constexpr int32_t kGPT2Magic = 20240326; +constexpr int32_t kGPT2FP32Version = 3; +constexpr int32_t kGPT2BF16Version = 5; + +constexpr int32_t kLLaMA3Magic = 20240803; +constexpr int32_t kLLaMA3FP32Version = 3; + +std::tuple DetermineAndCheckVersion(const std::vector &header, + size_t offset) { + const auto version = BytesToType(header, offset); + switch (version) { + case kGPT2FP32Version: + return {version, infini_train::DataType::kBFLOAT16}; + case kGPT2BF16Version: + return {version, infini_train::DataType::kFLOAT32}; + default: + LOG(FATAL) << "Unsupported version: " << version << " at " << __FILE__ << ":" << __LINE__; + return {}; // Unreachable, but keeps compiler happy + } +} +} // namespace + +namespace infini_train { +namespace gpt2 { + +std::shared_ptr LoadFromLLMC(const std::string &filepath) { + if (!std::filesystem::exists(filepath)) { + LOG(FATAL) << "File not found: " << filepath; + } + + std::ifstream ifs(filepath, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs); + + const auto magic = BytesToType(header, 0); + CHECK_EQ(magic, kGPT2Magic); + auto [version, dtype] = DetermineAndCheckVersion(header, 4); + CHECK_EQ(version, kGPT2FP32Version); + + auto tp_size = nn::parallel::global::GetTensorParallelSize(); + + const auto block_size = BytesToType(header, 8); + const auto vocab_size = BytesToType(header, 12); + const auto n_layer = BytesToType(header, 16); + const auto n_head = BytesToType(header, 20); + const auto n_embd = BytesToType(header, 24); + const auto padded_vocab_size = BytesToType(header, 28); + // NOTE(zbl): vocab_size needs to be padded to multiple of TP size + const auto model_vocab_size = tp_size > 1 ? padded_vocab_size : vocab_size; + + nn::TransformerConfig gpt2_config = infini_train::gpt2::GPT2Config(); + gpt2_config.block_size = block_size; + gpt2_config.vocab_size = model_vocab_size; + gpt2_config.original_vocab_size = vocab_size; + gpt2_config.n_layer = n_layer; + gpt2_config.n_head = n_head; + gpt2_config.n_embd = n_embd; + auto local_gpt2 = std::make_shared(gpt2_config); + + LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size + << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head + << " n_embd: " << n_embd << " padded_vocab_size: " << padded_vocab_size; + + CHECK_EQ(n_embd % tp_size, 0) << "n_embd must be divisible by TP world size."; + CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; + CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto pp_rank = nn::parallel::pp_rank; + auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); + // ========== layer to chunk ========== + std::vector owned_layers(n_layer, false); + for (const auto &[start, end] : layer_ranges_per_chunk) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } + + auto tp_rank = nn::parallel::tp_rank; + // calculate xx_size_per_partition + const int64_t vpp = model_vocab_size / tp_size; + const int64_t v_start = static_cast(tp_rank) * vpp; + const int64_t v_end = v_start + vpp; + + const int64_t qkv_out = 3 * n_embd; + const int64_t qkv_pp = qkv_out / tp_size; + const int64_t qkv_start = static_cast(tp_rank) * qkv_pp; + + const int64_t fc_out = 4 * n_embd; + const int64_t fc_pp = fc_out / tp_size; + const int64_t fc_start = static_cast(tp_rank) * fc_pp; + + const int64_t in_pp = n_embd / tp_size; // for c_proj (row-parallel, shard on input) + const int64_t in4_pp = (4 * n_embd) / tp_size; // for mlp.c_proj (input shard) + + auto state_dict = local_gpt2->StateDict(); + + // transformer.wte.weight (also transformer.lm_head.weight) + // full: (model_vocab_size, n_embd) + // local: (vocab_size_per_partition, n_embd) + if (is_first_stage) { + auto &transformer_wte_weight = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, + v_start, vpp); + } else if (pp_size > 1 && is_last_stage) { + auto &lm_head_weight = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start, + vpp); + } else { + size_t wte_bytes = model_vocab_size * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); + } + + if (tp_size == 1) { + // Skip padded vocab part when TP is not enabled + ifs.ignore((padded_vocab_size - model_vocab_size) * n_embd * sizeof(float)); + } + + if (is_first_stage) { + // transformer.wpe.weight + auto &transformer_wpe_weight + = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWPELayerName, nn::Embedding::kParamWeightName)]; + ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); + } else { + size_t wpe_bytes = block_size * n_embd * sizeof(float); + ifs.seekg(wpe_bytes, std::ios::cur); + } + + // transformer.h.{i}.ln_1.weight + int local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_1_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.ln_1.bias + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_1_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_b_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, + // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T + // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t rows_all = 3 * n_embd; + const int64_t cols_all = n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i -> write to dst rows of [0 : local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); + // Read k_i -> write to dst rows of [local_C : 2*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (1 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + // Read v_i -> write to dst rows of [2*local_C : 3*local_C) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (2 * local_C) * cols_all, + /*rows=*/rows_all, /*cols=*/cols_all, + /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); + + ++local_layer_index; + } else { + size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); + ifs.seekg(c_attn_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; + // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated + // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] + // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them + // respectively + float *dst = static_cast(tensor->DataPtr()); + const int64_t local_C = n_embd / tp_size; + const int64_t len_all = 3 * n_embd; + const std::streampos base_pos = ifs.tellg(); + // Read q_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (0 * local_C), + /*len=*/len_all, + /*start=*/tp_rank * local_C, /*cnt=*/local_C); + // Read k_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (1 * local_C), + /*len=*/len_all, + /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); + // Read v_i + ifs.seekg(base_pos); + ReadVectorShardFloat(ifs, + /*dst=*/dst + (2 * local_C), + /*len=*/len_all, + /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); + + ++local_layer_index; + } else { + size_t c_attn_b_bytes = qkv_out * sizeof(float); + ifs.seekg(c_attn_b_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, + in_pp); + ++local_layer_index; + } else { + size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.ln_2.weight + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor + = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_2_w_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.ln_2.bias + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_2_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_2_b_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); + ++local_layer_index; + } else { + size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_fc_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamBiasName)]; + ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); + ++local_layer_index; + } else { + size_t c_fc_b_bytes = fc_out * sizeof(float); + ifs.seekg(c_fc_b_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, + in4_pp); + ++local_layer_index; + } else { + size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); + ifs.seekg(c_proj_w_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) + local_layer_index = 0; + for (int idx = 0; idx < n_layer; ++idx) { + if (owned_layers[idx]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t c_proj_b_bytes = n_embd * sizeof(float); + ifs.seekg(c_proj_b_bytes, std::ios::cur); + } + } + + if (is_last_stage) { + // transformer.ln_f.weight + auto &transformer_ln_f_weight + = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); + // transformer.ln_f.bias + auto &transformer_ln_f_bias + = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; + ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); + } else { + size_t ln_f_w_bytes = n_embd * sizeof(float); + size_t ln_f_b_bytes = n_embd * sizeof(float); + ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur); + } + + return local_gpt2; +} + +void SaveAsLLMC(const std::shared_ptr &model, const std::string &filepath) { + CHECK_EQ(nn::parallel::global::GetTensorParallelSize(), 1) << "SaveAsLLMC currently supports TP=1 only."; + CHECK_EQ(nn::parallel::global::GetPipelineParallelSize(), 1) << "SaveAsLLMC currently supports PP=1 only."; + + std::ofstream ofs(filepath, std::ios::binary); + CHECK(ofs.is_open()) << "Failed to open model file for write: " << filepath; + + auto config = model->Config(); + std::vector header(256, 0); + header[0] = kGPT2Magic; + header[1] = kGPT2FP32Version; + header[2] = static_cast(config.block_size); + header[3] = static_cast(config.original_vocab_size); + header[4] = static_cast(config.n_layer); + header[5] = static_cast(config.n_head); + header[6] = static_cast(config.n_embd); + header[7] = static_cast(config.vocab_size); + ofs.write(reinterpret_cast(header.data()), + static_cast(header.size() * sizeof(int32_t))); + + const auto state_dict = model->StateDict(); + auto get_tensor = [&](const std::string &name) -> std::shared_ptr { + CHECK(state_dict.contains(name)) << "Missing tensor in GPT2 state_dict: " << name; + return state_dict.at(name); + }; + + auto write_tensor_fp32 = [&](const std::shared_ptr &tensor) { + Tensor cpu = tensor->To(Device()); + if (cpu.Dtype() != DataType::kFLOAT32) { + cpu = cpu.To(DataType::kFLOAT32); + } + const auto bytes = static_cast(cpu.SizeInBytes()); + ofs.write(reinterpret_cast(cpu.DataPtr()), bytes); + }; + + // transformer.wte.weight + write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName))); + + // transformer.wpe.weight + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWPELayerName, nn::Embedding::kParamWeightName))); + + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format( + "{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kAttnLayerName, nn::CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor( + std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor( + std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor( + std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format( + "{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName))); + } + + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName))); + write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName))); + + ofs.flush(); + CHECK(ofs.good()) << "Failed to flush model file: " << filepath; +} +} // namespace gpt2 + +namespace llama3 { + +std::shared_ptr LoadFromLLMC(const std::string &filepath) { + if (!std::filesystem::exists(filepath)) { + LOG(FATAL) << "File not found: " << filepath; + } + + std::ifstream ifs(filepath, std::ios::binary); + const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs); + + const auto magic = BytesToType(header, 0); + CHECK_EQ(magic, kLLaMA3Magic); + const auto version = BytesToType(header, 4); + CHECK_EQ(version, kLLaMA3FP32Version); + + const auto block_size = BytesToType(header, 8); + const auto vocab_size = BytesToType(header, 12); + const auto n_layer = BytesToType(header, 16); + const auto n_head = BytesToType(header, 20); + const auto n_kv_head = BytesToType(header, 24); + const auto n_embd = BytesToType(header, 28); + const auto ffn_dim_multiplier = BytesToType(header, 32); + const auto multiple_of = BytesToType(header, 36); + const auto norm_eps = BytesToType(header, 40); + const auto rope_theta = BytesToType(header, 44); + const auto use_scaled_rope = BytesToType(header, 48); + const auto max_gen_bs = BytesToType(header, 52); + const auto version_major = BytesToType(header, 56); + const auto version_minor = BytesToType(header, 60); + + nn::TransformerConfig llama3_config = infini_train::llama3::LLaMA3Config(); + llama3_config.block_size = block_size; + llama3_config.vocab_size = vocab_size; + llama3_config.n_layer = n_layer; + llama3_config.n_head = n_head; + llama3_config.n_kv_head = n_kv_head; + llama3_config.n_embd = n_embd; + llama3_config.ffn_dim_multiplier = ffn_dim_multiplier; + llama3_config.multiple_of = multiple_of; + llama3_config.rope_theta = rope_theta; + llama3_config.use_scaled_rope = static_cast(use_scaled_rope); + llama3_config.norm_eps = norm_eps; + llama3_config.max_gen_batch_size = max_gen_bs; + auto llama3 = std::make_shared(llama3_config); + + // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== + int pp_size = nn::parallel::global::GetPipelineParallelSize(); + int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); + auto pp_rank = nn::parallel::pp_rank; + auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] + = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); + // ========== layer to chunk ========== + std::vector owned_layers(n_layer, false); + for (const auto &[start, end] : layer_ranges_per_chunk) { + for (int i = start; i < end; ++i) { owned_layers[i] = true; } + } + + const int tp_size = nn::parallel::global::GetTensorParallelSize(); + const int tp_rank = nn::parallel::tp_rank; + + CHECK_EQ(n_embd % tp_size, 0) << "n_embd must be divisible by TP world size."; + CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; + CHECK_EQ(n_kv_head % tp_size, 0) << "n_kv_head must be divisible by TP world size."; + CHECK_EQ(vocab_size % tp_size, 0) << "vocab_size must be divisible by TP world size."; + + if (tp_rank == 0) { + LOG(INFO) << "Model Config:"; + LOG(INFO) << " block_size = " << block_size; + LOG(INFO) << " vocab_size = " << vocab_size; + LOG(INFO) << " n_layer = " << n_layer; + LOG(INFO) << " n_head = " << n_head; + LOG(INFO) << " n_kv_head = " << n_kv_head; + LOG(INFO) << " n_embd = " << n_embd; + LOG(INFO) << " ffn_dim_multiplier = " << ffn_dim_multiplier; + LOG(INFO) << " multiple_of = " << multiple_of; + LOG(INFO) << " norm_eps = " << norm_eps; + LOG(INFO) << " rope_theta = " << rope_theta; + LOG(INFO) << " use_scaled_rope = " << use_scaled_rope; + LOG(INFO) << " max_gen_bs = " << max_gen_bs; + LOG(INFO) << " version_major = " << version_major; + LOG(INFO) << " version_minor = " << version_minor; + + LOG(INFO) << "Pipeline Parallel Chunks:"; + for (size_t i = 0; i < layer_ranges_per_chunk.size(); ++i) { + LOG(INFO) << " Chunk " << i << ": layers " << layer_ranges_per_chunk[i].first << " to " + << layer_ranges_per_chunk[i].second; + } + } + + const int64_t head_dim = static_cast(n_embd) / static_cast(n_head); + + // nn::MLP hidden dim calculation in LLaMA-3 + auto round_up_to = [](int64_t x, int64_t m) { return (x + m - 1) / m * m; }; + int64_t hidden_dim = 4LL * static_cast(n_embd); + hidden_dim = (2LL * hidden_dim) / 3LL; + if (ffn_dim_multiplier > 0.0f) { + hidden_dim = static_cast( + std::llround(static_cast(ffn_dim_multiplier) * static_cast(hidden_dim))); + } + + int64_t ffn_hidden = round_up_to(hidden_dim, static_cast(multiple_of)); + + // ===== Per-rank sizes / offsets ===== + // vocab parallel + const int64_t vpp = static_cast(vocab_size) / tp_size; + const int64_t v_start = static_cast(tp_rank) * vpp; + + // attention Q/K/V packed as rows: [Q | K | V] + const int64_t q_out_rows = static_cast(n_embd); + const int64_t kv_out_rows = static_cast(n_kv_head) * head_dim; // for K or V (each) + const int64_t attn_rows_all = q_out_rows + 2 * kv_out_rows; + const int64_t attn_cols = static_cast(n_embd); + + // local Q/K/V rows per tp_rank + const int64_t q_local_rows = static_cast(n_embd) / tp_size; // = (n_head/world)*head_dim + const int64_t kv_head_local = static_cast(n_kv_head) / tp_size; + const int64_t kv_local_rows = kv_head_local * head_dim; // for K or V (each) + const int64_t attn_local_rows = q_local_rows + 2 * kv_local_rows; + + // RowParallel (proj) + const int64_t in_pp = static_cast(n_embd) / tp_size; + // nn::MLP: c_fc/c_fc2(shard along row),c_proj(shard along col) + const int64_t fc_out = ffn_hidden; + const int64_t fc_pp = fc_out / tp_size; + const int64_t in_fc_pp = ffn_hidden / tp_size; + + auto state_dict = llama3->StateDict(); + + // ========== Read Sharded Params ========== + // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) + if (is_first_stage) { + auto &wte = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), + /*rows=*/vocab_size, /*cols=*/n_embd, + /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t wte_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(wte_bytes, std::ios::cur); + } + + // transformer.h.{i}.ln_1.weight : Full version nn::RMSNorm + int local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn1LayerName, nn::RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_1_bytes = n_embd * sizeof(float); + ifs.seekg(ln_1_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" + // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; + + float *dst = static_cast(tensor->DataPtr()); + const std::streampos base_pos = ifs.tellg(); + + // Q block -> [0 : q_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (0 * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); + + // K block -> [q_local_rows : q_local_rows + kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + (q_local_rows * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); + + // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) + ifs.seekg(base_pos); + ReadMatrixRowShardFloat(ifs, + /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), + /*rows=*/attn_rows_all, /*cols=*/attn_cols, + /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, + /*row_cnt=*/kv_local_rows); + ++local_layer_index; + } else { + size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); + ifs.seekg(qkv_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, + std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/n_embd, + /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); + ++local_layer_index; + } else { + size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.ln_2.weight : Full version RMSNorm + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kLn2LayerName, nn::RMSNorm::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); + ++local_layer_index; + } else { + size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); + ifs.seekg(ln_2_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; + } else { + size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFc2LayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/fc_out, /*cols=*/n_embd, + /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); + ++local_layer_index; + } else { + size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); + ifs.seekg(fc2_bytes, std::ios::cur); + } + } + + // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" + local_layer_index = 0; + for (int i = 0; i < static_cast(n_layer); ++i) { + if (owned_layers[i]) { + auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), + nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, + nn::parallel::RowParallelLinear::kParamWeightName)]; + ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), + /*rows=*/n_embd, /*cols=*/fc_out, + /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); + ++local_layer_index; + } else { + size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); + ifs.seekg(c_proj_bytes, std::ios::cur); + } + } + + // transformer.ln_f.weight : Full version nn::RMSNorm + // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" + { + if (is_last_stage) { + auto &ln_f + = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::RMSNorm::kParamWeightName)]; + auto &lm_head = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName)]; + ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); + ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), + /*rows=*/vocab_size, /*cols=*/n_embd, + /*row_start=*/v_start, /*row_cnt=*/vpp); + } else { + size_t ln_f_bytes = static_cast(n_embd) * sizeof(float); + size_t lm_head_bytes = static_cast(vocab_size) * n_embd * sizeof(float); + ifs.seekg(ln_f_bytes + lm_head_bytes, std::ios::cur); + } + } + + return llama3; +} + +void SaveAsLLMC(const std::shared_ptr &model, const std::string &filepath) { + CHECK_EQ(nn::parallel::global::GetTensorParallelSize(), 1) << "SaveAsLLMC currently supports TP=1 only."; + CHECK_EQ(nn::parallel::global::GetPipelineParallelSize(), 1) << "SaveAsLLMC currently supports PP=1 only."; + + std::ofstream ofs(filepath, std::ios::binary); + CHECK(ofs.is_open()) << "Failed to open model file for write: " << filepath; + + auto pack_float = [](float value) -> int32_t { + int32_t bits = 0; + std::memcpy(&bits, &value, sizeof(float)); + return bits; + }; + + auto config = model->Config(); + std::vector header(256, 0); + header[0] = kLLaMA3Magic; + header[1] = kLLaMA3FP32Version; + header[2] = static_cast(config.block_size); + header[3] = static_cast(config.vocab_size); + header[4] = static_cast(config.n_layer); + header[5] = static_cast(config.n_head); + header[6] = static_cast(config.n_kv_head); + header[7] = static_cast(config.n_embd); + header[8] = pack_float(config.ffn_dim_multiplier.value_or(0.0f)); + header[9] = static_cast(config.multiple_of); + header[10] = pack_float(config.norm_eps); + header[11] = pack_float(config.rope_theta); + header[12] = static_cast(config.use_scaled_rope ? 1 : 0); + header[13] = static_cast(config.max_gen_batch_size); + header[14] = 1; // version_major + header[15] = 0; // version_minor + ofs.write(reinterpret_cast(header.data()), + static_cast(header.size() * sizeof(int32_t))); + + const auto state_dict = model->StateDict(); + auto get_tensor = [&](const std::string &name) -> std::shared_ptr { + CHECK(state_dict.contains(name)) << "Missing tensor in LLaMA3 state_dict: " << name; + return state_dict.at(name); + }; + + auto write_tensor_fp32 = [&](const std::shared_ptr &tensor) { + Tensor cpu = tensor->To(Device()); + if (cpu.Dtype() != DataType::kFLOAT32) { + cpu = cpu.To(DataType::kFLOAT32); + } + const auto bytes = static_cast(cpu.SizeInBytes()); + ofs.write(reinterpret_cast(cpu.DataPtr()), bytes); + }; + + write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerFirstStage::kWTELayerName, + nn::parallel::VocabParallelEmbedding::kParamWeightName))); + + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn1LayerName, nn::RMSNorm::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format( + "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kAttnLayerName, nn::CausalSelfAttention::kCAttnLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor( + std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kAttnLayerName, + nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32(get_tensor(std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, + nn::TransformerLayer::kLn2LayerName, nn::RMSNorm::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCFcLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCFc2LayerName, nn::parallel::ColumnParallelLinear::kParamWeightName))); + } + for (int idx = 0; idx < config.n_layer; ++idx) { + write_tensor_fp32( + get_tensor(std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerChunk::kHLayerName, idx, nn::TransformerLayer::kMlpLayerName, + nn::MLP::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName))); + } + + write_tensor_fp32(get_tensor(std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, + nn::TransformerLastStage::kLnFLayerName, nn::RMSNorm::kParamWeightName))); + write_tensor_fp32(get_tensor(std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, + nn::parallel::ColumnParallelLinear::kParamWeightName))); + + ofs.flush(); + CHECK(ofs.good()) << "Failed to flush model file: " << filepath; +} +} // namespace llama3 + // namespace infini_train + +ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args) { + ResumeFromCheckpointResult result; + int ddp_world_size = nn::parallel::global::GetDataParallelSize(); + + if (args.resume_root.empty()) { + LOG(INFO) << "No checkpoint specified for resume. Starting training from scratch."; + return result; + } + + std::filesystem::path resume_dir = args.resume_root; + if (args.rank.IsParallel()) { + const auto rank_dir = resume_dir / std::format("rank_{:06d}", args.rank.GlobalRank()); + if (std::filesystem::exists(rank_dir)) { + resume_dir = rank_dir; + } + } + + Checkpoint::Load(resume_dir, args.model.get(), args.optimizer.get(), &args.state, args.load_options); + + result.global_step = static_cast(args.state.global_step); + result.best_loss = args.state.best_loss; + if (args.state.data_batch_stride != static_cast(ddp_world_size)) { + LOG(FATAL) << std::format("Checkpoint data_batch_stride {} mismatches current ddp_world_size {}. " + "Proceeding with recorded data_batch_idx {}.", + args.state.data_batch_stride, ddp_world_size, args.state.data_batch_idx); + } + result.data_batch_idx = static_cast(std::max(args.state.data_batch_idx, 0)); + args.train_iter = args.train_loader.IteratorAtBatchIndex(result.data_batch_idx); + if (args.rank.IsMainRank()) { + LOG(INFO) << std::format( + "Resume training from step {} with best_loss {:.6f}, last_lr {:.3e}, data_batch_idx {}", + args.state.global_step, args.state.best_loss, args.state.last_lr, args.state.data_batch_idx); + } + + return result; +} + +void SaveCheckpoint(const SaveCheckpointArgs &args) { + const auto ckpt_start = std::chrono::high_resolution_clock::now(); + + TrainerState state; + state.global_step = args.global_step; + state.data_batch_idx = static_cast(args.data_batch_idx); + state.data_batch_stride = args.ddp_size; + state.best_loss = args.best_loss; + state.last_lr = args.last_lr; + state.optimizer_type = args.optimizer_type; + state.checkpoint_format = args.checkpoint_format; + state.ddp_size = args.ddp_size; + state.tp_size = args.tp_size; + state.sp_size = args.sp_size; + state.pp_size = args.pp_size; + + CheckpointOptions options; + options.format = args.checkpoint_format; + options.save_optimizer_state = args.save_optimizer_state; + options.model_bin_writer = args.model_bin_writer; + Checkpoint::Save(args.save_dir, args.model, args.optimizer, state, options); + + const auto ckpt_end = std::chrono::high_resolution_clock::now(); + const double ckpt_ms = std::chrono::duration(ckpt_end - ckpt_start).count(); + + if (!args.rank.IsMainRank()) { + return; + } + + LOG(INFO) << std::format("Checkpoint saved at: {} ({:.2f} ms)", args.save_dir.string(), ckpt_ms); + + if (!args.prune_step_checkpoints) { + return; + } + + std::vector ckpts; + if (std::filesystem::exists(args.checkpoint_root_dir)) { + for (const auto &entry : std::filesystem::directory_iterator(args.checkpoint_root_dir)) { + if (entry.is_directory() && entry.path().filename().string().starts_with("checkpoint_step_")) { + ckpts.push_back(entry.path()); + } + } + std::sort(ckpts.begin(), ckpts.end()); + while (ckpts.size() > args.max_checkpoint_keep) { + std::filesystem::remove_all(ckpts.front()); + ckpts.erase(ckpts.begin()); + } + } +} +} // namespace infini_train diff --git a/example/common/checkpoint_loader.h b/example/common/checkpoint_loader.h new file mode 100644 index 00000000..ee135171 --- /dev/null +++ b/example/common/checkpoint_loader.h @@ -0,0 +1,76 @@ +#pragma once + +#include "infini_train/include/checkpoint.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/rank.h" +#include "infini_train/include/optimizer.h" + +#include "gflags/gflags.h" + +#include +#include +#include + +#include +#include +#include + +namespace infini_train { +namespace nn { +class TransformerModel; +} + +namespace gpt2 { +std::shared_ptr LoadFromLLMC(const std::string &filepath); +void SaveAsLLMC(const std::shared_ptr &model, const std::string &filepath); +} // namespace gpt2 +namespace llama3 { +std::shared_ptr LoadFromLLMC(const std::string &filepath); +void SaveAsLLMC(const std::shared_ptr &model, const std::string &filepath); +} // namespace llama3 + +struct ResumeFromCheckpointArgs { + fLS::clstring resume_root; + const nn::parallel::Rank &rank; + std::shared_ptr model; + std::shared_ptr optimizer; + DistributedDataLoader &train_loader; + TrainerState &state; + DataLoaderIterator &train_iter; + CheckpointLoadOptions load_options; +}; + +struct ResumeFromCheckpointResult { + int global_step = 0; + float best_loss = std::numeric_limits::infinity(); + size_t data_batch_idx = 0; +}; + +struct SaveCheckpointArgs { + std::filesystem::path save_dir; + int64_t global_step = 0; + size_t data_batch_idx = 0; + float best_loss = std::numeric_limits::infinity(); + double last_lr = 0.0; + std::string optimizer_type; + std::string checkpoint_format = "bin"; + int ddp_size = 1; + int tp_size = 1; + int sp_size = 1; + int pp_size = 1; + bool save_optimizer_state = true; + bool prune_step_checkpoints = false; + std::filesystem::path checkpoint_root_dir; + size_t max_checkpoint_keep = 0; + const nn::parallel::Rank &rank; + const nn::Module &model; + const Optimizer &optimizer; + std::function model_bin_writer; +}; + +ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args); + +void SaveCheckpoint(const SaveCheckpointArgs &args); + +} // namespace infini_train diff --git a/example/common/utils.cc b/example/common/utils.cc index 03cc7aa0..28a35172 100644 --- a/example/common/utils.cc +++ b/example/common/utils.cc @@ -1,5 +1,13 @@ #include "example/common/utils.h" +#include +#include + +#include "gflags/gflags.h" +#include "gflags/gflags_declare.h" +#include "glog/logging.h" +#include "infini_train/include/nn/parallel/global.h" + namespace infini_train { float ConvertBF16ToFloat(void *ptr) { @@ -60,5 +68,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s ifs.read(reinterpret_cast(dst), static_cast(cnt * sizeof(float))); ifs.seekg(base + std::streamoff(len * sizeof(float))); } - } // namespace infini_train diff --git a/example/common/utils.h b/example/common/utils.h index 5bab3e97..ed4a8f7c 100644 --- a/example/common/utils.h +++ b/example/common/utils.h @@ -1,8 +1,20 @@ #pragma once +#include "infini_train/include/checkpoint.h" +#include "infini_train/include/dataloader.h" +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/rank.h" +#include "infini_train/include/optimizer.h" + +#include "gflags/gflags.h" + #include #include +#include #include +#include +#include +#include #include namespace infini_train { diff --git a/example/gpt2/checkpoint_loader.cc b/example/gpt2/checkpoint_loader.cc deleted file mode 100644 index 57064423..00000000 --- a/example/gpt2/checkpoint_loader.cc +++ /dev/null @@ -1,430 +0,0 @@ -#include "example/gpt2/checkpoint_loader.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "glog/logging.h" - -#include "example/common/utils.h" -#include "example/gpt2/config.h" -#include "infini_train/include/nn/modules/normalization.h" -#include "infini_train/include/nn/modules/sparse.h" -#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" -#include "infini_train/include/nn/modules/transformer/mlp.h" -#include "infini_train/include/nn/modules/transformer/transformer.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" -#include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; -namespace nn = infini_train::nn; - -namespace { -constexpr int kRandomSeed = 42; - -// TODO(dcj): make this rng generator compatible with torch later -static std::mt19937 gen{kRandomSeed}; -} // namespace - -namespace { -constexpr int32_t kHeaderMagic = 20240326; -constexpr int32_t kHeaderFP32Version = 3; -constexpr int32_t kHeaderBF16Version = 5; - -std::tuple DetermineAndCheckVersion(const std::vector &header, - size_t offset) { - const auto version = BytesToType(header, offset); - switch (version) { - case kHeaderBF16Version: - return {version, infini_train::DataType::kBFLOAT16}; - case kHeaderFP32Version: - return {version, infini_train::DataType::kFLOAT32}; - default: - LOG(FATAL) << "Unsupported version: " << version << " at " << __FILE__ << ":" << __LINE__; - return {}; // Unreachable, but keeps compiler happy - } -} -} // namespace - -namespace gpt2 { - -std::shared_ptr LoadFromLLMC(const std::string &filepath) { - if (!std::filesystem::exists(filepath)) { - LOG(FATAL) << "File not found: " << filepath; - } - - std::ifstream ifs(filepath, std::ios::binary); - const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs); - - const auto magic = BytesToType(header, 0); - CHECK_EQ(magic, kHeaderMagic); - auto [version, dtype] = DetermineAndCheckVersion(header, 4); - CHECK_EQ(version, kHeaderFP32Version); - - auto tp_size = nn::parallel::global::GetTensorParallelSize(); - - const auto block_size = BytesToType(header, 8); - const auto vocab_size = BytesToType(header, 12); - const auto n_layer = BytesToType(header, 16); - const auto n_head = BytesToType(header, 20); - const auto n_embd = BytesToType(header, 24); - const auto padded_vocab_size = BytesToType(header, 28); - // NOTE(zbl): vocab_size needs to be padded to multiple of TP size - const auto model_vocab_size = tp_size > 1 ? padded_vocab_size : vocab_size; - - nn::TransformerConfig gpt2_config = gpt2::GPT2Config(); - gpt2_config.block_size = block_size; - gpt2_config.vocab_size = model_vocab_size; - gpt2_config.original_vocab_size = vocab_size; - gpt2_config.n_layer = n_layer; - gpt2_config.n_head = n_head; - gpt2_config.n_embd = n_embd; - auto local_gpt2 = std::make_shared(gpt2_config); - - LOG(INFO) << "magic: " << magic << " version: " << version << " block_size: " << block_size - << " vocab_size: " << vocab_size << " n_layer: " << n_layer << " n_head: " << n_head - << " n_embd: " << n_embd << " padded_vocab_size: " << padded_vocab_size; - - CHECK_EQ(n_embd % tp_size, 0) << "n_embd must be divisible by TP world size."; - CHECK_EQ(n_embd % n_head, 0) << "n_embd must be divisible by n_head."; - CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; - - // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_rank = nn::parallel::pp_rank; - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); - // ========== layer to chunk ========== - std::vector owned_layers(n_layer, false); - for (const auto &[start, end] : layer_ranges_per_chunk) { - for (int i = start; i < end; ++i) { owned_layers[i] = true; } - } - - auto tp_rank = nn::parallel::tp_rank; - // calculate xx_size_per_partition - const int64_t vpp = model_vocab_size / tp_size; - const int64_t v_start = static_cast(tp_rank) * vpp; - const int64_t v_end = v_start + vpp; - - const int64_t qkv_out = 3 * n_embd; - const int64_t qkv_pp = qkv_out / tp_size; - const int64_t qkv_start = static_cast(tp_rank) * qkv_pp; - - const int64_t fc_out = 4 * n_embd; - const int64_t fc_pp = fc_out / tp_size; - const int64_t fc_start = static_cast(tp_rank) * fc_pp; - - const int64_t in_pp = n_embd / tp_size; // for c_proj (row-parallel, shard on input) - const int64_t in4_pp = (4 * n_embd) / tp_size; // for mlp.c_proj (input shard) - - auto state_dict = local_gpt2->StateDict(); - - // transformer.wte.weight (also transformer.lm_head.weight) - // full: (model_vocab_size, n_embd) - // local: (vocab_size_per_partition, n_embd) - if (is_first_stage) { - auto &transformer_wte_weight = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerFirstStage::kWTELayerName, - nn::parallel::VocabParallelEmbedding::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(transformer_wte_weight->DataPtr()), model_vocab_size, n_embd, - v_start, vpp); - } else if (pp_size > 1 && is_last_stage) { - auto &lm_head_weight = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(lm_head_weight->DataPtr()), model_vocab_size, n_embd, v_start, - vpp); - } else { - size_t wte_bytes = model_vocab_size * n_embd * sizeof(float); - ifs.seekg(wte_bytes, std::ios::cur); - } - - if (tp_size == 1) { - // Skip padded vocab part when TP is not enabled - ifs.ignore((padded_vocab_size - model_vocab_size) * n_embd * sizeof(float)); - } - - if (is_first_stage) { - // transformer.wpe.weight - auto &transformer_wpe_weight - = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerFirstStage::kWPELayerName, nn::Embedding::kParamWeightName)]; - ReadMatrixAllFloat(ifs, static_cast(transformer_wpe_weight->DataPtr()), block_size, n_embd); - } else { - size_t wpe_bytes = block_size * n_embd * sizeof(float); - ifs.seekg(wpe_bytes, std::ios::cur); - } - - // transformer.h.{i}.ln_1.weight - int local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_1_w_bytes = n_embd * sizeof(float); - ifs.seekg(ln_1_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.ln_1.bias - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn1LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_1_b_bytes = n_embd * sizeof(float); - ifs.seekg(ln_1_b_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_attn.weight (ColumnParallelLinear, but actually applies on "rows") - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - // NOTE(zbl): In the .bin model file, Q/K/V is concated along last dim, - // i.e. [Q|K|V].T = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn].T - // However, each tp_rank needs to get [q_i|k_i|v_i].T, so we need to jump and read them - // respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t rows_all = 3 * n_embd; - const int64_t cols_all = n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i -> write to dst rows of [0 : local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/tp_rank * local_C, /*row_cnt=*/local_C); - // Read k_i -> write to dst rows of [local_C : 2*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (1 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/n_embd + tp_rank * local_C, /*row_cnt=*/local_C); - // Read v_i -> write to dst rows of [2*local_C : 3*local_C) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (2 * local_C) * cols_all, - /*rows=*/rows_all, /*cols=*/cols_all, - /*row_start=*/2 * n_embd + tp_rank * local_C, /*row_cnt=*/local_C); - - ++local_layer_index; - } else { - size_t c_attn_w_bytes = qkv_out * n_embd * sizeof(float); - ifs.seekg(c_attn_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_attn.bias (ColumnParallelLinear) - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamBiasName)]; - // NOTE(zbl): Same as c_attn.weight, the bias for Q/K/V is concated - // i.e. [Q|K|V] = [q1|q2|...|qn|k1|k2|...|kn|v1|v2|...|vn] - // However, each tp_rank needs to get [q_i|k_i|v_i], so we need to jump and read them - // respectively - float *dst = static_cast(tensor->DataPtr()); - const int64_t local_C = n_embd / tp_size; - const int64_t len_all = 3 * n_embd; - const std::streampos base_pos = ifs.tellg(); - // Read q_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (0 * local_C), - /*len=*/len_all, - /*start=*/tp_rank * local_C, /*cnt=*/local_C); - // Read k_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (1 * local_C), - /*len=*/len_all, - /*start=*/n_embd + tp_rank * local_C, /*cnt=*/local_C); - // Read v_i - ifs.seekg(base_pos); - ReadVectorShardFloat(ifs, - /*dst=*/dst + (2 * local_C), - /*len=*/len_all, - /*start=*/2 * n_embd + tp_rank * local_C, /*cnt=*/local_C); - - ++local_layer_index; - } else { - size_t c_attn_b_bytes = qkv_out * sizeof(float); - ifs.seekg(c_attn_b_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_proj.weight (RowParallelLinear, but actually applies on "columns") - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, n_embd, tp_rank * in_pp, - in_pp); - ++local_layer_index; - } else { - size_t c_proj_w_bytes = n_embd * n_embd * sizeof(float); - ifs.seekg(c_proj_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_proj.bias (RowParallelLinear, no shard on bias) - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t c_proj_b_bytes = n_embd * sizeof(float); - ifs.seekg(c_proj_b_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.ln_2.weight - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor - = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_2_w_bytes = n_embd * sizeof(float); - ifs.seekg(ln_2_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.ln_2.bias - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn2LayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_2_b_bytes = n_embd * sizeof(float); - ifs.seekg(ln_2_b_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_fc.weight (ColumnParallelLinear, but actually applies on "rows") - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, n_embd, fc_start, fc_pp); - ++local_layer_index; - } else { - size_t c_fc_w_bytes = fc_out * n_embd * sizeof(float); - ifs.seekg(c_fc_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_fc.bias (ColumnParallelLinear) - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamBiasName)]; - ReadVectorShardFloat(ifs, static_cast(tensor->DataPtr()), fc_out, fc_start, fc_pp); - ++local_layer_index; - } else { - size_t c_fc_b_bytes = fc_out * sizeof(float); - ifs.seekg(c_fc_b_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_proj.weight (RowParallelLinear, but actually applies on "columns") - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), n_embd, fc_out, tp_rank * in4_pp, - in4_pp); - ++local_layer_index; - } else { - size_t c_proj_w_bytes = fc_out * n_embd * sizeof(float); - ifs.seekg(c_proj_w_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_proj.bias (RowParallelLinear, no shard on bias) - local_layer_index = 0; - for (int idx = 0; idx < n_layer; ++idx) { - if (owned_layers[idx]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t c_proj_b_bytes = n_embd * sizeof(float); - ifs.seekg(c_proj_b_bytes, std::ios::cur); - } - } - - if (is_last_stage) { - // transformer.ln_f.weight - auto &transformer_ln_f_weight - = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_weight->DataPtr()), n_embd); - // transformer.ln_f.bias - auto &transformer_ln_f_bias - = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerLastStage::kLnFLayerName, nn::LayerNorm::kParamBiasName)]; - ReadVectorAllFloat(ifs, static_cast(transformer_ln_f_bias->DataPtr()), n_embd); - } else { - size_t ln_f_w_bytes = n_embd * sizeof(float); - size_t ln_f_b_bytes = n_embd * sizeof(float); - ifs.seekg(ln_f_w_bytes + ln_f_b_bytes, std::ios::cur); - } - - return local_gpt2; -} -} // namespace gpt2 diff --git a/example/gpt2/checkpoint_loader.h b/example/gpt2/checkpoint_loader.h deleted file mode 100644 index e80c356e..00000000 --- a/example/gpt2/checkpoint_loader.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -namespace infini_train::nn { -class TransformerModel; -} // namespace infini_train::nn - -namespace gpt2 { -std::shared_ptr LoadFromLLMC(const std::string &filepath); -} // namespace gpt2 diff --git a/example/gpt2/config.h b/example/gpt2/config.h index 978c0ca6..27ab8828 100644 --- a/example/gpt2/config.h +++ b/example/gpt2/config.h @@ -2,7 +2,7 @@ #include "infini_train/include/nn/modules/transformer/transformer_config.h" -namespace nn = infini_train::nn; +namespace infini_train { namespace gpt2 { inline nn::TransformerConfig GPT2Config() { return {.block_size = 1024, @@ -22,5 +22,5 @@ inline nn::TransformerConfig GPT2Config() { .ffn_dim_multiplier = std::nullopt, .multiple_of = 1}; } - } // namespace gpt2 +} // namespace infini_train diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index f69736f5..cd7ae940 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -1,6 +1,9 @@ +#include #include #include +#include #include +#include #include #include #include @@ -10,6 +13,7 @@ #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/checkpoint.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" @@ -34,9 +38,9 @@ #include "infini_train/include/utils/precision_check_config.h" #include "infini_train/include/utils/precision_checker.h" +#include "example/common/checkpoint_loader.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/gpt2/checkpoint_loader.h" #include "example/gpt2/config.h" // I/O @@ -77,6 +81,15 @@ DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving"); +DEFINE_string(resume_from, "", "checkpoint directory to resume from"); +DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints"); +DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); +DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); +DEFINE_string(checkpoint_format, "pth", + "checkpoint format: bin|pth. " + "'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); " + "'pth' generates model.pth/optimizer.pth (native StateDict binary)."); // precision check DEFINE_string( precision_check, "", @@ -192,6 +205,8 @@ void Train(const nn::parallel::Rank &rank) { model_config = kModelToConfigs.at(FLAGS_model); model = std::make_shared(model_config); } + auto llmc_model = std::dynamic_pointer_cast(model); + CHECK(llmc_model != nullptr) << "Failed to cast model to GPT2 for LLMC checkpoint I/O."; model->To(device); @@ -305,6 +320,7 @@ void Train(const nn::parallel::Rank &rank) { } auto train_iter = train_loader.begin(); + size_t saved_data_batch_idx = train_iter.BatchIndex(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( std::make_shared(model_config.original_vocab_size)) @@ -314,9 +330,57 @@ void Train(const nn::parallel::Rank &rank) { auto impl = core::GetDeviceGuardImpl(device.type()); - LOG(INFO) << "start training"; - - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + int start_step = 0; + float best_loss = std::numeric_limits::infinity(); + TrainerState state; + CheckpointLoadOptions load_options; + load_options.load_optimizer_state = true; + load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { + auto loaded_model = gpt2::LoadFromLLMC(model_path.string()); + target_model->LoadStateDict(loaded_model->StateDict()); + }; + const auto resume_result = infini_train::ResumeFromCheckpoint({ + .resume_root = FLAGS_resume_from, + .rank = rank, + .model = model, + .optimizer = optimizer, + .train_loader = train_loader, + .state = state, + .train_iter = train_iter, + .load_options = load_options, + }); + start_step = resume_result.global_step; + best_loss = resume_result.best_loss; + saved_data_batch_idx = resume_result.data_batch_idx; + + auto save_checkpoint + = [&](const std::filesystem::path &save_dir, int64_t global_step, bool prune_step_checkpoints) { + infini_train::SaveCheckpoint({ + .save_dir = save_dir, + .global_step = global_step, + .data_batch_idx = saved_data_batch_idx, + .best_loss = best_loss, + .last_lr = FLAGS_learning_rate, + .optimizer_type = "SGD", + .checkpoint_format = FLAGS_checkpoint_format, + .ddp_size = ddp_world_size, + .tp_size = tp_world_size, + .sp_size = sp_world_size, + .pp_size = pp_world_size, + .save_optimizer_state = FLAGS_save_optimizer_state, + .prune_step_checkpoints = prune_step_checkpoints, + .checkpoint_root_dir = FLAGS_checkpoint_dir, + .max_checkpoint_keep = FLAGS_max_checkpoint_keep, + .rank = rank, + .model = *model, + .optimizer = *optimizer, + .model_bin_writer + = [&](const nn::Module &, + const std::filesystem::path &model_path) { gpt2::SaveAsLLMC(llmc_model, model_path.string()); }, + }); + }; + + for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) { // Reset precision check counters at start of each iteration for file overwrite utils::PrecisionChecker::ResetCounters(); @@ -366,6 +430,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + saved_data_batch_idx = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -396,6 +461,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + saved_data_batch_idx = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -408,6 +474,8 @@ void Train(const nn::parallel::Rank &rank) { lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } + best_loss = std::min(best_loss, lossf); + const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); @@ -430,6 +498,15 @@ void Train(const nn::parallel::Rank &rank) { } } } + + if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) { + std::filesystem::path step_dir + = std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1); + if (rank.IsParallel()) { + step_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(step_dir, step + 1, true); + } } // Save LoRA weights if enabled and path specified @@ -438,6 +515,12 @@ void Train(const nn::parallel::Rank &rank) { nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); } + std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final"; + if (rank.IsParallel()) { + final_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(final_dir, FLAGS_num_iteration, false); + #ifdef PROFILE_MODE Profiler::Instance().Report("gpt2.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("gpt2.records.log"); diff --git a/example/llama3/checkpoint_loader.cc b/example/llama3/checkpoint_loader.cc deleted file mode 100644 index a31d1748..00000000 --- a/example/llama3/checkpoint_loader.cc +++ /dev/null @@ -1,347 +0,0 @@ -#include "example/llama3/checkpoint_loader.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "glog/logging.h" - -#include "example/common/utils.h" -#include "example/llama3/config.h" -#include "infini_train/include/nn/modules/normalization.h" -#include "infini_train/include/nn/modules/transformer/causal_self_attention.h" -#include "infini_train/include/nn/modules/transformer/mlp.h" -#include "infini_train/include/nn/modules/transformer/transformer.h" -#include "infini_train/include/nn/parallel/global.h" -#include "infini_train/include/nn/parallel/tensor_parallel.h" -#include "infini_train/include/tensor.h" - -using namespace infini_train; -namespace nn = infini_train::nn; - -namespace { -constexpr int kRandomSeed = 42; - -// TODO(zbl): make this rng generator compatible with torch later -static std::mt19937 gen{kRandomSeed}; -} // namespace - -namespace { -constexpr int32_t kLLaMA3Magic = 20240803; -constexpr int32_t kLLaMA3FP32Version = 3; -} // namespace - -namespace llama3 { - -std::shared_ptr LoadFromLLMC(const std::string &filepath) { - if (!std::filesystem::exists(filepath)) { - LOG(FATAL) << "File not found: " << filepath; - } - - std::ifstream ifs(filepath, std::ios::binary); - const auto header = ReadSeveralBytesFromIfstream(256 * sizeof(int32_t), &ifs); - - const auto magic = BytesToType(header, 0); - CHECK_EQ(magic, kLLaMA3Magic); - const auto version = BytesToType(header, 4); - CHECK_EQ(version, kLLaMA3FP32Version); - - const auto block_size = BytesToType(header, 8); - const auto vocab_size = BytesToType(header, 12); - const auto n_layer = BytesToType(header, 16); - const auto n_head = BytesToType(header, 20); - const auto n_kv_head = BytesToType(header, 24); - const auto n_embd = BytesToType(header, 28); - const auto ffn_dim_multiplier = BytesToType(header, 32); - const auto multiple_of = BytesToType(header, 36); - const auto norm_eps = BytesToType(header, 40); - const auto rope_theta = BytesToType(header, 44); - const auto use_scaled_rope = BytesToType(header, 48); - const auto max_gen_bs = BytesToType(header, 52); - const auto version_major = BytesToType(header, 56); - const auto version_minor = BytesToType(header, 60); - - nn::TransformerConfig llama3_config = llama3::LLaMA3Config(); - llama3_config.block_size = block_size; - llama3_config.vocab_size = vocab_size; - llama3_config.n_layer = n_layer; - llama3_config.n_head = n_head; - llama3_config.n_kv_head = n_kv_head; - llama3_config.n_embd = n_embd; - llama3_config.ffn_dim_multiplier = ffn_dim_multiplier; - llama3_config.multiple_of = multiple_of; - llama3_config.rope_theta = rope_theta; - llama3_config.use_scaled_rope = static_cast(use_scaled_rope); - llama3_config.norm_eps = norm_eps; - llama3_config.max_gen_batch_size = max_gen_bs; - auto llama3 = std::make_shared(llama3_config); - - // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== - int pp_size = nn::parallel::global::GetPipelineParallelSize(); - int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_rank = nn::parallel::pp_rank; - auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] - = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); - // ========== layer to chunk ========== - std::vector owned_layers(n_layer, false); - for (const auto &[start, end] : layer_ranges_per_chunk) { - for (int i = start; i < end; ++i) { owned_layers[i] = true; } - } - - const int tp_size = nn::parallel::global::GetTensorParallelSize(); - const int tp_rank = nn::parallel::tp_rank; - - CHECK_EQ(n_embd % tp_size, 0) << "n_embd must be divisible by TP world size."; - CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; - CHECK_EQ(n_kv_head % tp_size, 0) << "n_kv_head must be divisible by TP world size."; - CHECK_EQ(vocab_size % tp_size, 0) << "vocab_size must be divisible by TP world size."; - - if (tp_rank == 0) { - LOG(INFO) << "Model Config:"; - LOG(INFO) << " block_size = " << block_size; - LOG(INFO) << " vocab_size = " << vocab_size; - LOG(INFO) << " n_layer = " << n_layer; - LOG(INFO) << " n_head = " << n_head; - LOG(INFO) << " n_kv_head = " << n_kv_head; - LOG(INFO) << " n_embd = " << n_embd; - LOG(INFO) << " ffn_dim_multiplier = " << ffn_dim_multiplier; - LOG(INFO) << " multiple_of = " << multiple_of; - LOG(INFO) << " norm_eps = " << norm_eps; - LOG(INFO) << " rope_theta = " << rope_theta; - LOG(INFO) << " use_scaled_rope = " << use_scaled_rope; - LOG(INFO) << " max_gen_bs = " << max_gen_bs; - LOG(INFO) << " version_major = " << version_major; - LOG(INFO) << " version_minor = " << version_minor; - - LOG(INFO) << "Pipeline Parallel Chunks:"; - for (size_t i = 0; i < layer_ranges_per_chunk.size(); ++i) { - LOG(INFO) << " Chunk " << i << ": layers " << layer_ranges_per_chunk[i].first << " to " - << layer_ranges_per_chunk[i].second; - } - } - - const int64_t head_dim = static_cast(n_embd) / static_cast(n_head); - - // nn::MLP hidden dim calculation in LLaMA-3 - auto round_up_to = [](int64_t x, int64_t m) { return (x + m - 1) / m * m; }; - int64_t hidden_dim = 4LL * static_cast(n_embd); - hidden_dim = (2LL * hidden_dim) / 3LL; - if (ffn_dim_multiplier > 0.0f) { - hidden_dim = static_cast( - std::llround(static_cast(ffn_dim_multiplier) * static_cast(hidden_dim))); - } - - int64_t ffn_hidden = round_up_to(hidden_dim, static_cast(multiple_of)); - - // ===== Per-rank sizes / offsets ===== - // vocab parallel - const int64_t vpp = static_cast(vocab_size) / tp_size; - const int64_t v_start = static_cast(tp_rank) * vpp; - - // attention Q/K/V packed as rows: [Q | K | V] - const int64_t q_out_rows = static_cast(n_embd); - const int64_t kv_out_rows = static_cast(n_kv_head) * head_dim; // for K or V (each) - const int64_t attn_rows_all = q_out_rows + 2 * kv_out_rows; - const int64_t attn_cols = static_cast(n_embd); - - // local Q/K/V rows per tp_rank - const int64_t q_local_rows = static_cast(n_embd) / tp_size; // = (n_head/world)*head_dim - const int64_t kv_head_local = static_cast(n_kv_head) / tp_size; - const int64_t kv_local_rows = kv_head_local * head_dim; // for K or V (each) - const int64_t attn_local_rows = q_local_rows + 2 * kv_local_rows; - - // RowParallel (proj) - const int64_t in_pp = static_cast(n_embd) / tp_size; - // nn::MLP: c_fc/c_fc2(shard along row),c_proj(shard along col) - const int64_t fc_out = ffn_hidden; - const int64_t fc_pp = fc_out / tp_size; - const int64_t in_fc_pp = ffn_hidden / tp_size; - - auto state_dict = llama3->StateDict(); - - // ========== Read Sharded Params ========== - // transformer.wte.weight : (vocab_size, n_embd) -> local tp_rank: rows of [v_start : v_start+vpp) - if (is_first_stage) { - auto &wte = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerFirstStage::kWTELayerName, - nn::parallel::VocabParallelEmbedding::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(wte->DataPtr()), - /*rows=*/vocab_size, /*cols=*/n_embd, - /*row_start=*/v_start, /*row_cnt=*/vpp); - } else { - size_t wte_bytes = static_cast(vocab_size) * n_embd * sizeof(float); - ifs.seekg(wte_bytes, std::ios::cur); - } - - // transformer.h.{i}.ln_1.weight : Full version nn::RMSNorm - int local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn1LayerName, nn::RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_1_bytes = n_embd * sizeof(float); - ifs.seekg(ln_1_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_attn.weight : ColumnParallelLinear, but actually applies on "rows" - // W-qkv should be [Q(=n_embd) | K(=n_kv_head*head_dim) | V(=n_kv_head*head_dim)] × n_embd - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCAttnLayerName, nn::parallel::ColumnParallelLinear::kParamWeightName)]; - - float *dst = static_cast(tensor->DataPtr()); - const std::streampos base_pos = ifs.tellg(); - - // Q block -> [0 : q_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (0 * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/tp_rank * q_local_rows, /*row_cnt=*/q_local_rows); - - // K block -> [q_local_rows : q_local_rows + kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + (q_local_rows * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + tp_rank * kv_local_rows, /*row_cnt=*/kv_local_rows); - - // V block -> [q_local_rows + kv_local_rows : q_local_rows + 2*kv_local_rows) - ifs.seekg(base_pos); - ReadMatrixRowShardFloat(ifs, - /*dst=*/dst + ((q_local_rows + kv_local_rows) * attn_cols), - /*rows=*/attn_rows_all, /*cols=*/attn_cols, - /*row_start=*/q_out_rows + kv_out_rows + tp_rank * kv_local_rows, - /*row_cnt=*/kv_local_rows); - ++local_layer_index; - } else { - size_t qkv_bytes = static_cast(attn_rows_all) * attn_cols * sizeof(float); - ifs.seekg(qkv_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.attn.c_proj.weight : RowParallelLinear, but actually applies on "columns" - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format( - "{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, nn::TransformerChunk::kHLayerName, - std::to_string(local_layer_index), nn::TransformerLayer::kAttnLayerName, - nn::CausalSelfAttention::kCProjLayerName, nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/n_embd, - /*col_start=*/tp_rank * in_pp, /*col_cnt=*/in_pp); - ++local_layer_index; - } else { - size_t c_proj_bytes = static_cast(n_embd) * n_embd * sizeof(float); - ifs.seekg(c_proj_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.ln_2.weight : Full version RMSNorm - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kLn2LayerName, nn::RMSNorm::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(tensor->DataPtr()), n_embd); - ++local_layer_index; - } else { - size_t ln_2_bytes = static_cast(n_embd) * sizeof(float); - ifs.seekg(ln_2_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_fc.weight : ColumnParallelLinear, but actually applies on "rows" - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFcLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); - ++local_layer_index; - } else { - size_t fc_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); - ifs.seekg(fc_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_fc2.weight : ColumnParallelLinear, but actually applies on "rows" - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCFc2LayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadMatrixRowShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/fc_out, /*cols=*/n_embd, - /*row_start=*/tp_rank * fc_pp, /*row_cnt=*/fc_pp); - ++local_layer_index; - } else { - size_t fc2_bytes = static_cast(ffn_hidden) * n_embd * sizeof(float); - ifs.seekg(fc2_bytes, std::ios::cur); - } - } - - // transformer.h.{i}.mlp.c_proj.weight : RowParallelLinear, but actually applies on "columns" - local_layer_index = 0; - for (int i = 0; i < static_cast(n_layer); ++i) { - if (owned_layers[i]) { - auto &tensor = state_dict[std::format("{}.{}.{}.{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerChunk::kHLayerName, std::to_string(local_layer_index), - nn::TransformerLayer::kMlpLayerName, nn::MLP::kCProjLayerName, - nn::parallel::RowParallelLinear::kParamWeightName)]; - ReadMatrixColShardFloat(ifs, static_cast(tensor->DataPtr()), - /*rows=*/n_embd, /*cols=*/fc_out, - /*col_start=*/tp_rank * in_fc_pp, /*col_cnt=*/in_fc_pp); - ++local_layer_index; - } else { - size_t c_proj_bytes = static_cast(n_embd) * ffn_hidden * sizeof(float); - ifs.seekg(c_proj_bytes, std::ios::cur); - } - } - - // transformer.ln_f.weight : Full version nn::RMSNorm - // lm_head.weight : (vocab_size, n_embd) -> ColumnParallelLinear, but actually applies on "rows" - { - if (is_last_stage) { - auto &ln_f - = state_dict[std::format("{}.{}.{}", nn::TransformerModel::kTransformerModelName, - nn::TransformerLastStage::kLnFLayerName, nn::RMSNorm::kParamWeightName)]; - auto &lm_head = state_dict[std::format("{}.{}", nn::TransformerLastStage::kLMHeadLayerName, - nn::parallel::ColumnParallelLinear::kParamWeightName)]; - ReadVectorAllFloat(ifs, static_cast(ln_f->DataPtr()), n_embd); - ReadMatrixRowShardFloat(ifs, static_cast(lm_head->DataPtr()), - /*rows=*/vocab_size, /*cols=*/n_embd, - /*row_start=*/v_start, /*row_cnt=*/vpp); - } else { - size_t ln_f_bytes = static_cast(n_embd) * sizeof(float); - size_t lm_head_bytes = static_cast(vocab_size) * n_embd * sizeof(float); - ifs.seekg(ln_f_bytes + lm_head_bytes, std::ios::cur); - } - } - - return llama3; -} -} // namespace llama3 diff --git a/example/llama3/checkpoint_loader.h b/example/llama3/checkpoint_loader.h deleted file mode 100644 index d4aea3d0..00000000 --- a/example/llama3/checkpoint_loader.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include -#include - -namespace infini_train::nn { -class TransformerModel; -} // namespace infini_train::nn - -namespace llama3 { -std::shared_ptr LoadFromLLMC(const std::string &filepath); -} // namespace llama3 diff --git a/example/llama3/config.h b/example/llama3/config.h index a9eef863..b2aad4b6 100644 --- a/example/llama3/config.h +++ b/example/llama3/config.h @@ -2,7 +2,7 @@ #include "infini_train/include/nn/modules/transformer/transformer_config.h" -namespace nn = infini_train::nn; +namespace infini_train { namespace llama3 { inline nn::TransformerConfig LLaMA3Config() { return {.block_size = 8192, @@ -23,3 +23,4 @@ inline nn::TransformerConfig LLaMA3Config() { .multiple_of = 256}; } } // namespace llama3 +} // namespace infini_train diff --git a/example/llama3/main.cc b/example/llama3/main.cc index da9a1027..d57586d1 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -1,13 +1,18 @@ +#include #include +#include #include +#include #include #include #include +#include "example/common/utils.h" #include "gflags/gflags.h" #include "glog/logging.h" #include "infini_train/include/autocast.h" +#include "infini_train/include/checkpoint.h" #include "infini_train/include/core/runtime/device_guard.h" #include "infini_train/include/dataloader.h" #include "infini_train/include/device.h" @@ -33,9 +38,9 @@ #include "infini_train/include/profiler.h" #endif +#include "example/common/checkpoint_loader.h" #include "example/common/tiny_shakespeare_dataset.h" #include "example/common/tokenizer.h" -#include "example/llama3/checkpoint_loader.h" #include "example/llama3/config.h" // I/O @@ -75,6 +80,15 @@ DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the DEFINE_uint32(virtual_pipeline_parallel, 1, "Number of chunks in PP stage."); // precision DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)"); +DEFINE_uint32(save_steps, 0, "save checkpoint every N steps; 0 disables saving"); +DEFINE_string(resume_from, "", "checkpoint directory to resume from"); +DEFINE_string(checkpoint_dir, "./checkpoints", "root directory used to store checkpoints"); +DEFINE_uint32(max_checkpoint_keep, 3, "max number of checkpoint steps to keep"); +DEFINE_bool(save_optimizer_state, true, "whether optimizer state is persisted in checkpoints"); +DEFINE_string(checkpoint_format, "pth", + "checkpoint format: bin|pth. " + "'bin' generates model.bin/optimizer.bin (bin supports LLMC model format via callbacks); " + "'pth' generates model.pth/optimizer.pth (native StateDict binary)."); // precision check DEFINE_string( precision_check, "", @@ -176,6 +190,8 @@ void Train(const nn::parallel::Rank &rank) { } else { model = std::make_shared(model_config); } + auto llmc_model = std::dynamic_pointer_cast(model); + CHECK(llmc_model != nullptr) << "Failed to cast model to LLaMA3 for LLMC checkpoint I/O."; model->To(device); @@ -284,6 +300,7 @@ void Train(const nn::parallel::Rank &rank) { } auto train_iter = train_loader.begin(); + size_t saved_data_batch_idx = train_iter.BatchIndex(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) : std::static_pointer_cast(std::make_shared()); @@ -292,7 +309,57 @@ void Train(const nn::parallel::Rank &rank) { auto impl = core::GetDeviceGuardImpl(device.type()); - for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { + int start_step = 0; + float best_loss = std::numeric_limits::infinity(); + TrainerState state; + CheckpointLoadOptions load_options; + load_options.load_optimizer_state = true; + load_options.model_bin_loader = [](nn::Module *target_model, const std::filesystem::path &model_path) { + auto loaded_model = llama3::LoadFromLLMC(model_path.string()); + target_model->LoadStateDict(loaded_model->StateDict()); + }; + const auto resume_result = infini_train::ResumeFromCheckpoint({ + .resume_root = FLAGS_resume_from, + .rank = rank, + .model = model, + .optimizer = optimizer, + .train_loader = train_loader, + .state = state, + .train_iter = train_iter, + .load_options = load_options, + }); + start_step = resume_result.global_step; + best_loss = resume_result.best_loss; + saved_data_batch_idx = resume_result.data_batch_idx; + + auto save_checkpoint = [&](const std::filesystem::path &save_dir, int64_t global_step, + bool prune_step_checkpoints) { + infini_train::SaveCheckpoint({ + .save_dir = save_dir, + .global_step = global_step, + .data_batch_idx = saved_data_batch_idx, + .best_loss = best_loss, + .last_lr = FLAGS_learning_rate, + .optimizer_type = "Adam", + .checkpoint_format = FLAGS_checkpoint_format, + .ddp_size = ddp_world_size, + .tp_size = tp_world_size, + .sp_size = sp_world_size, + .pp_size = pp_world_size, + .save_optimizer_state = FLAGS_save_optimizer_state, + .prune_step_checkpoints = prune_step_checkpoints, + .checkpoint_root_dir = FLAGS_checkpoint_dir, + .max_checkpoint_keep = FLAGS_max_checkpoint_keep, + .rank = rank, + .model = *model, + .optimizer = *optimizer, + .model_bin_writer + = [&](const nn::Module &, + const std::filesystem::path &model_path) { llama3::SaveAsLLMC(llmc_model, model_path.string()); }, + }); + }; + + for (int step = start_step; step < FLAGS_num_iteration + 1; ++step) { // Reset precision check counters at start of each iteration for file overwrite utils::PrecisionChecker::ResetCounters(); @@ -342,6 +409,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + saved_data_batch_idx = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -371,6 +439,7 @@ void Train(const nn::parallel::Rank &rank) { // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below // TODO(dcj): support dataloader.reset() later ++train_iter; + saved_data_batch_idx = train_iter.BatchIndex(); x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); @@ -383,6 +452,8 @@ void Train(const nn::parallel::Rank &rank) { lossf = static_cast(lossf_tensor->To(Device()).DataPtr())[0]; } + best_loss = std::min(best_loss, lossf); + const auto iter_end = std::chrono::high_resolution_clock::now(); const double duration_us = std::chrono::duration(iter_end - iter_start).count(); const double tps = FLAGS_total_batch_size / (duration_us / 1e6); @@ -405,6 +476,15 @@ void Train(const nn::parallel::Rank &rank) { } } } + + if (FLAGS_save_steps > 0 && (step + 1) % FLAGS_save_steps == 0) { + std::filesystem::path step_dir + = std::filesystem::path(FLAGS_checkpoint_dir) / std::format("checkpoint_step_{:06d}", step + 1); + if (rank.IsParallel()) { + step_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(step_dir, step + 1, true); + } } // Save LoRA weights if enabled and path specified @@ -413,6 +493,12 @@ void Train(const nn::parallel::Rank &rank) { nn::lora::SaveLoRAWeights(model, FLAGS_lora_save_path); } + std::filesystem::path final_dir = std::filesystem::path(FLAGS_checkpoint_dir) / "checkpoint_final"; + if (rank.IsParallel()) { + final_dir /= std::format("rank_{:06d}", rank.GlobalRank()); + } + save_checkpoint(final_dir, FLAGS_num_iteration, false); + #ifdef PROFILE_MODE Profiler::Instance().Report("llama3.report", Profiler::SortBy::DeviceTimePercentage); Profiler::Instance().PrintRecords("llama3.records.log"); diff --git a/infini_train/include/checkpoint.h b/infini_train/include/checkpoint.h new file mode 100644 index 00000000..32dab4ec --- /dev/null +++ b/infini_train/include/checkpoint.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace infini_train { +class Optimizer; +class Tensor; +namespace nn { +class Module; +} + +struct TrainerState { + int64_t global_step = 0; + int64_t data_batch_idx = 0; + int64_t data_batch_stride = 1; + float best_loss = 0.0f; + double last_lr = 0.0; + std::string optimizer_type = "unknown"; + std::string checkpoint_format = "bin"; + + int ddp_size = 1; + int tp_size = 1; + int sp_size = 1; + int pp_size = 1; +}; + +struct CheckpointOptions { + std::string format = "bin"; + bool save_optimizer_state = true; + std::function model_bin_writer; +}; + +struct CheckpointLoadOptions { + bool load_optimizer_state = true; + std::function model_bin_loader; +}; + +class Checkpoint { +public: + static void Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer, + const TrainerState &state, const CheckpointOptions &options = {}); + + static void Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer, + TrainerState *state, const CheckpointLoadOptions &options = {}); + +private: + static void SaveStateDictBinary(const std::filesystem::path &path, + const std::unordered_map> &state_dict); + + static std::unordered_map> + LoadStateDictBinary(const std::filesystem::path &path); + + static void SaveTrainerState(const std::filesystem::path &path, const TrainerState &state); + static TrainerState LoadTrainerState(const std::filesystem::path &path); + static std::string InferFormat(const std::filesystem::path &checkpoint_dir); +}; + +} // namespace infini_train diff --git a/infini_train/include/dataloader.h b/infini_train/include/dataloader.h index ad7fbcda..3c8bc0b5 100644 --- a/infini_train/include/dataloader.h +++ b/infini_train/include/dataloader.h @@ -24,6 +24,8 @@ class DataLoaderIterator { friend bool operator!=(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs); friend bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs); + size_t BatchIndex() const; + private: const Dataset *dataset_ = nullptr; // not owned size_t batch_size_ = 0; @@ -39,6 +41,7 @@ class DataLoader { virtual DataLoaderIterator begin() const; virtual DataLoaderIterator end() const; + virtual DataLoaderIterator IteratorAtBatchIndex(size_t batch_idx) const; protected: std::shared_ptr dataset_; @@ -53,6 +56,7 @@ class DistributedDataLoader : public DataLoader { DataLoaderIterator begin() const override; DataLoaderIterator end() const override; + DataLoaderIterator IteratorAtBatchIndex(size_t batch_idx) const override; private: size_t ddp_rank_ = 0; diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index f366661b..7cddaeed 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -61,6 +61,8 @@ class Module : public std::enable_shared_from_this { std::unordered_map> StateDict() const; + virtual void LoadStateDict(const std::unordered_map> &state_dict); + // operator() calls hooks and Forward std::vector> operator()(const std::vector> &input_tensors); diff --git a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h index bc31442e..559c4312 100644 --- a/infini_train/include/nn/parallel/ddp/distributed_optimizer.h +++ b/infini_train/include/nn/parallel/ddp/distributed_optimizer.h @@ -28,6 +28,10 @@ class DistributedOptimizer final : public infini_train::Optimizer { void ZeroGrad(bool set_to_none = true) override; + std::unordered_map> StateDict() const override; + + void LoadStateDict(const std::unordered_map> &state_dict) override; + void StartGradSync(); void FinishGradSync(); diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index fb0ae2d5..9a0d00ab 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include namespace infini_train { @@ -21,6 +23,10 @@ class Optimizer { virtual void Step() = 0; + virtual std::unordered_map> StateDict() const { return {}; } + + virtual void LoadStateDict(const std::unordered_map> &state_dict) {} + protected: std::vector> params_; }; @@ -49,6 +55,10 @@ class Adam : public Optimizer { void Step() override; + std::unordered_map> StateDict() const override; + + void LoadStateDict(const std::unordered_map> &state_dict) override; + static OptimizerCreator Create(float learning_rate = 1e-3, float beta1 = 0.9, float beta2 = 0.999, float eps = 1e-8) { return [=](const std::vector> ¶ms) { diff --git a/infini_train/src/checkpoint.cc b/infini_train/src/checkpoint.cc new file mode 100644 index 00000000..fc6b73a4 --- /dev/null +++ b/infini_train/src/checkpoint.cc @@ -0,0 +1,277 @@ +#include "infini_train/include/checkpoint.h" + +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/optimizer.h" +#include "infini_train/include/tensor.h" + +namespace infini_train { +namespace { +constexpr uint32_t kCkptMagic = 0x54504B43; // CKPT +constexpr uint32_t kCkptVersion = 1; + +uint32_t PeekMagic(const std::filesystem::path &path) { + std::ifstream ifs(path, std::ios::binary); + CHECK(ifs.is_open()) << "Failed to open checkpoint file: " << path; + uint32_t magic = 0; + ifs.read(reinterpret_cast(&magic), sizeof(magic)); + return magic; +} + +void WriteString(std::ofstream *ofs, const std::string &value) { + uint32_t len = static_cast(value.size()); + ofs->write(reinterpret_cast(&len), sizeof(len)); + ofs->write(value.data(), len); +} + +std::string ReadString(std::ifstream *ifs) { + uint32_t len = 0; + ifs->read(reinterpret_cast(&len), sizeof(len)); + std::string s(len, '\0'); + ifs->read(s.data(), len); + return s; +} + +std::string ExtractStringField(const std::string &content, const std::string &key, const std::string &fallback) { + const auto token = std::string("\"") + key + "\""; + const auto key_pos = content.find(token); + if (key_pos == std::string::npos) { + return fallback; + } + const auto colon_pos = content.find(':', key_pos); + const auto first_quote = content.find('"', colon_pos + 1); + const auto second_quote = content.find('"', first_quote + 1); + if (first_quote == std::string::npos || second_quote == std::string::npos) { + return fallback; + } + return content.substr(first_quote + 1, second_quote - first_quote - 1); +} + +template T ExtractNumberField(const std::string &content, const std::string &key, T fallback) { + const auto token = std::string("\"") + key + "\""; + const auto key_pos = content.find(token); + if (key_pos == std::string::npos) { + return fallback; + } + const auto colon_pos = content.find(':', key_pos); + if (colon_pos == std::string::npos) { + return fallback; + } + size_t value_start = colon_pos + 1; + while (value_start < content.size() && (content[value_start] == ' ' || content[value_start] == '\n')) { + ++value_start; + } + size_t value_end = value_start; + while (value_end < content.size() && content[value_end] != ',' && content[value_end] != '\n' + && content[value_end] != '}') { + ++value_end; + } + std::stringstream ss(content.substr(value_start, value_end - value_start)); + T value = fallback; + ss >> value; + if (ss.fail()) { + return fallback; + } + return value; +} +} // namespace + +void Checkpoint::Save(const std::filesystem::path &checkpoint_dir, const nn::Module &model, const Optimizer &optimizer, + const TrainerState &state, const CheckpointOptions &options) { + CHECK(options.format == "bin" || options.format == "pth") << "Unsupported checkpoint format: " << options.format; + std::filesystem::create_directories(checkpoint_dir); + LOG(ERROR) << "[CKPT] Save begin: dir=" << checkpoint_dir << ", format=" << options.format + << ", global_step=" << state.global_step; + + const auto model_path = checkpoint_dir / (options.format == "pth" ? "model.pth" : "model.bin"); + if (options.format == "bin" && options.model_bin_writer) { + options.model_bin_writer(model, model_path); + } else { + SaveStateDictBinary(model_path, model.StateDict()); + } + + if (options.save_optimizer_state) { + auto opt_state = optimizer.StateDict(); + if (!opt_state.empty()) { + const auto opt_path = checkpoint_dir / (options.format == "pth" ? "optimizer.pth" : "optimizer.bin"); + SaveStateDictBinary(opt_path, opt_state); + } + } + + SaveTrainerState(checkpoint_dir / "trainer_state.json", state); + LOG(ERROR) << "[CKPT] Save done: dir=" << checkpoint_dir; +} + +void Checkpoint::Load(const std::filesystem::path &checkpoint_dir, nn::Module *model, Optimizer *optimizer, + TrainerState *state, const CheckpointLoadOptions &options) { + CHECK(model != nullptr); + CHECK(state != nullptr); + + const std::string format = InferFormat(checkpoint_dir); + const auto model_path = checkpoint_dir / (format == "pth" ? "model.pth" : "model.bin"); + LOG(ERROR) << "[CKPT] Load begin: dir=" << checkpoint_dir << ", format=" << format; + LOG(ERROR) << "[CKPT] Loading model: " << model_path; + if (format == "bin" && options.model_bin_loader) { + const uint32_t magic = PeekMagic(model_path); + if (magic == kCkptMagic) { + LOG(ERROR) << "[CKPT] Model format detected: native checkpoint binary."; + model->LoadStateDict(LoadStateDictBinary(model_path)); + } else { + LOG(ERROR) << "[CKPT] Model format detected: external model.bin (magic=" << magic + << "), use model_bin_loader callback."; + options.model_bin_loader(model, model_path); + } + } else { + model->LoadStateDict(LoadStateDictBinary(model_path)); + } + + if (optimizer != nullptr && options.load_optimizer_state) { + const auto opt_path = checkpoint_dir / (format == "pth" ? "optimizer.pth" : "optimizer.bin"); + if (std::filesystem::exists(opt_path)) { + LOG(ERROR) << "[CKPT] Loading optimizer: " << opt_path; + optimizer->LoadStateDict(LoadStateDictBinary(opt_path)); + } else { + LOG(ERROR) << "[CKPT] Optimizer state not found, skip: " << opt_path; + } + } else if (optimizer == nullptr) { + LOG(ERROR) << "[CKPT] No optimizer instance, skip optimizer state loading."; + } else { + LOG(ERROR) << "[CKPT] load_optimizer_state=false, skip optimizer state loading."; + } + + *state = LoadTrainerState(checkpoint_dir / "trainer_state.json"); + LOG(ERROR) << "[CKPT] Load done: global_step=" << state->global_step << ", data_batch_idx=" << state->data_batch_idx + << ", data_batch_stride=" << state->data_batch_stride << ", best_loss=" << state->best_loss + << ", last_lr=" << state->last_lr << ", optimizer_type=" << state->optimizer_type + << ", topology(ddp,tp,sp,pp)=(" << state->ddp_size << "," << state->tp_size << "," << state->sp_size + << "," << state->pp_size << ")"; +} + +void Checkpoint::SaveStateDictBinary(const std::filesystem::path &path, + const std::unordered_map> &state_dict) { + std::ofstream ofs(path, std::ios::binary); + CHECK(ofs.is_open()) << "Failed to open checkpoint file: " << path; + + uint32_t magic = kCkptMagic; + uint32_t version = kCkptVersion; + uint32_t count = static_cast(state_dict.size()); + ofs.write(reinterpret_cast(&magic), sizeof(magic)); + ofs.write(reinterpret_cast(&version), sizeof(version)); + ofs.write(reinterpret_cast(&count), sizeof(count)); + + for (const auto &[name, tensor] : state_dict) { + WriteString(&ofs, name); + + const int8_t dtype = static_cast(tensor->Dtype()); + ofs.write(reinterpret_cast(&dtype), sizeof(dtype)); + + const auto &dims = tensor->Dims(); + uint32_t ndim = static_cast(dims.size()); + ofs.write(reinterpret_cast(&ndim), sizeof(ndim)); + for (const auto dim : dims) { ofs.write(reinterpret_cast(&dim), sizeof(dim)); } + + Tensor cpu_tensor = tensor->To(Device()); + uint64_t bytes = static_cast(cpu_tensor.SizeInBytes()); + ofs.write(reinterpret_cast(&bytes), sizeof(bytes)); + ofs.write(reinterpret_cast(cpu_tensor.DataPtr()), static_cast(bytes)); + } +} + +std::unordered_map> +Checkpoint::LoadStateDictBinary(const std::filesystem::path &path) { + std::ifstream ifs(path, std::ios::binary); + CHECK(ifs.is_open()) << "Failed to open checkpoint file: " << path; + + uint32_t magic = 0; + uint32_t version = 0; + uint32_t count = 0; + ifs.read(reinterpret_cast(&magic), sizeof(magic)); + ifs.read(reinterpret_cast(&version), sizeof(version)); + ifs.read(reinterpret_cast(&count), sizeof(count)); + + CHECK_EQ(magic, kCkptMagic) << "Invalid checkpoint magic: " << path; + CHECK_EQ(version, kCkptVersion) << "Unsupported checkpoint version: " << path; + + std::unordered_map> state; + for (uint32_t i = 0; i < count; ++i) { + const std::string name = ReadString(&ifs); + + int8_t dtype_raw = 0; + ifs.read(reinterpret_cast(&dtype_raw), sizeof(dtype_raw)); + DataType dtype = static_cast(dtype_raw); + + uint32_t ndim = 0; + ifs.read(reinterpret_cast(&ndim), sizeof(ndim)); + std::vector dims(ndim); + for (uint32_t d = 0; d < ndim; ++d) { ifs.read(reinterpret_cast(&dims[d]), sizeof(dims[d])); } + + uint64_t bytes = 0; + ifs.read(reinterpret_cast(&bytes), sizeof(bytes)); + + auto tensor = std::make_shared(dims, dtype, Device()); + CHECK_EQ(bytes, tensor->SizeInBytes()) << "Tensor bytes mismatch for key: " << name; + ifs.read(reinterpret_cast(tensor->DataPtr()), static_cast(bytes)); + state.emplace(name, tensor); + } + + return state; +} + +void Checkpoint::SaveTrainerState(const std::filesystem::path &path, const TrainerState &state) { + std::ofstream ofs(path); + CHECK(ofs.is_open()) << "Failed to open trainer state file: " << path; + ofs << "{\n"; + ofs << " \"global_step\": " << state.global_step << ",\n"; + ofs << " \"data_batch_idx\": " << state.data_batch_idx << ",\n"; + ofs << " \"data_batch_stride\": " << state.data_batch_stride << ",\n"; + ofs << " \"best_loss\": " << state.best_loss << ",\n"; + ofs << " \"last_lr\": " << state.last_lr << ",\n"; + ofs << " \"optimizer_type\": \"" << state.optimizer_type << "\",\n"; + ofs << " \"checkpoint_format\": \"" << state.checkpoint_format << "\",\n"; + ofs << " \"ddp_size\": " << state.ddp_size << ",\n"; + ofs << " \"tp_size\": " << state.tp_size << ",\n"; + ofs << " \"sp_size\": " << state.sp_size << ",\n"; + ofs << " \"pp_size\": " << state.pp_size << "\n"; + ofs << "}\n"; +} + +TrainerState Checkpoint::LoadTrainerState(const std::filesystem::path &path) { + std::ifstream ifs(path); + CHECK(ifs.is_open()) << "Failed to open trainer state file: " << path; + const std::string content((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + + TrainerState state; + state.global_step = ExtractNumberField(content, "global_step", 0); + state.data_batch_idx = ExtractNumberField(content, "data_batch_idx", 0); + state.data_batch_stride = ExtractNumberField(content, "data_batch_stride", 1); + state.best_loss = ExtractNumberField(content, "best_loss", std::numeric_limits::infinity()); + state.last_lr = ExtractNumberField(content, "last_lr", 0.0); + state.optimizer_type = ExtractStringField(content, "optimizer_type", "unknown"); + state.checkpoint_format = ExtractStringField(content, "checkpoint_format", "bin"); + state.ddp_size = ExtractNumberField(content, "ddp_size", 1); + state.tp_size = ExtractNumberField(content, "tp_size", 1); + state.sp_size = ExtractNumberField(content, "sp_size", 1); + state.pp_size = ExtractNumberField(content, "pp_size", 1); + return state; +} + +std::string Checkpoint::InferFormat(const std::filesystem::path &checkpoint_dir) { + if (std::filesystem::exists(checkpoint_dir / "model.pth")) { + return "pth"; + } + if (std::filesystem::exists(checkpoint_dir / "model.bin")) { + return "bin"; + } + LOG(FATAL) << "Failed to infer checkpoint format from path: " << checkpoint_dir; + return "bin"; +} + +} // namespace infini_train diff --git a/infini_train/src/dataloader.cc b/infini_train/src/dataloader.cc index 322df553..1fe9de88 100644 --- a/infini_train/src/dataloader.cc +++ b/infini_train/src/dataloader.cc @@ -78,6 +78,8 @@ bool operator==(const DataLoaderIterator &lhs, const DataLoaderIterator &rhs) { return lhs.batch_idx_ == rhs.batch_idx_; } +size_t DataLoaderIterator::BatchIndex() const { return batch_idx_; } + DataLoader::DataLoader(const std::shared_ptr &dataset, size_t batch_size) : dataset_(dataset), batch_size_(batch_size), max_batch_idx_((dataset_->Size() + batch_size_ - 1) / batch_size_) {} @@ -87,6 +89,10 @@ DataLoaderIterator DataLoader::end() const { return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_); } +DataLoaderIterator DataLoader::IteratorAtBatchIndex(size_t batch_idx) const { + return DataLoaderIterator(*dataset_, batch_size_, std::min(batch_idx, max_batch_idx_), max_batch_idx_); +} + DistributedDataLoader::DistributedDataLoader(const std::shared_ptr &dataset, size_t batch_size, size_t ddp_rank, size_t ddp_world_size) : DataLoader(dataset, batch_size), ddp_rank_(ddp_rank), ddp_world_size_(ddp_world_size) {} @@ -98,4 +104,9 @@ DataLoaderIterator DistributedDataLoader::begin() const { DataLoaderIterator DistributedDataLoader::end() const { return DataLoaderIterator(*dataset_, batch_size_, max_batch_idx_, max_batch_idx_, ddp_rank_, ddp_world_size_); } + +DataLoaderIterator DistributedDataLoader::IteratorAtBatchIndex(size_t batch_idx) const { + return DataLoaderIterator(*dataset_, batch_size_, std::min(batch_idx, max_batch_idx_), max_batch_idx_, ddp_rank_, + ddp_world_size_); +} } // namespace infini_train diff --git a/infini_train/src/nn/modules/module.cc b/infini_train/src/nn/modules/module.cc index 6d48dcab..150626ac 100644 --- a/infini_train/src/nn/modules/module.cc +++ b/infini_train/src/nn/modules/module.cc @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,21 @@ namespace infini_train::nn { +namespace { +std::string DimsToString(const std::vector &dims) { + std::ostringstream oss; + oss << "["; + for (size_t i = 0; i < dims.size(); ++i) { + if (i > 0) { + oss << ", "; + } + oss << dims[i]; + } + oss << "]"; + return oss.str(); +} +} // namespace + Module::Module() : Module(kUndefinedType) {} Module::Module(const std::string &type) : type_(type), device_(Device()) {} @@ -147,6 +163,21 @@ std::unordered_map> Module::StateDict() con return state; } +void Module::LoadStateDict(const std::unordered_map> &state_dict) { + auto expected = StateDict(); + for (const auto &[name, dst] : expected) { + CHECK(state_dict.contains(name)) << "Missing tensor in state dict: " << name; + const auto &src = state_dict.at(name); + CHECK(dst->Dims() == src->Dims()) + << "Shape mismatch for tensor: " << name << ", expected=" << DimsToString(dst->Dims()) + << ", got=" << DimsToString(src->Dims()); + CHECK(dst->Dtype() == src->Dtype()) + << "Dtype mismatch for tensor: " << name << ", expected=" << kDataTypeToDesc.at(dst->Dtype()) + << ", got=" << kDataTypeToDesc.at(src->Dtype()); + dst->CopyFrom(src); + } +} + std::vector> Module::Forward(const std::vector> &input_tensors) { LOG(FATAL) << "Forward function not implemented for this module"; return {}; diff --git a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc index 55e5800b..ae9bb1c5 100644 --- a/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc +++ b/infini_train/src/nn/parallel/ddp/distributed_optimizer.cc @@ -128,4 +128,14 @@ void DistributedOptimizer::Step() { FinishParamSync(/*skip_next_bucket_dispatch=*/true); } +std::unordered_map> DistributedOptimizer::StateDict() const { + CHECK(base_optimizer_) << "DistributedOptimizer: base optimizer is null."; + return base_optimizer_->StateDict(); +} + +void DistributedOptimizer::LoadStateDict(const std::unordered_map> &state_dict) { + CHECK(base_optimizer_) << "DistributedOptimizer: base optimizer is null."; + base_optimizer_->LoadStateDict(state_dict); +} + } // namespace infini_train::nn::parallel diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index d5589b01..d5a14333 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -1,5 +1,6 @@ #include "infini_train/include/optimizer.h" +#include #include #include "infini_train/include/core/runtime/device_guard.h" @@ -62,5 +63,33 @@ void Adam::Step() { kernel.Call(grad, param, m, v, learning_rate_, beta1_, beta2_, eps_, t_); } } + +std::unordered_map> Adam::StateDict() const { + std::unordered_map> state; + for (size_t i = 0; i < m_.size(); ++i) { + state.emplace(std::format("adam.m.{}", i), m_[i]); + state.emplace(std::format("adam.v.{}", i), v_[i]); + } + + auto t_tensor = std::make_shared(std::vector{}, DataType::kINT64, Device()); + *static_cast(t_tensor->DataPtr()) = t_; + state.emplace("adam.t", t_tensor); + return state; +} + +void Adam::LoadStateDict(const std::unordered_map> &state_dict) { + for (size_t i = 0; i < m_.size(); ++i) { + const auto m_key = std::format("adam.m.{}", i); + const auto v_key = std::format("adam.v.{}", i); + CHECK(state_dict.contains(m_key)) << "Missing optimizer state: " << m_key; + CHECK(state_dict.contains(v_key)) << "Missing optimizer state: " << v_key; + m_[i]->CopyFrom(state_dict.at(m_key)); + v_[i]->CopyFrom(state_dict.at(v_key)); + } + + CHECK(state_dict.contains("adam.t")) << "Missing optimizer state: adam.t"; + const Tensor t_cpu = state_dict.at("adam.t")->To(Device()); + t_ = *static_cast(t_cpu.DataPtr()); +} } // namespace optimizers } // namespace infini_train diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 06589904..c85d157f 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -206,15 +206,30 @@ move_profile_logs() { done } -# Build "--key value" arg string from test_groups[gi].tests[ti].args (shell-escaped) +# Build "--key value" arg string from tests[i].args. +# For checkpoint-related args, automatically isolate by model and run mode +# (resume/no_resume) to avoid cross-test overwrites in one-click runs. args_string_for_test() { - local group_idx="$1" - local test_idx="$2" - jq -r --argjson g "$group_idx" --argjson t "$test_idx" ' - .test_groups[$g].tests[$t].args - | to_entries[] - | "--\(.key)=\(.value|tostring)" - ' "$CONFIG_FILE" | paste -sd' ' - + local idx="$1" + local model_name="$2" + jq -r --argjson i "$idx" --arg model "$model_name" ' + def namespaced_path($p; $model; $mode): + if ($p | test("/checkpoint_step_[0-9]+($|/)")) then + ($p | capture("^(?.*)/(?checkpoint_step_[0-9]+(?:/.*)?)$")) as $m + | ($m.prefix + "/" + $model + "/" + $mode + "/" + $m.step) + else + ($p + "/" + $model + "/" + $mode) + end; + + .tests[$i].args as $args + | (if ($args | has("resume_from")) then "resume" else "no_resume" end) as $run_mode + | (if (($args.resume_from // "") | test("(^|/)no_resume(/|$)")) then "no_resume" else "resume" end) as $resume_src_mode + | $args + | (if has("checkpoint_dir") then .checkpoint_dir = namespaced_path(.checkpoint_dir; $model; $run_mode) else . end) + | (if has("resume_from") then .resume_from = namespaced_path(.resume_from; $model; $resume_src_mode) else . end) + | to_entries[] + | "--\(.key) \(.value|tostring)" + ' "$CONFIG_FILE" | paste -sd' ' - } # Run tests @@ -253,27 +268,18 @@ for ((id=0; id