Skip to content

Commit e07056a

Browse files
authored
[Gaudi] Remove optimum-habana (#3261)
Signed-off-by: yuanwu <[email protected]>
1 parent 25fdc5f commit e07056a

File tree

20 files changed

+23
-5995
lines changed

20 files changed

+23
-5995
lines changed

Dockerfile_gaudi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ ARG PYTORCH_VERSION
5757

5858
FROM vault.habana.ai/gaudi-docker/${HABANA_VERSION}/ubuntu22.04/habanalabs/pytorch-installer-${PYTORCH_VERSION}:latest AS base
5959

60-
ENV ATTENTION=default
60+
ENV ATTENTION=paged
6161
ENV PREFIX_CACHING=0
6262
ENV PREFILL_CHUNKING=0
6363
ENV PT_HPU_LAZY_MODE=1

backends/gaudi/server/pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ opentelemetry-instrumentation-grpc = "^0.53b0"
2222
hf-transfer = "^0.1.9"
2323
sentencepiece = "^0.2.0"
2424
peft = "^0.15"
25-
optimum-habana = "1.17"
26-
transformers = "^4.49"
25+
transformers = "^4.52.4"
2726
numpy = "^1.26"
28-
accelerate = "^0.33"
27+
accelerate = "^1.7.0"
2928
outlines= { version = "^0.0.36", optional = true }
3029
prometheus-client = "^0.21.1"
3130
py-cpuinfo = "^9.0.0"

backends/gaudi/server/requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
1+
accelerate==1.7.0 ; python_version >= "3.9" and python_version < "3.13"
22
annotated-types==0.7.0 ; python_version >= "3.9" and python_version < "3.13"
33
attrs==25.3.0 ; python_version >= "3.9" and python_version < "3.13"
44
certifi==2025.1.31 ; python_version >= "3.9" and python_version < "3.13"
@@ -46,7 +46,6 @@ opentelemetry-instrumentation==0.53b0 ; python_version >= "3.9" and python_versi
4646
opentelemetry-proto==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
4747
opentelemetry-sdk==1.32.0 ; python_version >= "3.9" and python_version < "3.13"
4848
opentelemetry-semantic-conventions==0.53b0 ; python_version >= "3.9" and python_version < "3.13"
49-
optimum-habana==1.17.0 ; python_version >= "3.9" and python_version < "3.13"
5049
optimum==1.24.0 ; python_version >= "3.9" and python_version < "3.13"
5150
outlines==0.0.36 ; python_version >= "3.9" and python_version < "3.13"
5251
packaging==24.2 ; python_version >= "3.9" and python_version < "3.13"
@@ -76,7 +75,7 @@ sympy==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
7675
threadpoolctl==3.6.0 ; python_version >= "3.9" and python_version < "3.13"
7776
tokenizers==0.21.1 ; python_version >= "3.9" and python_version < "3.13"
7877
tqdm==4.67.1 ; python_version >= "3.9" and python_version < "3.13"
79-
transformers==4.49.0 ; python_version >= "3.9" and python_version < "3.13"
78+
transformers==4.52.4 ; python_version >= "3.9" and python_version < "3.13"
8079
triton==3.2.0 ; python_version >= "3.9" and python_version < "3.13" and platform_system == "Linux" and platform_machine == "x86_64"
8180
typer==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
8281
typing-extensions==4.13.2 ; python_version >= "3.9" and python_version < "3.13"

backends/gaudi/server/text_generation_server/cli.py

Lines changed: 13 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import os
2-
import psutil
3-
import signal
42
import sys
53
import typer
64

@@ -115,80 +113,19 @@ def serve(
115113
raise RuntimeError(
116114
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
117115
)
118-
119-
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
120-
121-
if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
122-
tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
123-
num_shard = int(os.getenv("WORLD_SIZE", "1"))
124-
logger.info("CLI SHARDED = {}".format(num_shard))
125-
import subprocess
126-
127-
cmd = (
128-
f"deepspeed --num_nodes 1 --num_gpus {num_shard} --no_local_rank {tgi_file}"
129-
)
130-
cmd += f" --model_id {model_id} --revision {revision} --sharded {sharded}"
131-
cmd += f" --dtype {dtype} --trust_remote_code {trust_remote_code} --uds_path {uds_path}"
132-
cmd += f" --quantize {quantize} --max_input_tokens {max_input_tokens}"
133-
if speculate is not None:
134-
cmd += f"--speculate {speculate}"
135-
logger.info("CLI server start deepspeed ={} ".format(cmd))
136-
sys.stdout.flush()
137-
sys.stderr.flush()
138-
with subprocess.Popen(cmd, shell=True, executable="/bin/bash") as proc:
139-
do_terminate = False
140-
current_handler = signal.getsignal(signal.SIGTERM)
141-
142-
def terminate_handler(sig, frame):
143-
nonlocal do_terminate
144-
do_terminate = True
145-
if callable(current_handler):
146-
current_handler(sig, frame)
147-
148-
signal.signal(signal.SIGTERM, terminate_handler)
149-
150-
finished = False
151-
while not finished:
152-
try:
153-
if do_terminate:
154-
parent = psutil.Process(proc.pid)
155-
all_procs = parent.children(recursive=True) + [parent]
156-
for p in all_procs:
157-
try:
158-
p.terminate()
159-
except psutil.NoSuchProcess:
160-
pass
161-
_, alive = psutil.wait_procs(all_procs, timeout=30)
162-
for p in alive:
163-
p.kill()
164-
165-
do_terminate = False
166-
167-
proc.wait(timeout=3)
168-
except subprocess.TimeoutExpired:
169-
pass
170-
else:
171-
finished = True
172-
173-
sys.stdout.flush()
174-
sys.stderr.flush()
175-
if proc.returncode != 0:
176-
logger.error(f"{cmd} exited with status = {proc.returncode}")
177-
return proc.returncode
178-
else:
179-
server.serve(
180-
model_id,
181-
lora_adapters,
182-
revision,
183-
sharded,
184-
quantize,
185-
speculate,
186-
dtype,
187-
kv_cache_dtype,
188-
trust_remote_code,
189-
uds_path,
190-
max_input_tokens,
191-
)
116+
server.serve(
117+
model_id,
118+
lora_adapters,
119+
revision,
120+
sharded,
121+
quantize,
122+
speculate,
123+
dtype,
124+
kv_cache_dtype,
125+
trust_remote_code,
126+
uds_path,
127+
max_input_tokens,
128+
)
192129

193130

194131
@app.command()

backends/gaudi/server/text_generation_server/habana_quantization_env.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

backends/gaudi/server/text_generation_server/models/__init__.py

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from loguru import logger
77
from transformers.configuration_utils import PretrainedConfig
8-
from transformers.models.auto import modeling_auto
98
from huggingface_hub import hf_hub_download, HfApi
109
from typing import Optional
1110
from pathlib import Path
@@ -36,14 +35,10 @@
3635
"Seq2SeqLM",
3736
"get_model_with_lora_adapters",
3837
]
39-
from text_generation_server.models.globals import ATTENTION
4038

4139
VLM_BATCH_TYPES = set()
42-
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
4340

44-
FLASH_ATTENTION = False
45-
if ATTENTION == "paged":
46-
FLASH_ATTENTION = True
41+
FLASH_ATTENTION = True
4742

4843
try:
4944
from text_generation_server.models.flash_causal_lm import FlashCausalLM
@@ -883,72 +878,6 @@ def get_model(
883878
trust_remote_code=trust_remote_code,
884879
)
885880

886-
from text_generation_server.models.causal_lm import CausalLM
887-
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
888-
from text_generation_server.models.custom_modeling.mllama import (
889-
MllamaForConditionalGeneration,
890-
)
891-
from text_generation_server.models.custom_modeling.llava_next import (
892-
LlavaNextForConditionalGeneration,
893-
)
894-
from text_generation_server.models.vlm_causal_lm import (
895-
VlmCausalLMBatch,
896-
)
897-
898-
VLM_BATCH_TYPES.add(VlmCausalLMBatch)
899-
900-
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
901-
902-
adapt_transformers_to_gaudi()
903-
if SDP_ON_BF16 == 1:
904-
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
905-
if model_type == "gpt_bigcode":
906-
from text_generation_server.models.starcoder import StarCoder
907-
908-
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
909-
if model_type == "bloom":
910-
from text_generation_server.models.bloom import BLOOM
911-
912-
return BLOOM(
913-
model_id=model_id,
914-
revision=revision,
915-
speculator=speculator,
916-
dtype=dtype,
917-
trust_remote_code=trust_remote_code,
918-
)
919-
920-
if model_type == "llava_next":
921-
return VlmCausalLM(
922-
model_class=LlavaNextForConditionalGeneration,
923-
model_id=model_id,
924-
revision=revision,
925-
quantize=None,
926-
speculator=speculator,
927-
dtype=dtype,
928-
trust_remote_code=trust_remote_code,
929-
)
930-
931-
if model_type == "mllama":
932-
return VlmCausalLM(
933-
model_class=MllamaForConditionalGeneration,
934-
model_id=model_id,
935-
revision=revision,
936-
quantize=None,
937-
speculator=speculator,
938-
dtype=dtype,
939-
trust_remote_code=trust_remote_code,
940-
)
941-
942-
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
943-
return CausalLM(
944-
model_id,
945-
revision,
946-
quantize=quantize,
947-
speculator=speculator,
948-
dtype=dtype,
949-
trust_remote_code=trust_remote_code,
950-
)
951-
952881
raise ValueError(f"Unsupported model type {model_type}")
953882

954883

backends/gaudi/server/text_generation_server/models/bloom.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)