Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 32 additions & 144 deletions pgdog/src/backend/pool/connection/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ mod test {
Decoder,
};
use bytes::Bytes;
use pg_query::{protobuf::SelectStmt, NodeEnum};
use std::collections::VecDeque;

#[test]
Expand Down Expand Up @@ -787,22 +788,30 @@ mod test {
}
}

fn select(stmt: &str) -> SelectStmt {
let stmt = pg_query::parse(stmt)
.unwrap()
.protobuf
.stmts
.remove(0)
.stmt
.unwrap();
match stmt.node.unwrap() {
NodeEnum::SelectStmt(stmt) => *stmt,
_ => panic!("not a select"),
}
}

fn parse(stmt: &str) -> Aggregate {
Aggregate::parse(&select(stmt))
}

#[test]
fn aggregate_count_with_int_typecast() {
// Regression test for https://github.com/pgdogdev/pgdog/issues/861
// SELECT COUNT(*)::int returns int4 from each shard; the accumulator
// must merge the per-shard values and preserve the requested type.
let stmt = pg_query::parse("SELECT COUNT(*)::int FROM users")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT COUNT(*)::int FROM users");

let rd = RowDescription::new(&[integer_field("count")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -830,17 +839,7 @@ mod test {
#[test]
fn aggregate_count_default_bigint() {
// SELECT COUNT(*) (no cast) should still merge correctly and stay bigint.
let stmt = pg_query::parse("SELECT COUNT(*) FROM users")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT COUNT(*) FROM users");

let rd = RowDescription::new(&[Field::bigint("count")]);
let decoder = Decoder::from(&rd);
Expand All @@ -866,17 +865,7 @@ mod test {

#[test]
fn aggregate_merges_avg_with_count() {
let stmt = pg_query::parse("SELECT COUNT(price), AVG(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT COUNT(price), AVG(price) FROM menu");

let rd = RowDescription::new(&[Field::bigint("count"), Field::double("avg")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -906,17 +895,7 @@ mod test {

#[test]
fn aggregate_avg_without_count_passthrough() {
let stmt = pg_query::parse("SELECT AVG(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT AVG(price) FROM menu");

let rd = RowDescription::new(&[Field::double("avg")]);
let decoder = Decoder::from(&rd);
Expand All @@ -943,17 +922,7 @@ mod test {

#[test]
fn aggregate_avg_with_rewrite_helper() {
let stmt = pg_query::parse("SELECT AVG(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT AVG(price) FROM menu");

let rd = RowDescription::new(&[
Field::double("avg"),
Expand Down Expand Up @@ -992,17 +961,7 @@ mod test {

#[test]
fn aggregate_multiple_avg_with_helpers() {
let stmt = pg_query::parse("SELECT AVG(price), AVG(discount) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT AVG(price), AVG(discount) FROM menu");

let rd = RowDescription::new(&[
Field::double("avg_price"),
Expand Down Expand Up @@ -1055,17 +1014,7 @@ mod test {

#[test]
fn aggregate_stddev_samp_with_helpers() {
let stmt = pg_query::parse("SELECT STDDEV(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT STDDEV(price) FROM menu");

let rd = RowDescription::new(&[
Field::double("stddev_price"),
Expand Down Expand Up @@ -1132,17 +1081,7 @@ mod test {

#[test]
fn aggregate_var_pop_with_helpers() {
let stmt = pg_query::parse("SELECT VAR_POP(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT VAR_POP(price) FROM menu");

let rd = RowDescription::new(&[
Field::double("var_price"),
Expand Down Expand Up @@ -1201,17 +1140,7 @@ mod test {

#[test]
fn aggregate_distinct_count_not_paired() {
let stmt = pg_query::parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT COUNT(DISTINCT price), AVG(price) FROM menu");

let rd = RowDescription::new(&[Field::bigint("count"), Field::double("avg")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -1239,17 +1168,7 @@ mod test {

#[test]
fn aggregate_errors_when_helper_alias_missing() {
let stmt = pg_query::parse("SELECT AVG(price) FROM menu")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT AVG(price) FROM menu");

let rd = RowDescription::new(&[Field::double("avg")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -1282,17 +1201,7 @@ mod test {

#[test]
fn aggregate_group_by_merges_rows() {
let stmt = pg_query::parse("SELECT price, SUM(quantity) FROM menu GROUP BY 1")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT price, SUM(quantity) FROM menu GROUP BY 1");

let rd = RowDescription::new(&[Field::double("price"), Field::bigint("sum")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -1333,17 +1242,7 @@ mod test {

#[test]
fn aggregate_group_by_multidimensional_arrays_uses_raw_bytes() {
let stmt = pg_query::parse("SELECT matrix, COUNT(*) FROM samples GROUP BY 1")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT matrix, COUNT(*) FROM samples GROUP BY 1");

let rd = RowDescription::new(&[integer_array_field("matrix"), Field::bigint("count")]);
let decoder = Decoder::from(&rd);
Expand Down Expand Up @@ -1389,18 +1288,7 @@ mod test {

#[test]
fn aggregate_group_by_interval_arrays_preserves_postgres_text_output() {
let stmt =
pg_query::parse("SELECT sample_interval_array, COUNT(*) FROM samples GROUP BY 1")
.unwrap()
.protobuf
.stmts
.first()
.cloned()
.unwrap();
let aggregate = match stmt.stmt.unwrap().node.unwrap() {
pg_query::NodeEnum::SelectStmt(stmt) => Aggregate::parse(&stmt),
_ => panic!("expected select stmt"),
};
let aggregate = parse("SELECT sample_interval_array, COUNT(*) FROM samples GROUP BY 1");

let rd = RowDescription::new(&[
interval_array_field("sample_interval_array"),
Expand Down
Loading
Loading