Skip to content

Commit 24ee40d

Browse files
authored
feat: support max_image_fetch_size to limit (#3339)
* feat: support max_image_fetch_size to limit * fix: update model path for test * fix: adjust model repo id for test again * fix: apply clippy lints * fix: clippy fix * fix: avoid torch build isolation in docker * fix: bump repo id in flash llama tests * fix: temporarily avoid problematic repos in tests
1 parent 85790a1 commit 24ee40d

File tree

12 files changed

+78
-8
lines changed

12 files changed

+78
-8
lines changed

backends/llamacpp/src/main.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ struct Args {
157157
/// Maximum payload size in bytes.
158158
#[clap(default_value = "2000000", long, env)]
159159
payload_limit: usize,
160+
161+
/// Maximum image fetch size in bytes.
162+
#[clap(default_value = "1073741824", long, env)]
163+
max_image_fetch_size: usize,
160164
}
161165

162166
#[tokio::main]
@@ -320,6 +324,7 @@ async fn main() -> Result<(), RouterError> {
320324
args.max_client_batch_size,
321325
args.usage_stats,
322326
args.payload_limit,
327+
args.max_image_fetch_size,
323328
args.prometheus_port,
324329
)
325330
.await?;

backends/trtllm/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ struct Args {
6767
usage_stats: UsageStatsLevel,
6868
#[clap(default_value = "2000000", long, env)]
6969
payload_limit: usize,
70+
#[clap(default_value = "1073741824", long, env)]
71+
max_image_fetch_size: usize,
7072
}
7173

7274
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
@@ -244,6 +246,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
244246
executor_worker,
245247
usage_stats,
246248
payload_limit,
249+
max_image_fetch_size,
247250
} = args;
248251

249252
// Launch Tokio runtime
@@ -325,6 +328,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
325328
max_client_batch_size,
326329
usage_stats,
327330
payload_limit,
331+
max_image_fetch_size,
328332
prometheus_port,
329333
)
330334
.await?;

backends/v2/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct Args {
7474
usage_stats: usage_stats::UsageStatsLevel,
7575
#[clap(default_value = "2000000", long, env)]
7676
payload_limit: usize,
77+
#[clap(default_value = "1073741824", long, env)]
78+
max_image_fetch_size: usize,
7779
}
7880

7981
#[derive(Debug, Subcommand)]
@@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
120122
max_client_batch_size,
121123
usage_stats,
122124
payload_limit,
125+
max_image_fetch_size,
123126
} = args;
124127

125128
if let Some(Commands::PrintSchema) = command {
@@ -201,6 +204,7 @@ async fn main() -> Result<(), RouterError> {
201204
max_client_batch_size,
202205
usage_stats,
203206
payload_limit,
207+
max_image_fetch_size,
204208
prometheus_port,
205209
)
206210
.await?;

backends/v3/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ struct Args {
7474
usage_stats: usage_stats::UsageStatsLevel,
7575
#[clap(default_value = "2000000", long, env)]
7676
payload_limit: usize,
77+
#[clap(default_value = "1073741824", long, env)]
78+
max_image_fetch_size: usize,
7779
}
7880

7981
#[derive(Debug, Subcommand)]
@@ -120,6 +122,7 @@ async fn main() -> Result<(), RouterError> {
120122
max_client_batch_size,
121123
usage_stats,
122124
payload_limit,
125+
max_image_fetch_size,
123126
} = args;
124127

125128
if let Some(Commands::PrintSchema) = command {
@@ -217,6 +220,7 @@ async fn main() -> Result<(), RouterError> {
217220
max_client_batch_size,
218221
usage_stats,
219222
payload_limit,
223+
max_image_fetch_size,
220224
prometheus_port,
221225
)
222226
.await?;

integration-tests/models/test_flash_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
@pytest.fixture(scope="module")
55
def flash_llama_handle(launcher):
6-
with launcher("huggingface/llama-7b", num_shard=2) as handle:
6+
with launcher("huggyllama/llama-7b", num_shard=2) as handle:
77
yield handle
88

99

integration-tests/models/test_flash_llama_fp8.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ async def flash_llama_fp8(flash_llama_fp8_handle):
1313
return flash_llama_fp8_handle.client
1414

1515

16+
@pytest.mark.skip(reason="Issue with the model access")
1617
@pytest.mark.release
1718
@pytest.mark.asyncio
1819
@pytest.mark.private
@@ -26,6 +27,7 @@ async def test_flash_llama_fp8(flash_llama_fp8, response_snapshot):
2627
assert response == response_snapshot
2728

2829

30+
@pytest.mark.skip(reason="Issue with the model access")
2931
@pytest.mark.release
3032
@pytest.mark.asyncio
3133
@pytest.mark.private
@@ -49,6 +51,7 @@ async def test_flash_llama_fp8_all_params(flash_llama_fp8, response_snapshot):
4951
assert response == response_snapshot
5052

5153

54+
@pytest.mark.skip(reason="Issue with the model access")
5255
@pytest.mark.release
5356
@pytest.mark.asyncio
5457
@pytest.mark.private

integration-tests/models/test_flash_llama_marlin_24.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin24_handle):
1515
return flash_llama_marlin24_handle.client
1616

1717

18+
@pytest.mark.skip(reason="Issue with the model access")
1819
@pytest.mark.release
1920
@pytest.mark.asyncio
2021
@pytest.mark.private
@@ -27,6 +28,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
2728
assert response == response_snapshot
2829

2930

31+
@pytest.mark.skip(reason="Issue with the model access")
3032
@pytest.mark.release
3133
@pytest.mark.asyncio
3234
@pytest.mark.private
@@ -50,6 +52,7 @@ async def test_flash_llama_marlin24_all_params(flash_llama_marlin, response_snap
5052
assert response == response_snapshot
5153

5254

55+
@pytest.mark.skip(reason="Issue with the model access")
5356
@pytest.mark.release
5457
@pytest.mark.asyncio
5558
@pytest.mark.private

router/src/chat.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ mod tests {
673673
let (name, arguments) = get_tool_call_content(&events[0]);
674674
if let Some(name) = name {
675675
assert_eq!(name, "get_current_weather");
676-
output_name.push_str(&name);
676+
output_name.push_str(name);
677677
}
678678
output.push_str(arguments);
679679
} else {

router/src/server.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,7 @@ pub async fn run(
15231523
max_client_batch_size: usize,
15241524
usage_stats_level: usage_stats::UsageStatsLevel,
15251525
payload_limit: usize,
1526+
max_image_fetch_size: usize,
15261527
prometheus_port: u16,
15271528
) -> Result<(), WebServerError> {
15281529
// CORS allowed origins
@@ -1827,6 +1828,7 @@ pub async fn run(
18271828
compat_return_full_text,
18281829
allow_origin,
18291830
payload_limit,
1831+
max_image_fetch_size,
18301832
prometheus_port,
18311833
)
18321834
.await;
@@ -1889,6 +1891,7 @@ async fn start(
18891891
compat_return_full_text: bool,
18901892
allow_origin: Option<AllowOrigin>,
18911893
payload_limit: usize,
1894+
max_image_fetch_size: usize,
18921895
prometheus_port: u16,
18931896
) -> Result<(), WebServerError> {
18941897
// Determine the server port based on the feature and environment variable.
@@ -1920,6 +1923,7 @@ async fn start(
19201923
max_input_tokens,
19211924
max_total_tokens,
19221925
disable_grammar_support,
1926+
max_image_fetch_size,
19231927
);
19241928

19251929
let infer = Infer::new(

router/src/validation.rs

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use rand::{thread_rng, Rng};
1212
use serde_json::Value;
1313
/// Payload validation logic
1414
use std::cmp::min;
15-
use std::io::Cursor;
15+
use std::io::{Cursor, Read};
1616
use std::iter;
1717
use std::sync::Arc;
1818
use thiserror::Error;
@@ -51,6 +51,7 @@ impl Validation {
5151
max_input_length: usize,
5252
max_total_tokens: usize,
5353
disable_grammar_support: bool,
54+
max_image_fetch_size: usize,
5455
) -> Self {
5556
let workers = if let Tokenizer::Python { .. } = &tokenizer {
5657
1
@@ -78,6 +79,7 @@ impl Validation {
7879
config_clone,
7980
preprocessor_config_clone,
8081
tokenizer_receiver,
82+
max_image_fetch_size,
8183
)
8284
});
8385
}
@@ -480,6 +482,7 @@ fn tokenizer_worker(
480482
config: Option<Config>,
481483
preprocessor_config: Option<HubPreprocessorConfig>,
482484
mut receiver: mpsc::UnboundedReceiver<TokenizerRequest>,
485+
max_image_fetch_size: usize,
483486
) {
484487
match tokenizer {
485488
Tokenizer::Python {
@@ -503,6 +506,7 @@ fn tokenizer_worker(
503506
&tokenizer,
504507
config.as_ref(),
505508
preprocessor_config.as_ref(),
509+
max_image_fetch_size,
506510
))
507511
.unwrap_or(())
508512
})
@@ -524,6 +528,7 @@ fn tokenizer_worker(
524528
&tokenizer,
525529
config.as_ref(),
526530
preprocessor_config.as_ref(),
531+
max_image_fetch_size,
527532
))
528533
.unwrap_or(())
529534
})
@@ -562,10 +567,35 @@ fn format_to_mimetype(format: ImageFormat) -> String {
562567
.to_string()
563568
}
564569

565-
fn fetch_image(input: &str) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
570+
fn fetch_image(
571+
input: &str,
572+
max_image_fetch_size: usize,
573+
) -> Result<(Vec<u8>, String, usize, usize), ValidationError> {
566574
if input.starts_with("![](http://") || input.starts_with("![](https://") {
567575
let url = &input["![](".len()..input.len() - 1];
568-
let data = reqwest::blocking::get(url)?.bytes()?;
576+
let response = reqwest::blocking::get(url)?;
577+
578+
// Check Content-Length header if present
579+
if let Some(content_length) = response.content_length() {
580+
if content_length as usize > max_image_fetch_size {
581+
return Err(ValidationError::ImageTooLarge(
582+
content_length as usize,
583+
max_image_fetch_size,
584+
));
585+
}
586+
}
587+
588+
// Read the body with size limit to prevent unbounded memory allocation
589+
let mut data = Vec::new();
590+
let mut limited_reader = response.take((max_image_fetch_size + 1) as u64);
591+
limited_reader.read_to_end(&mut data)?;
592+
593+
if data.len() > max_image_fetch_size {
594+
return Err(ValidationError::ImageTooLarge(
595+
data.len(),
596+
max_image_fetch_size,
597+
));
598+
}
569599

570600
let format = image::guess_format(&data)?;
571601
// TODO Remove this clone
@@ -787,6 +817,7 @@ fn prepare_input<T: TokenizerTrait>(
787817
tokenizer: &T,
788818
config: Option<&Config>,
789819
preprocessor_config: Option<&HubPreprocessorConfig>,
820+
max_image_fetch_size: usize,
790821
) -> Result<(tokenizers::Encoding, Vec<Chunk>), ValidationError> {
791822
use Config::*;
792823
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
@@ -805,7 +836,8 @@ fn prepare_input<T: TokenizerTrait>(
805836
input_chunks.push(Chunk::Text(inputs[start..chunk_start].to_string()));
806837
tokenizer_query.push_str(&inputs[start..chunk_start]);
807838
}
808-
let (data, mimetype, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
839+
let (data, mimetype, height, width) =
840+
fetch_image(&inputs[chunk_start..chunk_end], max_image_fetch_size)?;
809841
input_chunks.push(Chunk::Image(Image { data, mimetype }));
810842
tokenizer_query.push_str(&image_tokens(config, preprocessor_config, height, width));
811843
start = chunk_end;
@@ -990,6 +1022,10 @@ pub enum ValidationError {
9901022
InvalidImageContent(String),
9911023
#[error("Could not fetch image: {0}")]
9921024
FailedFetchImage(#[from] reqwest::Error),
1025+
#[error("Image size {0} bytes exceeds maximum allowed size of {1} bytes")]
1026+
ImageTooLarge(usize, usize),
1027+
#[error("Failed to read image data: {0}")]
1028+
ImageReadError(#[from] std::io::Error),
9931029
#[error("{0} modality is not supported")]
9941030
UnsupportedModality(&'static str),
9951031
}
@@ -1023,6 +1059,7 @@ mod tests {
10231059
max_input_length,
10241060
max_total_tokens,
10251061
disable_grammar_support,
1062+
1024 * 1024 * 1024, // 1GB
10261063
);
10271064

10281065
let max_new_tokens = 10;
@@ -1058,6 +1095,7 @@ mod tests {
10581095
max_input_length,
10591096
max_total_tokens,
10601097
disable_grammar_support,
1098+
1024 * 1024 * 1024, // 1GB
10611099
);
10621100

10631101
let max_new_tokens = 10;
@@ -1092,6 +1130,7 @@ mod tests {
10921130
max_input_length,
10931131
max_total_tokens,
10941132
disable_grammar_support,
1133+
1024 * 1024 * 1024, // 1GB
10951134
);
10961135
match validation
10971136
.validate(GenerateRequest {
@@ -1132,6 +1171,7 @@ mod tests {
11321171
max_input_length,
11331172
max_total_tokens,
11341173
disable_grammar_support,
1174+
1024 * 1024 * 1024, // 1GB
11351175
);
11361176
match validation
11371177
.validate(GenerateRequest {
@@ -1203,6 +1243,7 @@ mod tests {
12031243
max_input_length,
12041244
max_total_tokens,
12051245
disable_grammar_support,
1246+
1024 * 1024 * 1024, // 1GB
12061247
);
12071248
match validation
12081249
.validate(GenerateRequest {
@@ -1293,6 +1334,7 @@ mod tests {
12931334
max_input_length,
12941335
max_total_tokens,
12951336
disable_grammar_support,
1337+
1024 * 1024 * 1024, // 1GB
12961338
);
12971339

12981340
let chunks = match validation
@@ -1349,6 +1391,7 @@ mod tests {
13491391
max_input_length,
13501392
max_total_tokens,
13511393
disable_grammar_support,
1394+
1024 * 1024 * 1024, // 1GB
13521395
);
13531396

13541397
let (encoding, chunks) = match validation

0 commit comments

Comments
 (0)