Skip to content

Commit c3f7947

Browse files
upload
1 parent b4bc4c5 commit c3f7947

14 files changed

+4821
-0
lines changed

llm_wrapper/LLM_finetuning_script.py

Lines changed: 571 additions & 0 deletions
Large diffs are not rendered by default.

llm_wrapper/LLM_inference_script.py

Lines changed: 1422 additions & 0 deletions
Large diffs are not rendered by default.

llm_wrapper/LLM_utils.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import json
2+
import os
3+
import random
4+
import time
5+
6+
import torch
7+
8+
# Functions for LLM loading
9+
10+
11+
def get_HF_model_id(HF_LLM_name):
12+
13+
if "Mixtral_8x7" in HF_LLM_name or "Mixtral-8x7" in HF_LLM_name:
14+
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
15+
base_model_name = "Mixtral-8x7B-Instruct-v0.1"
16+
17+
elif "Mixtral_8x22" in HF_LLM_name or "Mixtral-8x22" in HF_LLM_name:
18+
model_id = "mistralai/Mixtral-8x22B-Instruct-v0.1"
19+
base_model_name = "Mixtral-8x22B-Instruct-v0.1"
20+
21+
elif "Llama-3-8B" in HF_LLM_name or "Llama3_8" in HF_LLM_name:
22+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
23+
base_model_name = "Meta-Llama-3-8B-Instruct"
24+
25+
elif "Llama-3-70B" in HF_LLM_name or "Llama3_70" in HF_LLM_name:
26+
model_id = "meta-llama/Meta-Llama-3-70B-Instruct"
27+
base_model_name = "Meta-Llama-3-70B-Instruct"
28+
29+
elif "gemma-2-9b-it" in HF_LLM_name:
30+
model_id = "google/gemma-2-9b-it"
31+
base_model_name = "gemma-2-9b-it"
32+
33+
elif "gemma-2-2b-it" in HF_LLM_name:
34+
model_id = "google/gemma-2-2b-it"
35+
base_model_name = "gemma-2-2b-it"
36+
37+
elif "gpt-neo-2.7B" in HF_LLM_name:
38+
model_id = "EleutherAI/gpt-neo-2.7B"
39+
base_model_name = "gpt-neo-2.7B"
40+
41+
elif "gpt-neo-1.3B" in HF_LLM_name:
42+
model_id = "EleutherAI/gpt-neo-1.3B"
43+
base_model_name = "gpt-neo-1.3B"
44+
45+
elif "gpt-neo-125m" in HF_LLM_name:
46+
model_id = "EleutherAI/gpt-neo-125m"
47+
base_model_name = "gpt-neo-125m"
48+
49+
else:
50+
raise Exception(f"not implemented for the LLM ({HF_LLM_name})")
51+
52+
return model_id, base_model_name
53+
54+
55+
def load_HF_model_tok(args, HF_LLM_name, eval_mode=True, FT_mode=False, timing=True, quantization=True):
56+
57+
if timing:
58+
s = time.time()
59+
60+
with open(args.HF_token_path) as f:
61+
HF_TOKEN = json.load(f)
62+
63+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
64+
65+
if HF_LLM_name == "Mixtral-8x22B-Instruct-v0.1":
66+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch # noqa: F401
67+
68+
# Find base model
69+
model_id, base_model_name = get_HF_model_id(HF_LLM_name=HF_LLM_name)
70+
71+
# Define model / tokenizer's cache folders
72+
cache_dir = os.path.join(args.main_folder, "LLM_cache", f"cache_{base_model_name}")
73+
cache_dir_tok = os.path.join(args.main_folder, "LLM_cache", f"cache_{base_model_name}_tokenizer")
74+
print(f"(Down)loading base model {base_model_name} from/to {args.main_folder}/LLM_cache/")
75+
76+
# Load tokenizer
77+
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir_tok, token=HF_TOKEN, force_download=False)
78+
tokenizer.pad_token = tokenizer.eos_token
79+
tokenizer.padding_side = "right"
80+
81+
# Load base model
82+
use_cache = True if FT_mode is False else False
83+
attn_implem = "eager" if "gemma" in HF_LLM_name else "flash_attention_2"
84+
85+
if quantization:
86+
87+
nf4_config = BitsAndBytesConfig(
88+
load_in_4bit=True,
89+
bnb_4bit_quant_type="nf4",
90+
bnb_4bit_use_double_quant=True,
91+
bnb_4bit_compute_dtype=torch.bfloat16,
92+
)
93+
94+
model = AutoModelForCausalLM.from_pretrained(
95+
model_id,
96+
cache_dir=cache_dir,
97+
device_map="auto",
98+
quantization_config=nf4_config,
99+
use_cache=use_cache,
100+
attn_implementation=attn_implem,
101+
token=HF_TOKEN,
102+
)
103+
104+
else:
105+
model = AutoModelForCausalLM.from_pretrained(
106+
model_id,
107+
cache_dir=cache_dir,
108+
device_map="auto",
109+
# quantization_config=nf4_config,
110+
use_cache=use_cache,
111+
attn_implementation=attn_implem,
112+
token=HF_TOKEN,
113+
)
114+
115+
# add FT lora weights to base model
116+
if HF_LLM_name not in [
117+
"Mixtral-8x7B-Instruct-v0.1",
118+
"Mixtral-8x22B-Instruct-v0.1",
119+
"Meta-Llama-3-8B-Instruct",
120+
"Meta-Llama-3-70B-Instruct",
121+
"gemma-2-2b-it",
122+
"gemma-2-9b-it",
123+
"gpt-neo-2.7B",
124+
"gpt-neo-1.3B",
125+
"gpt-neo-125m",
126+
]:
127+
128+
# First loading option => instance directly FT model but it redownloads shards/base model
129+
"""
130+
model = AutoModelForCausalLM.from_pretrained(FT_model_save_path,
131+
device_map='auto',
132+
quantization_config=nf4_config,
133+
use_cache=use_cache,
134+
attn_implementation=attn_implem,
135+
token=HF_TOKEN)
136+
"""
137+
138+
# Second loading option => instance FT model efficiently on top of base model
139+
ft_dataset = args.ft_dataset if args.ft_dataset is not None else args.dataset
140+
path_to_folder = os.path.join(args.main_folder, ft_dataset, "my_FT_models")
141+
assert os.path.exists(os.path.join(path_to_folder, HF_LLM_name))
142+
model = load_FT_model_via_base_model(
143+
base_model=model,
144+
complete_FT_path=os.path.join(path_to_folder, HF_LLM_name),
145+
timing=timing,
146+
)
147+
148+
if eval_mode:
149+
model.eval()
150+
151+
if timing:
152+
e = time.time()
153+
print("HF LLM / tokenizer loading time:", round(e - s, 1), "secs")
154+
155+
return model, tokenizer
156+
157+
158+
def load_FT_model_via_base_model(base_model, complete_FT_path, timing):
159+
if timing:
160+
s = time.time()
161+
162+
# new loading way (load base model + give checkpoint path)
163+
from peft import PeftModel
164+
165+
print(f"add lora adapters to base model from: {complete_FT_path}")
166+
FT_model = PeftModel.from_pretrained(base_model, complete_FT_path)
167+
if timing:
168+
e = time.time()
169+
print("Time to load lora modules on base model:", round(e - s, 2), "secs")
170+
return FT_model
171+
172+
173+
# Functions for text generation
174+
175+
176+
def gen_step_3_with_HF(HF_LLM_name, model, tokenizer, prompt, verbose=False, max_gen_tok=None):
177+
178+
encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
179+
model_inputs = encoded_input.to("cuda")
180+
max_new_tokens = 128 if max_gen_tok is None else max_gen_tok
181+
182+
if "Mixtral" in HF_LLM_name or "gemma" in HF_LLM_name or "gpt-neo" in HF_LLM_name:
183+
generated_ids = model.generate(
184+
**model_inputs,
185+
max_new_tokens=max_new_tokens,
186+
do_sample=False,
187+
pad_token_id=tokenizer.eos_token_id,
188+
)
189+
190+
elif "Llama" in HF_LLM_name or "llama" in HF_LLM_name:
191+
model.generation_config.temperature = None
192+
model.generation_config.top_p = None
193+
generated_ids = model.generate(
194+
**model_inputs,
195+
max_new_tokens=max_new_tokens,
196+
eos_token_id=tokenizer.eos_token_id,
197+
do_sample=False,
198+
pad_token_id=tokenizer.eos_token_id,
199+
)
200+
201+
stripped_result = tokenizer.batch_decode(generated_ids[0][encoded_input["input_ids"][0].shape[0] :].unsqueeze(0))[0]
202+
if "Mixtral" in HF_LLM_name:
203+
stripped_result = stripped_result.replace("</s>", "").strip()
204+
elif "Llama" in HF_LLM_name or "llama" in HF_LLM_name:
205+
stripped_result = stripped_result.replace("<|eot_id|>", "").strip()
206+
elif "gemma" in HF_LLM_name:
207+
if "gemma-2-2b-it" in HF_LLM_name:
208+
print("gemma : before strip:", stripped_result)
209+
stripped_result = (
210+
stripped_result.strip("<eos>").strip().strip("\n").strip().strip("<end_of_turn>").strip().strip("\n").strip()
211+
)
212+
if "gemma-2-2b-it" in HF_LLM_name:
213+
print("gemma : after strip:", stripped_result)
214+
elif "gpt-neo" in HF_LLM_name:
215+
stripped_result = stripped_result.strip()
216+
else:
217+
print("WARNING stripping on generation by this LLM not implemented")
218+
219+
if verbose:
220+
print("LLM generated (stripped) result:", stripped_result)
221+
222+
return stripped_result
223+
224+
225+
def extract_int_result_from_LLM_gen(result, HF_LLM_name, return_error_type, nb_detected_boxes=None, fixed_seed=None):
226+
227+
# Clean up specific to gemma-2-2b
228+
result = result.strip().strip("'").strip('"')
229+
if HF_LLM_name == "gemma-2-2b-it":
230+
for i in range(9):
231+
result = result.replace(str(i) + ". ", "")
232+
233+
# Int extraction
234+
new_result = ""
235+
started_extraction = False
236+
for char in result:
237+
if char in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]:
238+
new_result += char
239+
started_extraction = True
240+
else:
241+
if started_extraction is True:
242+
break
243+
244+
# Error analysis if specified
245+
if return_error_type:
246+
247+
error_type = None # indicate type of error if any
248+
249+
if new_result != "":
250+
251+
new_result = int(new_result)
252+
253+
if new_result >= nb_detected_boxes:
254+
new_result = None
255+
error_type = "gen_out_of_scope"
256+
else:
257+
if fixed_seed is not None:
258+
list_id = list(range(nb_detected_boxes))
259+
random.seed(fixed_seed)
260+
random.shuffle(list_id)
261+
new_result = list_id[new_result]
262+
263+
else:
264+
new_result = None
265+
error_type = "no_int_generated"
266+
267+
return new_result, error_type
268+
else:
269+
return new_result

0 commit comments

Comments
 (0)