Skip to content

Commit b3094f2

Browse files
authored
Generation with detection integration tests (foundation-model-stack#345)
Signed-off-by: Mateus Devino <[email protected]>
1 parent 54d9340 commit b3094f2

File tree

3 files changed

+444
-23
lines changed

3 files changed

+444
-23
lines changed

tests/classification_with_text_gen.rs

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async fn no_detections() -> Result<(), anyhow::Error> {
157157
let mock_detector_server =
158158
MockServer::new(DETECTOR_NAME_ANGLE_BRACKETS_SENTENCE).with_mocks(detector_mocks);
159159
let mock_generation_server = MockServer::new("nlp").grpc().with_mocks(generation_mocks);
160-
let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE.into())
160+
let mock_chunker_server = MockServer::new(CHUNKER_NAME_SENTENCE)
161161
.grpc()
162162
.with_mocks(chunker_mocks);
163163

@@ -257,7 +257,7 @@ async fn no_detections() -> Result<(), anyhow::Error> {
257257
#[test(tokio::test)]
258258
async fn input_detector_detections() -> Result<(), anyhow::Error> {
259259
// Add tokenization results mock for input detections scenarios
260-
let tokenization_results = vec![
260+
let tokenization_results = [
261261
Token {
262262
start: 0,
263263
end: 40,
@@ -276,7 +276,7 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
276276
];
277277

278278
// Add tokenization mock responses for input detections
279-
let mock_tokenization_responses = vec![
279+
let mock_tokenization_responses = [
280280
TokenizationResults {
281281
results: Vec::new(),
282282
token_count: 61,
@@ -320,7 +320,6 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
320320
when.path(GENERATION_NLP_TOKENIZATION_ENDPOINT)
321321
.pb(TokenizationTaskRequest {
322322
text: "This sentence does not have a detection. But <this one does>.".into(),
323-
..Default::default()
324323
});
325324
then.pb(mock_tokenization_responses[0].clone());
326325
});
@@ -330,7 +329,6 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
330329
when.path(GENERATION_NLP_TOKENIZATION_ENDPOINT)
331330
.pb(TokenizationTaskRequest {
332331
text: "This sentence does not have a detection. But <this one does>. Also <this other one>.".into(),
333-
..Default::default()
334332
});
335333
then.pb(mock_tokenization_responses[1].clone());
336334
});
@@ -443,13 +441,13 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
443441
results.token_classification_results,
444442
TextGenTokenClassificationResults {
445443
input: Some(vec![TokenClassificationResult {
446-
start: 46 as u32,
447-
end: 59 as u32,
444+
start: 46_u32,
445+
end: 59_u32,
448446
word: expected_detections[0].text.clone(),
449447
entity: expected_detections[0].detection.clone(),
450448
entity_group: expected_detections[0].detection_type.clone(),
451449
detector_id: expected_detections[0].detector_id.clone(),
452-
score: expected_detections[0].score.clone(),
450+
score: expected_detections[0].score,
453451
token_count: None
454452
}]),
455453
output: None
@@ -494,8 +492,8 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
494492
TextGenTokenClassificationResults {
495493
input: Some(vec![
496494
TokenClassificationResult {
497-
start: 46 as u32,
498-
end: 59 as u32,
495+
start: 46_u32,
496+
end: 59_u32,
499497
word: expected_detections[0].text.clone(),
500498
entity: expected_detections[0].detection.clone(),
501499
entity_group: expected_detections[0].detection_type.clone(),
@@ -504,8 +502,8 @@ async fn input_detector_detections() -> Result<(), anyhow::Error> {
504502
token_count: None
505503
},
506504
TokenClassificationResult {
507-
start: 68 as u32,
508-
end: 82 as u32,
505+
start: 68_u32,
506+
end: 82_u32,
509507
word: expected_detections[1].text.clone(),
510508
entity: expected_detections[1].detection.clone(),
511509
entity_group: expected_detections[1].detection_type.clone(),
@@ -566,7 +564,6 @@ async fn input_detector_client_error() -> Result<(), anyhow::Error> {
566564
when.path(GENERATION_NLP_TOKENIZATION_ENDPOINT)
567565
.pb(TokenizationTaskRequest {
568566
text: generation_server_error_input.into(),
569-
..Default::default()
570567
});
571568
then.internal_server_error();
572569
});
@@ -926,8 +923,8 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> {
926923
TextGenTokenClassificationResults {
927924
input: None,
928925
output: Some(vec![TokenClassificationResult {
929-
start: 46 as u32,
930-
end: 59 as u32,
926+
start: 46_u32,
927+
end: 59_u32,
931928
word: expected_detections[0].text.clone(),
932929
entity: expected_detections[0].detection.clone(),
933930
entity_group: expected_detections[0].detection_type.clone(),
@@ -971,8 +968,8 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> {
971968
input: None,
972969
output: Some(vec![
973970
TokenClassificationResult {
974-
start: 46 as u32,
975-
end: 59 as u32,
971+
start: 46_u32,
972+
end: 59_u32,
976973
word: expected_detections[0].text.clone(),
977974
entity: expected_detections[0].detection.clone(),
978975
entity_group: expected_detections[0].detection_type.clone(),
@@ -981,8 +978,8 @@ async fn output_detector_detections() -> Result<(), anyhow::Error> {
981978
token_count: None
982979
},
983980
TokenClassificationResult {
984-
start: 68 as u32,
985-
end: 82 as u32,
981+
start: 68_u32,
982+
end: 82_u32,
986983
word: expected_detections[1].text.clone(),
987984
entity: expected_detections[1].detection.clone(),
988985
entity_group: expected_detections[1].detection_type.clone(),
@@ -1045,7 +1042,6 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> {
10451042
.header(GENERATION_NLP_MODEL_ID_HEADER_NAME, MODEL_ID)
10461043
.pb(TokenizationTaskRequest {
10471044
text: generation_server_error_input.into(),
1048-
..Default::default()
10491045
});
10501046
then.internal_server_error();
10511047
});
@@ -1056,7 +1052,6 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> {
10561052
.header(GENERATION_NLP_MODEL_ID_HEADER_NAME, MODEL_ID)
10571053
.pb(TokenizationTaskRequest {
10581054
text: detector_error_input.into(),
1059-
..Default::default()
10601055
});
10611056
then.pb(GeneratedTextResult {
10621057
generated_text: detector_error_input.into(),
@@ -1070,7 +1065,6 @@ async fn output_detector_client_error() -> Result<(), anyhow::Error> {
10701065
.header(GENERATION_NLP_MODEL_ID_HEADER_NAME, MODEL_ID)
10711066
.pb(TokenizationTaskRequest {
10721067
text: chunker_error_input.into(),
1073-
..Default::default()
10741068
});
10751069
then.pb(GeneratedTextResult {
10761070
generated_text: chunker_error_input.into(),

tests/common/orchestrator.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@ use url::Url;
3838
pub const ORCHESTRATOR_CONFIG_FILE_PATH: &str = "tests/test_config.yaml";
3939

4040
// Endpoints
41+
pub const ORCHESTRATOR_UNARY_ENDPOINT: &str = "/api/v1/task/classification-with-text-generation";
4142
pub const ORCHESTRATOR_STREAMING_ENDPOINT: &str =
4243
"/api/v1/task/server-streaming-classification-with-text-generation";
43-
pub const ORCHESTRATOR_UNARY_ENDPOINT: &str = "/api/v1/task/classification-with-text-generation";
44+
pub const ORCHESTRATOR_GENERATION_WITH_DETECTION_ENDPOINT: &str =
45+
"/api/v2/text/generation-detection";
46+
4447
pub const ORCHESTRATOR_CONTENT_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/content";
4548
pub const ORCHESTRATOR_DETECTION_ON_GENERATION_ENDPOINT: &str = "/api/v2/text/detection/generated";
4649
pub const ORCHESTRATOR_CONTEXT_DOCS_DETECTION_ENDPOINT: &str = "/api/v2/text/detection/context";

0 commit comments

Comments
 (0)