Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/toasty-core/src/schema/builder/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ impl BuildTableFromModels<'_> {

// Hax
for column in &mut self.table.columns {
if let stmt::Type::Enum(_) = column.ty {
if column.primary_key && matches!(column.ty, stmt::Type::Enum(_)) {
// this is a hack to support the internal representation of primary keys
// having this here blocks us from using enums as primary keys
column.ty = stmt::Type::String;
}
}
Expand Down Expand Up @@ -640,6 +642,7 @@ fn stmt_ty_to_table(ty: stmt::Type) -> stmt::Type {
stmt::Type::U64 => stmt::Type::U64,
stmt::Type::String => stmt::Type::String,
stmt::Type::Id(_) => stmt::Type::String,
stmt::Type::Enum(t) => stmt::Type::Enum(t),
_ => todo!("{ty:#?}"),
}
}
1 change: 1 addition & 0 deletions crates/toasty-core/src/schema/db/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ impl Type {
// TODO: not really correct, but we are getting rid of ID types
// most likely.
stmt::Type::Id(_) => Ok(db.default_string_type.clone()),
stmt::Type::Enum(_) => Ok(db.default_string_type.clone()),
_ => anyhow::bail!("unsupported type: {ty:?}"),
},
}
Expand Down
20 changes: 19 additions & 1 deletion crates/toasty-core/src/stmt/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,25 @@ impl Value {
_ => false,
},
Self::String(_) => ty.is_string(),
_ => todo!("value={self:#?}, ty={ty:#?}"),
Self::Enum(value) => match ty {
Type::Enum(type_enum) => {
if let Some(variant) = type_enum
.variants
.iter()
.find(|v| v.discriminant == value.variant)
{
variant.fields.len() == value.fields.len()
&& value
.fields
.iter()
.zip(variant.fields.iter())
.all(|(f, fty)| f.is_a(fty))
} else {
false
}
}
_ => false,
},
}
}

Expand Down
22 changes: 21 additions & 1 deletion crates/toasty-driver-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ fn value_from_param(value: &stmt::Value) -> rusqlite::types::ToSqlOutput<'_> {
}
}

fn sqlite_enum_to_value(value: &str, _ty: &stmt::Type) -> Result<stmt::Value> {
let Some((variant_str, fields_json)) = value.split_once("#") else {
todo!("value={value:#?}");
};

let variant = variant_str.parse::<usize>()?;

if !fields_json.eq("\"Null\"") {
todo!("value={value:#?}");
}

Ok(stmt::Value::Enum(stmt::ValueEnum {
variant,
fields: stmt::ValueRecord::default(),
}))
}

fn sqlite_to_toasty(row: &rusqlite::Row, index: usize, ty: &stmt::Type) -> stmt::Value {
use rusqlite::types::Value as SqlValue;

Expand All @@ -263,7 +280,10 @@ fn sqlite_to_toasty(row: &rusqlite::Row, index: usize, ty: &stmt::Type) -> stmt:
stmt::Type::U64 => stmt::Value::U64(value as u64),
_ => todo!("ty={ty:#?}"),
},
Some(SqlValue::Text(value)) => stmt::Value::String(value),
Some(SqlValue::Text(value)) => match ty {
stmt::Type::Enum(_) => sqlite_enum_to_value(&value, ty).unwrap(),
_ => stmt::Value::String(value),
},
None => stmt::Value::Null,
_ => todo!("value={value:#?}"),
}
Expand Down
1 change: 1 addition & 0 deletions crates/toasty-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ toasty-codegen.workspace = true
proc-macro2.workspace = true
quote.workspace = true
syn.workspace = true
anyhow.workspace = true
67 changes: 67 additions & 0 deletions crates/toasty-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use syn::{Fields, Variant};

#[proc_macro_derive(
Model,
Expand All @@ -14,6 +16,71 @@ pub fn derive_model(input: TokenStream) -> TokenStream {
}
}

fn compute_type_variant(enum_ident: &Ident, variant: &Variant) -> proc_macro2::TokenStream {
let ident = &variant.ident;
let fields = match variant.fields {
Fields::Unit => quote!(Vec::new()),
_ => todo!("fields unsupported vor enum #{ident}"),
};
quote!(EnumVariant { discriminant: #enum_ident::#ident as usize, fields: #fields })
}

fn discriminant_match(enum_ident: &Ident, variant: &Variant) -> proc_macro2::TokenStream {
let ident = &variant.ident;
quote!(
if v == #enum_ident::#ident as usize {
return Ok(#enum_ident::#ident);
}
)
}

#[proc_macro_derive(ToastyEnum)]
pub fn derive_enum(input: TokenStream) -> TokenStream {
let item: syn::ItemEnum = syn::parse(input).unwrap();
let ident = &item.ident;
let compute_type_variants = item.variants.iter().map(|v| compute_type_variant(ident, v));
let discriminant_matches = item.variants.iter().map(|v| discriminant_match(ident, v));
quote! {
const _: () = {
use anyhow::Result;
use toasty::{self, stmt::{Primitive, IntoExpr, Expr}};
use toasty_core::{self, stmt::{self, EnumVariant, TypeEnum, ValueRecord}};
impl Primitive for #ident {
fn ty() -> stmt::Type {
stmt::Type::Enum(TypeEnum {
variants: vec![#(#compute_type_variants,)*]
})
}

fn load(value: stmt::Value) -> Result<Self> {
let stmt::Value::Enum(value_enum) = value else {
anyhow::bail!("not an enum: #{value:#?}");
};

let v = value_enum.variant;
#(#discriminant_matches)*
anyhow::bail!("not matching any discriminant: #{v}");
}
}

impl IntoExpr<#ident> for #ident {
fn into_expr(self) -> Expr<#ident> {
let variant = self as usize;
Expr::from_untyped(stmt::Expr::Value(stmt::Value::Enum(stmt::ValueEnum {
variant,
fields: ValueRecord { fields: Vec::new() }
})))
}

fn by_ref(&self) -> Expr<#ident> {
todo!()
}
}
};
}
.into()
}

#[proc_macro]
pub fn include_schema(_input: TokenStream) -> TokenStream {
todo!()
Expand Down
9 changes: 2 additions & 7 deletions crates/toasty/src/engine/planner/kv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,14 @@ impl Planner<'_> {
index_plan: &mut IndexPlan<'_>,
input: Option<plan::Input>,
) -> plan::Input {
let key_ty = self.index_key_ty(index_plan.index);
let key_ty = self.pk_ty_for_index(index_plan.index);
let pk_by_index_out = self
.var_table
.register_var(stmt::Type::list(key_ty.clone()));

// In this case, we have to flatten the returned record into a single value
let project_key = if index_plan.index.columns.len() == 1 {
let arg_ty = stmt::Type::Record(vec![self
.schema
.db
.column(index_plan.index.columns[0].column)
.ty
.clone()]);
let arg_ty = stmt::Type::Record(vec![key_ty.clone()]);

eval::Func::from_stmt_unchecked(
stmt::Expr::arg_project(0, [0]),
Expand Down
14 changes: 14 additions & 0 deletions crates/toasty/src/engine/planner/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ impl Planner<'_> {
}
*/

pub(crate) fn pk_ty_for_index(&self, index: &Index) -> stmt::Type {
let table = self.schema.db.table(index.id.table);
if table.primary_key.columns.len() == 1 {
table.primary_key_column(0).ty.clone()
} else {
stmt::Type::Record(
table
.primary_key_columns()
.map(|id| id.ty.clone())
.collect(),
)
}
}

pub(crate) fn index_key_ty(&self, index: &Index) -> stmt::Type {
match &index.columns[..] {
[id] => self.schema.db.column(id.column).ty.clone(),
Expand Down
2 changes: 1 addition & 1 deletion crates/toasty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub mod schema;
pub mod stmt;
pub use stmt::Statement;

pub use toasty_macros::{create, query, Model};
pub use toasty_macros::{create, query, Model, ToastyEnum};

pub use anyhow::{Error, Result};

Expand Down
1 change: 1 addition & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ rusqlite.workspace = true
std-util.workspace = true
tempfile.workspace = true
trybuild.workspace = true
anyhow.workspace = true
env_logger = "0.11.8"

[dev-dependencies]
54 changes: 54 additions & 0 deletions tests/tests/index_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use tests::{models, tests, DbTest};
use toasty::{self, stmt::Id, ToastyEnum};

async fn index_enum(test: &mut DbTest) {
#[derive(toasty::Model)]
struct LogEntry {
#[key]
#[auto]
request_id: Id<Self>,

#[index]
log_level: LogLevel,

message: String,
}

#[derive(Debug, PartialEq, ToastyEnum, Clone)]
enum LogLevel {
Debug,
Info,
Warn,
Error,
}

let db = test.setup_db(models!(LogEntry)).await;

{
use LogLevel::{Debug, Error, Warn};
for (log_level, message) in [
(Debug, "initializing"),
(Warn, "something fishy"),
(Error, "null pointer"),
] {
LogEntry::create()
.log_level(log_level)
.message(message)
.exec(&db)
.await
.expect("failed to create entry");
}

let res = LogEntry::filter_by_log_level(Warn).all(&db).await.unwrap();
let entries = res.collect::<Vec<LogEntry>>().await.unwrap();
assert_eq!(
vec![(Warn, "something fishy".to_string())],
entries
.iter()
.map(|le| (le.log_level.clone(), le.message.clone()))
.collect::<Vec<(LogLevel, String)>>()
);
}
}

tests!(index_enum);