Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ add_executable(gpt2
example/gpt2/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/gpt2/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
)
link_infini_train_exe(gpt2)
Expand All @@ -185,7 +185,7 @@ add_executable(llama3
example/llama3/main.cc
example/common/tiny_shakespeare_dataset.cc
example/common/utils.cc
example/llama3/checkpoint_loader.cc
example/common/checkpoint_loader.cc
example/common/tokenizer.cc
)
link_infini_train_exe(llama3)
Expand Down
1,060 changes: 1,060 additions & 0 deletions example/common/checkpoint_loader.cc

Large diffs are not rendered by default.

76 changes: 76 additions & 0 deletions example/common/checkpoint_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#pragma once

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

#include "gflags/gflags.h"

#include <cstdint>
#include <cstring>
#include <filesystem>

#include <functional>
#include <limits>
#include <string>

namespace infini_train {
namespace nn {
class TransformerModel;
}

namespace gpt2 {
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
} // namespace gpt2
namespace llama3 {
std::shared_ptr<nn::TransformerModel> LoadFromLLMC(const std::string &filepath);
void SaveAsLLMC(const std::shared_ptr<nn::TransformerModel> &model, const std::string &filepath);
} // namespace llama3

struct ResumeFromCheckpointArgs {
fLS::clstring resume_root;
const nn::parallel::Rank &rank;
std::shared_ptr<nn::Module> model;
std::shared_ptr<Optimizer> optimizer;
DistributedDataLoader &train_loader;
TrainerState &state;
DataLoaderIterator &train_iter;
CheckpointLoadOptions load_options;
};

struct ResumeFromCheckpointResult {
int global_step = 0;
float best_loss = std::numeric_limits<float>::infinity();
size_t data_batch_idx = 0;
};

struct SaveCheckpointArgs {
std::filesystem::path save_dir;
int64_t global_step = 0;
size_t data_batch_idx = 0;
float best_loss = std::numeric_limits<float>::infinity();
double last_lr = 0.0;
std::string optimizer_type;
std::string checkpoint_format = "bin";
int ddp_size = 1;
int tp_size = 1;
int sp_size = 1;
int pp_size = 1;
bool save_optimizer_state = true;
bool prune_step_checkpoints = false;
std::filesystem::path checkpoint_root_dir;
size_t max_checkpoint_keep = 0;
const nn::parallel::Rank &rank;
const nn::Module &model;
const Optimizer &optimizer;
std::function<void(const nn::Module &, const std::filesystem::path &)> model_bin_writer;
};

ResumeFromCheckpointResult ResumeFromCheckpoint(const ResumeFromCheckpointArgs &args);

void SaveCheckpoint(const SaveCheckpointArgs &args);

} // namespace infini_train
9 changes: 8 additions & 1 deletion example/common/utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
#include "example/common/utils.h"

#include <algorithm>
#include <chrono>

#include "gflags/gflags.h"
#include "gflags/gflags_declare.h"
#include "glog/logging.h"
#include "infini_train/include/nn/parallel/global.h"

namespace infini_train {

float ConvertBF16ToFloat(void *ptr) {
Expand Down Expand Up @@ -60,5 +68,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
ifs.read(reinterpret_cast<char *>(dst), static_cast<std::streamsize>(cnt * sizeof(float)));
ifs.seekg(base + std::streamoff(len * sizeof(float)));
}

} // namespace infini_train
12 changes: 12 additions & 0 deletions example/common/utils.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
#pragma once

#include "infini_train/include/checkpoint.h"
#include "infini_train/include/dataloader.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/optimizer.h"

#include "gflags/gflags.h"

#include <cstdint>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <functional>
#include <limits>
#include <string>
#include <vector>

namespace infini_train {
Expand Down
Loading
Loading