2121import json
2222import os
2323
24- from datasets import load_from_disk , Dataset
24+ from datasets import load_dataset , Dataset
2525from datasets .distributed import split_dataset_by_node
2626from peft import LoraConfig , get_peft_model
2727import transformers
@@ -71,28 +71,26 @@ def setup_model_and_tokenizer(model_uri, transformer_type, model_dir):
7171 return model , tokenizer
7272
7373# This function is a modified version of the original.
74- def load_and_preprocess_data (dataset_dir , transformer_type , tokenizer ):
74+ def load_and_preprocess_data (dataset_file , transformer_type , tokenizer ):
7575 # Load and preprocess the dataset
7676 logger .info ("Load and preprocess dataset" )
7777
78- file_path = os .path .realpath (dataset_dir )
78+ file_path = os .path .realpath (dataset_file )
7979
80- if transformer_type != AutoModelForImageClassification :
81- dataset = load_from_disk (file_path )
80+ dataset = load_dataset ('json' ,data_files = file_path )
8281
82+ if transformer_type != AutoModelForImageClassification :
8383 logger .info (f"Dataset specification: { dataset } " )
8484 logger .info ("-" * 40 )
8585
8686 logger .info ("Tokenize dataset" )
8787 # TODO (andreyvelich): Discuss how user should set the tokenizer function.
8888 num_cores = os .cpu_count ()
8989 dataset = dataset .map (
90- lambda x : tokenizer (x ["text " ], padding = True , truncation = True , max_length = 128 ),
90+ lambda x : tokenizer (x ["output " ], padding = True , truncation = True , max_length = 128 ),
9191 batched = True ,
9292 num_proc = num_cores
9393 )
94- else :
95- dataset = load_from_disk (file_path )
9694
9795 # Check if dataset contains `train` key. Otherwise, load full dataset to train_data.
9896 if "train" in dataset :
@@ -175,7 +173,7 @@ def parse_arguments():
175173 parser .add_argument ("--model_uri" , help = "model uri" )
176174 parser .add_argument ("--transformer_type" , help = "model transformer type" )
177175 parser .add_argument ("--model_dir" , help = "directory containing model" )
178- parser .add_argument ("--dataset_dir " , help = "directory containing dataset " )
176+ parser .add_argument ("--dataset_file " , help = "dataset file path " )
179177 parser .add_argument ("--lora_config" , help = "lora_config" )
180178 parser .add_argument (
181179 "--training_parameters" , help = "hugging face training parameters"
@@ -197,7 +195,7 @@ def parse_arguments():
197195
198196 logger .info ("Preprocess dataset" )
199197 train_data , eval_data = load_and_preprocess_data (
200- args .dataset_dir , transformer_type , tokenizer
198+ args .dataset_file , transformer_type , tokenizer
201199 )
202200
203201 logger .info ("Setup LoRA config for model" )
0 commit comments