Skip to content

Commit 0a89902

Browse files
authored
[TRTLLM] Expose finish reason (#2841)
* feat(trtllm): expose finish reason to Rust * misc(llamacpp): fix typo * misc(backend): update deps
1 parent 4e17202 commit 0a89902

File tree

4 files changed

+15
-36
lines changed

4 files changed

+15
-36
lines changed

backends/trtllm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ option(TGI_TRTLLM_BACKEND_BUILD_TESTS "Enable building the unittests suite" OFF)
3030
option(TGI_TRTLLM_BACKEND_BUILD_EXAMPLES "Enable building the examples suite" OFF)
3131
option(TGI_TRTLLM_BACKEND_BUILD_USE_LLD "Enable lld linker instead of ld" OFF)
3232
set(TGI_TRTLLM_BACKEND_TARGET_CUDA_ARCH_LIST "89-real" CACHE STRING "List of CUDA architectures to support")
33-
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path rgo where TensorRT libraries and headers are located")
33+
set(TGI_TRTLLM_BACKEND_TRT_ROOT "/usr/local/tensorrt" CACHE STRING "Path where TensorRT libraries and headers are located")
3434
set(TGI_TRTLLM_BACKEND_TRT_INCLUDE_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/include" CACHE STRING "Path where TensorRT headers are located")
3535
set(TGI_TRTLLM_BACKEND_TRT_LIB_DIR "${TGI_TRTLLM_BACKEND_TRT_ROOT}/lib" CACHE STRING "Path where TensorRT libraries are located")
3636

backends/trtllm/Cargo.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,16 @@ homepage.workspace = true
77

88
[dependencies]
99
async-trait = "0.1"
10-
#async-stream = "0.3"
1110
clap = { version = "4.5", features = ["derive"] }
1211
cxx = "1.0"
13-
hashbrown = "0.14"
12+
hashbrown = "0.15"
1413
hf-hub = { workspace = true }
15-
#log = { version = "0.4", features = [] }
1614
text-generation-router = { path = "../../router" }
1715
tokenizers = { workspace = true }
18-
tokio = { version = "1.39", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
19-
tokio-stream = "0.1.15"
16+
tokio = { version = "1.43.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
17+
tokio-stream = "0.1.17"
2018
thiserror = "1.0.63"
2119
tracing = "0.1"
22-
#tracing-opentelemetry = "0.25"
23-
#tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] }
2420
pyo3 = { workspace = true }
2521

2622
[build-dependencies]

backends/trtllm/src/looper.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ use text_generation_router::validation::ValidationError::{
1818
EmptyInput, Grammar, TopNTokensDisabled, UnsupportedModality,
1919
};
2020
use text_generation_router::validation::{Chunk, ValidGenerateRequest};
21-
use text_generation_router::{FinishReason, Token};
21+
use text_generation_router::Token;
2222

2323
use crate::errors::TensorRtLlmBackendError;
24-
use crate::ffi::{create_backend_from_engine_folder, GenerationStep, TensorRtLlmBackendImpl};
24+
use crate::ffi::{
25+
create_backend_from_engine_folder, FinishReason, GenerationStep, TensorRtLlmBackendImpl,
26+
};
2527
use crate::utils::first_line;
2628

2729
type InferResult<T> = Result<T, InferError>;
@@ -40,6 +42,7 @@ struct DecodedToken {
4042
id: u32,
4143
log_prob: f32,
4244
is_final: bool,
45+
finish_reason: FinishReason,
4346
}
4447

4548
impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
@@ -51,6 +54,7 @@ impl<'step> TryFrom<&'step GenerationStep> for DecodedToken {
5154
id: step.token_id,
5255
log_prob: step.log_prob,
5356
is_final: step.is_final,
57+
finish_reason: step.finish_reason,
5458
})
5559
} else {
5660
Err(GenerationError(step.error_msg.clone()))
@@ -192,7 +196,7 @@ fn post_process_decoded_token(
192196
let generated_text = GeneratedText {
193197
text: text.unwrap(),
194198
generated_tokens: ctx.tokens.len() as u32,
195-
finish_reason: FinishReason::EndOfSequenceToken, // TODO : Map FinishReason
199+
finish_reason: decoded_token.finish_reason.into(),
196200
seed: None,
197201
};
198202

backends/trtllm/src/main.rs

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,7 @@ struct Args {
6767
payload_limit: usize,
6868
}
6969

70-
async fn get_tokenizer(
71-
tokenizer_name: &str,
72-
_tokenizer_config_path: Option<&str>,
73-
revision: Option<&str>,
74-
) -> Option<Tokenizer> {
70+
async fn get_tokenizer(tokenizer_name: &str, revision: Option<&str>) -> Option<Tokenizer> {
7571
// Parse Huggingface hub token
7672
let authorization_token = std::env::var("HF_TOKEN")
7773
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
@@ -182,19 +178,6 @@ async fn get_tokenizer(
182178
}
183179
};
184180

185-
// Read the JSON contents of the file as an instance of 'HubTokenizerConfig'.
186-
// let tokenizer_config: Option<HubTokenizerConfig> = if let Some(filename) = tokenizer_config_path
187-
// {
188-
// HubTokenizerConfig::from_file(filename)
189-
// } else {
190-
// tokenizer_config_filename.and_then(HubTokenizerConfig::from_file)
191-
// };
192-
193-
// let tokenizer_config = tokenizer_config.unwrap_or_else(|| {
194-
// tracing::warn!("Could not find tokenizer config locally and no API specified");
195-
// HubTokenizerConfig::default()
196-
// });
197-
198181
let tokenizer: Tokenizer = {
199182
use pyo3::prelude::*;
200183
pyo3::Python::with_gil(|py| -> PyResult<()> {
@@ -292,13 +275,9 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
292275
}
293276

294277
// Create the backend
295-
match get_tokenizer(
296-
&tokenizer_name,
297-
tokenizer_config_path.as_deref(),
298-
revision.as_deref(),
299-
)
300-
.await
301-
.expect("Failed to retrieve tokenizer implementation")
278+
match get_tokenizer(&tokenizer_name, revision.as_deref())
279+
.await
280+
.expect("Failed to retrieve tokenizer implementation")
302281
{
303282
Tokenizer::Python { .. } => Err(TensorRtLlmBackendError::Tokenizer(
304283
"Failed to retrieve Rust based tokenizer".to_string(),

0 commit comments

Comments
 (0)