Skip to content

Commit 4586479

Browse files
authored
Enable phi-4 with vision and audio (#13203)
* add phi4 * update * enable audio * update and add readme
1 parent e032156 commit 4586479

File tree

4 files changed

+183
-0
lines changed

4 files changed

+183
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
from dataclasses import asdict
3+
from typing import NamedTuple, Optional
4+
5+
from huggingface_hub import snapshot_download
6+
from transformers import AutoTokenizer
7+
8+
from vllm import LLM, EngineArgs, SamplingParams
9+
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
10+
from vllm.assets.audio import AudioAsset
11+
from vllm.utils import FlexibleArgumentParser
12+
13+
audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
14+
question_per_audio_count = {
15+
0: "What is 1+1?",
16+
1: "What is recited in the audio?",
17+
2: "What sport and what nursery rhyme are referenced?"
18+
}
19+
20+
model_path = "/llm/models/whisper-large-v3-turbo"
21+
#model_path = "/llm/models/whisper-medium"
22+
#model_path = "/llm/models/Phi-4-multimodal-instruct"
23+
24+
# Phi-4-multimodal-instruct
25+
def run_phi4mm(question: str, audio_count: int):
26+
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
27+
28+
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
29+
30+
return prompt
31+
32+
33+
# Whisper
34+
def run_whisper(question: str, audio_count: int):
35+
assert audio_count == 1, (
36+
"Whisper only support single audio input per prompt")
37+
38+
prompt = "<|startoftranscript|>"
39+
40+
return prompt
41+
42+
43+
model_example_map = {
44+
"phi4mm": run_phi4mm,
45+
"whisper": run_whisper,
46+
}
47+
48+
49+
if "whisper" in model_path:
50+
model_len=448
51+
low_bit="fp16"
52+
else:
53+
model_len = 5500
54+
low_bit="sym_int4"
55+
56+
def main(args):
57+
audio_count = args.num_audios
58+
59+
llm = LLM(
60+
model=model_path,
61+
device="xpu",
62+
dtype="float16",
63+
limit_mm_per_prompt={"audio": audio_count},
64+
enforce_eager=True,
65+
mm_processor_kwargs=None,
66+
load_in_low_bit=low_bit,
67+
tensor_parallel_size=1,
68+
max_num_seqs=8,
69+
gpu_memory_utilization=0.95,
70+
disable_async_output_proc=True,
71+
distributed_executor_backend="ray",
72+
max_model_len=model_len,
73+
trust_remote_code=True,
74+
block_size=8,
75+
max_num_batched_tokens=model_len)
76+
77+
model = llm.llm_engine.model_config.hf_config.model_type
78+
if model not in model_example_map:
79+
raise ValueError(f"Model type {model} is not supported.")
80+
81+
prompt = model_example_map[model](question_per_audio_count[audio_count], audio_count)
82+
83+
sampling_params = SamplingParams(temperature=0.1,
84+
top_p=0.001,
85+
repetition_penalty=1.05,
86+
max_tokens=128,
87+
skip_special_tokens=False
88+
)
89+
90+
mm_data = {}
91+
if audio_count > 0:
92+
mm_data = {
93+
"audio": [
94+
asset.audio_and_sample_rate
95+
for asset in audio_assets[:audio_count]
96+
]
97+
}
98+
99+
assert args.num_prompts > 0
100+
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
101+
if args.num_prompts > 1:
102+
# Batch inference
103+
inputs = [inputs] * args.num_prompts
104+
105+
outputs = llm.generate(inputs, sampling_params=sampling_params)
106+
107+
for o in outputs:
108+
generated_text = o.outputs[0].text
109+
print(generated_text)
110+
111+
112+
if __name__ == "__main__":
113+
parser = FlexibleArgumentParser(
114+
description='Demo on using vLLM for offline inference with '
115+
'audio language models')
116+
parser.add_argument('--num-prompts',
117+
type=int,
118+
default=1,
119+
help='Number of prompts to run.')
120+
parser.add_argument("--num-audios",
121+
type=int,
122+
default=1,
123+
choices=[0, 1, 2],
124+
help="Number of audio items per prompt.")
125+
parser.add_argument("--seed",
126+
type=int,
127+
default=None,
128+
help="Set the seed when initializing `vllm.LLM`.")
129+
130+
args = parser.parse_args()
131+
main(args)

docker/llm/serving/xpu/docker/vllm_offline_inference_vision_language.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
model_path = "/llm/models/InternVL2-8B"
1111
model_path = "/llm/models/gemma-3-12b-it"
1212
model_path = "/llm/models/Qwen2.5-VL-7B-Instruct"
13+
model_path = "/llm/models/Phi-4-multimodal-instruct"
1314

1415
prompt = "What is in the image?"
1516

@@ -77,6 +78,18 @@ def run_qwen2_vl(question, modality):
7778
stop_token_ids = None
7879
return prompt, stop_token_ids
7980

81+
# Phi-4-multimodal-instruct
82+
def run_phi4mm(question, modality):
83+
"""
84+
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
85+
show how to process image inputs.
86+
"""
87+
assert modality == "image"
88+
89+
prompt = f"<|user|><|image_1|>{question}<|end|><|assistant|>"
90+
stop_token_ids = None
91+
return prompt, stop_token_ids
92+
8093
model_example_map = {
8194
"minicpmv": run_minicpmv,
8295
"qwen2_vl": run_qwen2_vl,
@@ -85,6 +98,7 @@ def run_qwen2_vl(question, modality):
8598
"chatglm": run_glm4v,
8699
"internvl_chat": run_internvl,
87100
"gemma3": run_gemma3,
101+
"phi4mm": run_phi4mm,
88102
}
89103

90104
if "glm-4v" in model_path:

docs/mddocs/DockerGuides/vllm_docker_quickstart.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,8 @@ docker logs CONTAINER_NAME
438438
## 8. Advanced Features
439439

440440
#### Multi-modal Model
441+
442+
##### Vision model
441443
<details>
442444
vLLM serving with IPEX-LLM supports multi-modal models, such as [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6), which can accept image and text input at the same time and respond.
443445

@@ -478,6 +480,40 @@ curl http://localhost:8000/v1/chat/completions \
478480
```
479481
</details>
480482

483+
##### Audio model
484+
<details>
485+
vLLM serving with IPEX-LLM supports multi-modal models, such as [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct)(only offine now) and whisper series model([whisper-large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [whisper-medium](https://huggingface.co/openai/whisper-medium)), which can accept audio input and respond text output.
486+
487+
Offline test:
488+
```bash
489+
export VLLM_USE_V1=0
490+
python3 audio_language.py
491+
```
492+
493+
Online test:
494+
1. Start vLLM service: change the `model` and `served_model_name` value in `/llm/start-vllm-service.sh`
495+
496+
2. Download or get a audio file first.
497+
```python
498+
# python3 load.py
499+
from vllm.assets.audio import AudioAsset
500+
import soundfile as sf
501+
502+
audio, sr = AudioAsset("winning_call").audio_and_sample_rate
503+
504+
sf.write("output.wav", audio, sr)
505+
```
506+
507+
3. Send request with audio file and prompt text(optional).
508+
509+
```bash
510+
curl http://localhost:8000/v1/audio/transcriptions \
511+
-H "Content-Type: multipart/form-data" \
512+
-F file="@/llm/models/test/output.wav" \
513+
-F model="whisper-large-v3-turbo"
514+
```
515+
</details>
516+
481517
#### Preifx Caching
482518
<details>
483519
Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part.

python/llm/src/ipex_llm/vllm/xpu/model_convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def _ipex_llm_load_model(self) -> None:
129129
if "glm-4v" in self.vllm_config.model_config.model.lower() and \
130130
low_bit in ("sym_int4", "woq_int4"):
131131
modules = ["dense_4h_to_h"]
132+
if "phi4mm" in self.vllm_config.model_config.hf_config.model_type:
133+
modules = ["vision_encoder", "embed_tokens_extend"]
132134
if low_bit == "fp16":
133135
# to fix qwen2.5-vl and glm-4v
134136
modules = ["vision", "visual"]

0 commit comments

Comments
 (0)