Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::frontend::router::parser::aggregate::{Aggregate, AggregateFunction};
use pg_query::protobuf::{
a_const::Val, AConst, FuncCall, Integer, Node, ParseResult, ResTarget, String as PgString,
a_const::Val, AConst, FuncCall, Integer, Node, ResTarget, SelectStmt, String as PgString,
};
use pg_query::NodeEnum;

Expand All @@ -14,23 +14,11 @@ pub struct AggregatesRewrite;

impl AggregatesRewrite {
/// Rewrite a SELECT query in-place, adding helper aggregates when necessary.
pub fn rewrite_select(&self, ast: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput {
pub fn rewrite_select(&self, ast: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput {
self.rewrite_parsed(ast, aggregate)
}

fn rewrite_parsed(&self, parsed: &mut ParseResult, aggregate: &Aggregate) -> RewriteOutput {
let Some(raw_stmt) = parsed.stmts.first() else {
return RewriteOutput::default();
};

let Some(stmt) = raw_stmt.stmt.as_ref() else {
return RewriteOutput::default();
};

let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else {
return RewriteOutput::default();
};

fn rewrite_parsed(&self, select: &mut SelectStmt, aggregate: &Aggregate) -> RewriteOutput {
let mut plan = AggregateRewritePlan::new();
let mut helper_nodes: Vec<Node> = Vec::new();
let mut planned_aliases: Vec<String> = Vec::new();
Expand Down Expand Up @@ -133,18 +121,6 @@ impl AggregatesRewrite {
return RewriteOutput::default();
}

let Some(raw_stmt) = parsed.stmts.first_mut() else {
return RewriteOutput::default();
};

let Some(stmt) = raw_stmt.stmt.as_mut() else {
return RewriteOutput::default();
};

let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else {
return RewriteOutput::default();
};

select.target_list.extend(helper_nodes);

RewriteOutput::new(plan)
Expand Down Expand Up @@ -308,39 +284,38 @@ struct HelperSpec {
mod tests {
use super::*;
use crate::frontend::router::parser::aggregate::Aggregate;
use pg_query::protobuf::ParseResult;

fn select(ast: &ParseResult) -> &pg_query::protobuf::SelectStmt {
fn select(ast: &mut ParseResult) -> &mut pg_query::protobuf::SelectStmt {
match ast
.stmts
.first()
.and_then(|stmt| stmt.stmt.as_ref())
.and_then(|stmt| stmt.node.as_ref())
.first_mut()
.and_then(|stmt| stmt.stmt.as_mut())
.and_then(|stmt| stmt.node.as_mut())
{
Some(NodeEnum::SelectStmt(select)) => select,
Some(NodeEnum::SelectStmt(select)) => &mut *select,
_ => panic!("not a select"),
}
}

fn rewrite(sql: &str) -> (ParseResult, RewriteOutput) {
let mut parsed = pg_query::parse(sql).unwrap().protobuf;
let aggregate = {
let stmt = select(&parsed);
Aggregate::parse(stmt)
};
let output = AggregatesRewrite.rewrite_select(&mut parsed, &aggregate);
let stmt = select(&mut parsed);
let aggregate = Aggregate::parse(stmt);
let output = AggregatesRewrite.rewrite_select(stmt, &aggregate);
(parsed, output)
}

#[test]
fn rewrite_engine_noop() {
let (ast, output) = rewrite("SELECT COUNT(price) FROM menu");
let (mut ast, output) = rewrite("SELECT COUNT(price) FROM menu");
assert!(output.plan.is_noop());
assert_eq!(select(&ast).target_list.len(), 1);
assert_eq!(select(&mut ast).target_list.len(), 1);
}

#[test]
fn rewrite_engine_adds_helper() {
let (ast, output) = rewrite("SELECT AVG(price) FROM menu");
let (mut ast, output) = rewrite("SELECT AVG(price) FROM menu");
assert!(!output.plan.is_noop());
assert_eq!(output.plan.drop_columns(), &[1]);
assert_eq!(output.plan.helpers().len(), 1);
Expand All @@ -351,7 +326,7 @@ mod tests {
assert!(!helper.distinct);
assert!(matches!(helper.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&ast));
let aggregate = Aggregate::parse(select(&mut ast));
assert_eq!(aggregate.targets().len(), 2);
assert!(aggregate
.targets()
Expand All @@ -362,14 +337,14 @@ mod tests {
#[test]
fn rewrite_engine_skips_when_count_exists() {
let sql = "SELECT COUNT(price), AVG(price) FROM menu";
let (ast, output) = rewrite(sql);
let (mut ast, output) = rewrite(sql);
assert!(output.plan.is_noop());
assert_eq!(select(&ast).target_list.len(), 2);
assert_eq!(select(&mut ast).target_list.len(), 2);
}

#[test]
fn rewrite_engine_handles_mismatched_pair() {
let (ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu");
let (mut ast, output) = rewrite("SELECT COUNT(price::numeric), AVG(price) FROM menu");
assert_eq!(output.plan.drop_columns(), &[2]);
assert_eq!(output.plan.helpers().len(), 1);
let helper = &output.plan.helpers()[0];
Expand All @@ -378,7 +353,7 @@ mod tests {
assert!(!helper.distinct);
assert!(matches!(helper.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&ast));
let aggregate = Aggregate::parse(select(&mut ast));
assert_eq!(aggregate.targets().len(), 3);
assert!(
aggregate
Expand All @@ -392,7 +367,7 @@ mod tests {

#[test]
fn rewrite_engine_multiple_avg_helpers() {
let (ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu");
let (mut ast, output) = rewrite("SELECT AVG(price), AVG(discount) FROM menu");
assert_eq!(output.plan.drop_columns(), &[2, 3]);
assert_eq!(output.plan.helpers().len(), 2);

Expand All @@ -406,7 +381,7 @@ mod tests {
assert_eq!(helper_discount.helper_column, 3);
assert!(matches!(helper_discount.kind, HelperKind::Count));

let aggregate = Aggregate::parse(select(&ast));
let aggregate = Aggregate::parse(select(&mut ast));
assert_eq!(aggregate.targets().len(), 4);
assert_eq!(
aggregate
Expand All @@ -420,7 +395,7 @@ mod tests {

#[test]
fn rewrite_engine_stddev_helpers() {
let (ast, output) = rewrite("SELECT STDDEV(price) FROM menu");
let (mut ast, output) = rewrite("SELECT STDDEV(price) FROM menu");
assert!(!output.plan.is_noop());
assert_eq!(output.plan.drop_columns(), &[1, 2, 3]);
assert_eq!(output.plan.helpers().len(), 3);
Expand All @@ -440,6 +415,6 @@ mod tests {
assert!(kinds.contains(&HelperKind::SumSquares));

// Expect original STDDEV plus three helpers.
assert_eq!(select(&ast).target_list.len(), 4);
assert_eq!(select(&mut ast).target_list.len(), 4);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@ impl StatementRewrite<'_> {
return Ok(());
}

let Some(raw_stmt) = self.stmt.stmts.first() else {
let Some(raw_stmt) = self.stmt.stmts.first_mut() else {
return Ok(());
};

let Some(stmt) = raw_stmt.stmt.as_ref() else {
let Some(stmt) = raw_stmt.stmt.as_mut() else {
return Ok(());
};

let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_ref() else {
let Some(NodeEnum::SelectStmt(select)) = stmt.node.as_mut() else {
return Ok(());
};

let aggregate = Aggregate::parse(select);
let aggregate = Aggregate::parse(&select);
if aggregate.is_empty() {
return Ok(());
}

let output = AggregatesRewrite.rewrite_select(self.stmt, &aggregate);
let output = AggregatesRewrite.rewrite_select(select, &aggregate);
if output.plan.is_noop() {
return Ok(());
}
Expand Down
Loading