From 1261db5f2293d5af93b5f3c506af37710a7740ba Mon Sep 17 00:00:00 2001 From: Ori Ziv Date: Sat, 13 Jun 2026 18:18:33 +0300 Subject: [PATCH] refactor(sqrt): Make sqrt libfuncs runtime-bindings based. Co-Authored-By: Claude Opus 4.8 (1M context) --- src/libfuncs/int.rs | 155 ++--------------- src/libfuncs/uint256.rs | 274 ++----------------------------- src/metadata/runtime_bindings.rs | 221 ++++++++++++++++--------- src/runtime.rs | 37 +++++ tests/tests/uint.rs | 31 ++++ 5 files changed, 235 insertions(+), 483 deletions(-) diff --git a/src/libfuncs/int.rs b/src/libfuncs/int.rs index c8f1e5e478..502c3e6d42 100644 --- a/src/libfuncs/int.rs +++ b/src/libfuncs/int.rs @@ -3,15 +3,14 @@ use crate::{ error::{panic::ToNativeAssertError, Result}, execution_result::BITWISE_BUILTIN_SIZE, libfuncs::{increment_builtin_counter, increment_builtin_counter_by}, - metadata::MetadataStorage, - native_panic, + metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage}, types::TypeBuilder, utils::{ProgramRegistryExt, PRIME}, }; use cairo_lang_sierra::{ extensions::{ bounded_int::BoundedIntDivRemAlgorithm, - core::{CoreLibfunc, CoreType, CoreTypeConcrete}, + core::{CoreLibfunc, CoreType}, int::{ signed::{SintConcrete, SintTraits}, signed128::Sint128Concrete, @@ -35,8 +34,8 @@ use melior::{ }, helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt}, ir::{ - attribute::IntegerAttribute, operation::OperationBuilder, r#type::IntegerType, Block, - BlockLike, Location, Region, ValueLike, + operation::OperationBuilder, r#type::IntegerType, Block, BlockLike, Location, Region, + ValueLike, }, Context, }; @@ -689,12 +688,12 @@ fn build_operation<'ctx, 'this>( fn build_square_root<'ctx, 'this>( context: &'ctx Context, - registry: &ProgramRegistry, + _registry: &ProgramRegistry, entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, - _metadata: &mut MetadataStorage, - info: &SignatureOnlyConcreteLibfunc, + metadata: &mut MetadataStorage, + _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { // The sierra-to-casm compiler uses the range_check builtin 4 times. // https://github.com/starkware-libs/cairo/blob/v2.12.0-dev.1/crates/cairo-lang-sierra-to-casm/src/invocations/int/unsigned.rs#L73 @@ -702,144 +701,10 @@ fn build_square_root<'ctx, 'this>( super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 4)?; let input = entry.arg(1)?; - let (input_bits, value_bits) = - match registry.get_type(&info.signature.param_signatures[1].ty)? { - CoreTypeConcrete::Uint8(_) => (8, 8), - CoreTypeConcrete::Uint16(_) => (16, 8), - CoreTypeConcrete::Uint32(_) => (32, 16), - CoreTypeConcrete::Uint64(_) => (64, 32), - CoreTypeConcrete::Uint128(_) => (128, 64), - _ => native_panic!("invalid value type in int square root"), - }; - - let k1 = entry.const_int(context, location, 1, input_bits)?; - let is_small = entry.cmpi(context, CmpiPredicate::Ule, input, k1, location)?; - - let value = entry.append_op_result(scf::r#if( - is_small, - &[IntegerType::new(context, value_bits).into()], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - - let value = block.trunci( - input, - IntegerType::new(context, value_bits).into(), - location, - )?; - - block.append_operation(scf::r#yield(&[value], location)); - region - }, - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - let leading_zeros = block.append_op_result( - ods::llvm::intr_ctlz( - context, - IntegerType::new(context, input_bits).into(), - input, - IntegerAttribute::new(IntegerType::new(context, 1).into(), 1), - location, - ) - .into(), - )?; - - let k_bits = block.const_int(context, location, input_bits, input_bits)?; - let num_bits = block.append_op_result(arith::subi(k_bits, leading_zeros, location))?; - let shift_amount = block.addi(num_bits, k1, location)?; - - let parity_mask = block.const_int(context, location, -2, input_bits)?; - let shift_amount = - block.append_op_result(arith::andi(shift_amount, parity_mask, location))?; - - let k0 = block.const_int(context, location, 0, input_bits)?; - let value = block.append_op_result(scf::r#while( - &[k0, shift_amount], - &[ - IntegerType::new(context, input_bits).into(), - IntegerType::new(context, input_bits).into(), - ], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[ - (IntegerType::new(context, input_bits).into(), location), - (IntegerType::new(context, input_bits).into(), location), - ])); - - let value = block.shli(block.arg(0)?, k1, location)?; - let large_candidate = - block.append_op_result(arith::xori(value, k1, location))?; - let large_candidate_squared = - block.muli(large_candidate, large_candidate, location)?; - - let threshold = block.shrui(input, block.arg(1)?, location)?; - let threshold_is_poison = - block.cmpi(context, CmpiPredicate::Eq, block.arg(1)?, k_bits, location)?; - let threshold = block.append_op_result(arith::select( - threshold_is_poison, - k0, - threshold, - location, - ))?; - - let is_in_range = block.cmpi( - context, - CmpiPredicate::Ule, - large_candidate_squared, - threshold, - location, - )?; - let value = block.append_op_result(arith::select( - is_in_range, - large_candidate, - value, - location, - ))?; - - let k2 = block.const_int(context, location, 2, input_bits)?; - let shift_amount = - block.append_op_result(arith::subi(block.arg(1)?, k2, location))?; - - let should_continue = - block.cmpi(context, CmpiPredicate::Sge, shift_amount, k0, location)?; - block.append_operation(scf::condition( - should_continue, - &[value, shift_amount], - location, - )); - - region - }, - { - let region = Region::new(); - let block = region.append_block(Block::new(&[ - (IntegerType::new(context, input_bits).into(), location), - (IntegerType::new(context, input_bits).into(), location), - ])); - - block.append_operation(scf::r#yield(&[block.arg(0)?, block.arg(1)?], location)); - region - }, - location, - ))?; - - let value = if input_bits == value_bits { - value - } else { - block.trunci( - value, - IntegerType::new(context, value_bits).into(), - location, - )? - }; - - block.append_operation(scf::r#yield(&[value], location)); - region - }, - location, - ))?; + let runtime_bindings = metadata.get_or_insert_with(RuntimeBindingsMeta::default); + let value = + runtime_bindings.integer_square_root(context, helper.module, entry, location, input)?; helper.br(entry, 0, &[range_check, value], location) } diff --git a/src/libfuncs/uint256.rs b/src/libfuncs/uint256.rs index cddf6e9e9c..0e526cdbe3 100644 --- a/src/libfuncs/uint256.rs +++ b/src/libfuncs/uint256.rs @@ -21,14 +21,13 @@ use cairo_lang_sierra::{ use melior::{ dialect::{ arith::{self, CmpiPredicate}, - cf, llvm, ods, scf, + cf, llvm, }, helpers::{BuiltinBlockExt, LlvmBlockExt}, ir::{ attribute::{DenseI64ArrayAttribute, IntegerAttribute}, - operation::OperationBuilder, r#type::IntegerType, - Block, BlockLike, Location, Region, Value, + Block, BlockLike, Location, Value, }, Context, }; @@ -358,7 +357,7 @@ pub fn build_square_root<'ctx, 'this>( entry: &'this Block<'ctx>, location: Location<'ctx>, helper: &LibfuncHelper<'ctx, 'this>, - _metadata: &mut MetadataStorage, + metadata: &mut MetadataStorage, _info: &SignatureOnlyConcreteLibfunc, ) -> Result<()> { // The sierra-to-casm compiler uses the range check builtin a total of 7 times. @@ -367,7 +366,6 @@ pub fn build_square_root<'ctx, 'this>( super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 7)?; let i128_ty = IntegerType::new(context, 128).into(); - let i256_ty = IntegerType::new(context, 256).into(); let arg_struct = entry.arg(1)?; let arg_lo = entry @@ -391,261 +389,17 @@ pub fn build_square_root<'ctx, 'this>( .result(0)? .into(); - let arg_lo = entry - .append_operation(arith::extui(arg_lo, i256_ty, location)) - .result(0)? - .into(); - let arg_hi = entry - .append_operation(arith::extui(arg_hi, i256_ty, location)) - .result(0)? - .into(); - - let k128 = entry - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, 128).into(), - location, - )) - .result(0)? - .into(); - let arg_hi = entry - .append_operation(arith::shli(arg_hi, k128, location)) - .result(0)? - .into(); - - let arg_value = entry - .append_operation(arith::ori(arg_hi, arg_lo, location)) - .result(0)? - .into(); - - let k1 = entry - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, 1).into(), - location, - )) - .result(0)? - .into(); - - let is_small = entry - .append_operation(arith::cmpi( - context, - CmpiPredicate::Ule, - arg_value, - k1, - location, - )) - .result(0)? - .into(); - - let result = entry - .append_operation(scf::r#if( - is_small, - &[i256_ty], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - - block.append_operation(scf::r#yield(&[arg_value], location)); - - region - }, - { - let region = Region::new(); - let block = region.append_block(Block::new(&[])); - - let k128 = entry - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, 256).into(), - location, - )) - .result(0)? - .into(); - - let leading_zeros = block - .append_operation( - ods::llvm::intr_ctlz( - context, - i256_ty, - arg_value, - IntegerAttribute::new(IntegerType::new(context, 1).into(), 1), - location, - ) - .into(), - ) - .result(0)? - .into(); - - let num_bits = block - .append_operation(arith::subi(k128, leading_zeros, location)) - .result(0)? - .into(); - - let shift_amount = block - .append_operation(arith::addi(num_bits, k1, location)) - .result(0)? - .into(); - - let parity_mask = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, -2).into(), - location, - )) - .result(0)? - .into(); - let shift_amount = block - .append_operation(arith::andi(shift_amount, parity_mask, location)) - .result(0)? - .into(); - - let k0 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, 0).into(), - location, - )) - .result(0)? - .into(); - let result = block - .append_operation(scf::r#while( - &[k0, shift_amount], - &[i256_ty, i256_ty], - { - let region = Region::new(); - let block = region.append_block(Block::new(&[ - (i256_ty, location), - (i256_ty, location), - ])); - - let result = block - .append_operation(arith::shli(block.arg(0)?, k1, location)) - .result(0)? - .into(); - let large_candidate = block - .append_operation(arith::xori(result, k1, location)) - .result(0)? - .into(); - - let large_candidate_squared = block - .append_operation(arith::muli( - large_candidate, - large_candidate, - location, - )) - .result(0)? - .into(); - - let threshold = block - .append_operation(arith::shrui(arg_value, block.arg(1)?, location)) - .result(0)? - .into(); - let threshold_is_poison = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Eq, - block.arg(1)?, - k128, - location, - )) - .result(0)? - .into(); - let threshold = block - .append_operation( - OperationBuilder::new("arith.select", location) - .add_operands(&[threshold_is_poison, k0, threshold]) - .add_results(&[i256_ty]) - .build()?, - ) - .result(0)? - .into(); - - let is_in_range = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Ule, - large_candidate_squared, - threshold, - location, - )) - .result(0)? - .into(); - - let result = block - .append_operation( - OperationBuilder::new("arith.select", location) - .add_operands(&[is_in_range, large_candidate, result]) - .add_results(&[i256_ty]) - .build()?, - ) - .result(0)? - .into(); - - let k2 = block - .append_operation(arith::constant( - context, - IntegerAttribute::new(i256_ty, 2).into(), - location, - )) - .result(0)? - .into(); - - let shift_amount = block - .append_operation(arith::subi(block.arg(1)?, k2, location)) - .result(0)? - .into(); - - let should_continue = block - .append_operation(arith::cmpi( - context, - CmpiPredicate::Sge, - shift_amount, - k0, - location, - )) - .result(0)? - .into(); - block.append_operation(scf::condition( - should_continue, - &[result, shift_amount], - location, - )); - - region - }, - { - let region = Region::new(); - let block = region.append_block(Block::new(&[ - (i256_ty, location), - (i256_ty, location), - ])); - - block.append_operation(scf::r#yield( - &[block.arg(0)?, block.argument(1)?.into()], - location, - )); - - region - }, - location, - )) - .result(0)? - .into(); - - block.append_operation(scf::r#yield(&[result], location)); - - region - }, - location, - )) - .result(0)? - .into(); - - let result = entry - .append_operation(arith::trunci(result, i128_ty, location)) - .result(0)? - .into(); + // The square root is computed by a Rust runtime function taking the low and + // high 128-bit limbs; it returns the result already fitting in 128 bits. + let runtime_bindings = metadata.get_or_insert_with(RuntimeBindingsMeta::default); + let result = runtime_bindings.u256_square_root( + context, + helper.module, + entry, + location, + arg_lo, + arg_hi, + )?; helper.br(entry, 0, &[range_check, result], location) } diff --git a/src/metadata/runtime_bindings.rs b/src/metadata/runtime_bindings.rs index 7a4fdac556..b9af83f230 100644 --- a/src/metadata/runtime_bindings.rs +++ b/src/metadata/runtime_bindings.rs @@ -21,7 +21,7 @@ use melior::{ operation::OperationBuilder, r#type::IntegerType, Attribute, Block, BlockLike, Identifier, Location, Module, OperationRef, Region, Type, - Value, + Value, ValueLike, }, Context, }; @@ -47,6 +47,12 @@ enum RuntimeBinding { BlakeCompress, DebugPrint, ExtendedEuclideanAlgorithm(ExtendedEuclideanWidth), + U8SquareRoot, + U16SquareRoot, + U32SquareRoot, + U64SquareRoot, + U128SquareRoot, + U256SquareRoot, CircuitArithOperation, DictIntoEntries, QM31Add, @@ -61,31 +67,35 @@ enum RuntimeBinding { impl RuntimeBinding { const fn symbol(self) -> &'static str { match self { - RuntimeBinding::DebugPrint => "cairo_native__libfunc__debug__print", - RuntimeBinding::Pedersen => "cairo_native__libfunc__pedersen", - RuntimeBinding::HadesPermutation => "cairo_native__libfunc__hades_permutation", - RuntimeBinding::EcStateTryFinalizeNz => { - "cairo_native__libfunc__ec__ec_state_try_finalize_nz" - } - RuntimeBinding::EcStateAddMul => "cairo_native__libfunc__ec__ec_state_add_mul", - RuntimeBinding::EcStateAdd => "cairo_native__libfunc__ec__ec_state_add", - RuntimeBinding::EcPointTryNewNz => "cairo_native__libfunc__ec__ec_point_try_new_nz", - RuntimeBinding::EcPointFromXNz => "cairo_native__libfunc__ec__ec_point_from_x_nz", - RuntimeBinding::DictNew => "cairo_native__dict_new", - RuntimeBinding::DictGet => "cairo_native__dict_get", - RuntimeBinding::DictSquash => "cairo_native__dict_squash", - RuntimeBinding::GetCostsBuiltin => "cairo_native__get_costs_builtin", - RuntimeBinding::BlakeCompress => "cairo_native__libfunc__blake_compress", - RuntimeBinding::ExtendedEuclideanAlgorithm(width) => width.symbol(), - RuntimeBinding::CircuitArithOperation => "cairo_native__circuit_arith_operation", - RuntimeBinding::DictIntoEntries => "cairo_native__dict_into_entries", - RuntimeBinding::QM31Add => "cairo_native__libfunc__qm31__qm31_add", - RuntimeBinding::QM31Sub => "cairo_native__libfunc__qm31__qm31_sub", - RuntimeBinding::QM31Mul => "cairo_native__libfunc__qm31__qm31_mul", - RuntimeBinding::QM31Div => "cairo_native__libfunc__qm31__qm31_div", - RuntimeBinding::ArenaAlloc => "cairo_native__arena_alloc", + Self::DebugPrint => "cairo_native__libfunc__debug__print", + Self::Pedersen => "cairo_native__libfunc__pedersen", + Self::HadesPermutation => "cairo_native__libfunc__hades_permutation", + Self::EcStateTryFinalizeNz => "cairo_native__libfunc__ec__ec_state_try_finalize_nz", + Self::EcStateAddMul => "cairo_native__libfunc__ec__ec_state_add_mul", + Self::EcStateAdd => "cairo_native__libfunc__ec__ec_state_add", + Self::EcPointTryNewNz => "cairo_native__libfunc__ec__ec_point_try_new_nz", + Self::EcPointFromXNz => "cairo_native__libfunc__ec__ec_point_from_x_nz", + Self::DictNew => "cairo_native__dict_new", + Self::DictGet => "cairo_native__dict_get", + Self::DictSquash => "cairo_native__dict_squash", + Self::GetCostsBuiltin => "cairo_native__get_costs_builtin", + Self::BlakeCompress => "cairo_native__libfunc__blake_compress", + Self::ExtendedEuclideanAlgorithm(width) => width.symbol(), + Self::U8SquareRoot => "cairo_native__u8_square_root", + Self::U16SquareRoot => "cairo_native__u16_square_root", + Self::U32SquareRoot => "cairo_native__u32_square_root", + Self::U64SquareRoot => "cairo_native__u64_square_root", + Self::U128SquareRoot => "cairo_native__u128_square_root", + Self::U256SquareRoot => "cairo_native__u256_square_root", + Self::CircuitArithOperation => "cairo_native__circuit_arith_operation", + Self::DictIntoEntries => "cairo_native__dict_into_entries", + Self::QM31Add => "cairo_native__libfunc__qm31__qm31_add", + Self::QM31Sub => "cairo_native__libfunc__qm31__qm31_sub", + Self::QM31Mul => "cairo_native__libfunc__qm31__qm31_mul", + Self::QM31Div => "cairo_native__libfunc__qm31__qm31_div", + Self::ArenaAlloc => "cairo_native__arena_alloc", #[cfg(feature = "with-cheatcode")] - RuntimeBinding::VtableCheatcode => "cairo_native__vtable_cheatcode", + Self::VtableCheatcode => "cairo_native__vtable_cheatcode", } } @@ -96,62 +106,38 @@ impl RuntimeBinding { /// - For internal bindings (implemented in MLIR), it returns `None`, since those /// functions are defined within MLIR and invoked by name const fn function_ptr(self) -> Option<*const ()> { + use crate::runtime::*; let function_ptr = match self { - RuntimeBinding::DebugPrint => { - crate::runtime::cairo_native__libfunc__debug__print as *const () - } - RuntimeBinding::Pedersen => { - crate::runtime::cairo_native__libfunc__pedersen as *const () - } - RuntimeBinding::HadesPermutation => { - crate::runtime::cairo_native__libfunc__hades_permutation as *const () - } - RuntimeBinding::EcStateTryFinalizeNz => { - crate::runtime::cairo_native__libfunc__ec__ec_state_try_finalize_nz as *const () - } - RuntimeBinding::EcStateAddMul => { - crate::runtime::cairo_native__libfunc__ec__ec_state_add_mul as *const () - } - RuntimeBinding::EcStateAdd => { - crate::runtime::cairo_native__libfunc__ec__ec_state_add as *const () - } - RuntimeBinding::EcPointTryNewNz => { - crate::runtime::cairo_native__libfunc__ec__ec_point_try_new_nz as *const () - } - RuntimeBinding::EcPointFromXNz => { - crate::runtime::cairo_native__libfunc__ec__ec_point_from_x_nz as *const () - } - RuntimeBinding::DictNew => crate::runtime::cairo_native__dict_new as *const (), - RuntimeBinding::DictGet => crate::runtime::cairo_native__dict_get as *const (), - RuntimeBinding::DictSquash => crate::runtime::cairo_native__dict_squash as *const (), - RuntimeBinding::GetCostsBuiltin => { - crate::runtime::cairo_native__get_costs_builtin as *const () - } - RuntimeBinding::DictIntoEntries => { - crate::runtime::cairo_native__dict_into_entries as *const () - } - RuntimeBinding::QM31Add => { - crate::runtime::cairo_native__libfunc__qm31__qm31_add as *const () - } - RuntimeBinding::QM31Sub => { - crate::runtime::cairo_native__libfunc__qm31__qm31_sub as *const () - } - RuntimeBinding::QM31Mul => { - crate::runtime::cairo_native__libfunc__qm31__qm31_mul as *const () - } - RuntimeBinding::QM31Div => { - crate::runtime::cairo_native__libfunc__qm31__qm31_div as *const () - } - RuntimeBinding::BlakeCompress => { - crate::runtime::cairo_native__libfunc__blake_compress as *const () + Self::DebugPrint => cairo_native__libfunc__debug__print as *const (), + Self::Pedersen => cairo_native__libfunc__pedersen as *const (), + Self::HadesPermutation => cairo_native__libfunc__hades_permutation as *const (), + Self::EcStateTryFinalizeNz => { + cairo_native__libfunc__ec__ec_state_try_finalize_nz as *const () } - RuntimeBinding::ExtendedEuclideanAlgorithm(_) => return None, - RuntimeBinding::CircuitArithOperation => return None, - RuntimeBinding::ArenaAlloc => crate::runtime::cairo_native__arena_alloc as *const (), + Self::EcStateAddMul => cairo_native__libfunc__ec__ec_state_add_mul as *const (), + Self::EcStateAdd => cairo_native__libfunc__ec__ec_state_add as *const (), + Self::EcPointTryNewNz => cairo_native__libfunc__ec__ec_point_try_new_nz as *const (), + Self::EcPointFromXNz => cairo_native__libfunc__ec__ec_point_from_x_nz as *const (), + Self::DictNew => cairo_native__dict_new as *const (), + Self::DictGet => cairo_native__dict_get as *const (), + Self::DictSquash => cairo_native__dict_squash as *const (), + Self::GetCostsBuiltin => cairo_native__get_costs_builtin as *const (), + Self::DictIntoEntries => cairo_native__dict_into_entries as *const (), + Self::QM31Add => cairo_native__libfunc__qm31__qm31_add as *const (), + Self::QM31Sub => cairo_native__libfunc__qm31__qm31_sub as *const (), + Self::QM31Mul => cairo_native__libfunc__qm31__qm31_mul as *const (), + Self::QM31Div => cairo_native__libfunc__qm31__qm31_div as *const (), + Self::BlakeCompress => cairo_native__libfunc__blake_compress as *const (), + Self::U8SquareRoot => cairo_native__u8_square_root as *const (), + Self::U16SquareRoot => cairo_native__u16_square_root as *const (), + Self::U32SquareRoot => cairo_native__u32_square_root as *const (), + Self::U64SquareRoot => cairo_native__u64_square_root as *const (), + Self::U128SquareRoot => cairo_native__u128_square_root as *const (), + Self::U256SquareRoot => cairo_native__u256_square_root as *const (), + Self::ArenaAlloc => cairo_native__arena_alloc as *const (), + Self::ExtendedEuclideanAlgorithm(_) | Self::CircuitArithOperation => return None, #[cfg(feature = "with-cheatcode")] - RuntimeBinding::VtableCheatcode => { - crate::starknet::cairo_native__vtable_cheatcode as *const () - } + Self::VtableCheatcode => crate::starknet::cairo_native__vtable_cheatcode as *const (), }; Some(function_ptr) } @@ -290,6 +276,79 @@ impl RuntimeBindingsMeta { .into()) } + /// Register if necessary, then invoke the integer square root runtime + /// function matching the width of `value` (`u8`..`u128`). + /// + /// Returns `floor(sqrt(value))` in the (smaller) output type used by the + /// libfunc. + pub fn integer_square_root<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + value: Value<'c, '_>, + ) -> Result> + where + 'c: 'a, + { + let value_type: IntegerType = value.r#type().try_into()?; + // Each width has its own runtime function whose result already fits the + // (smaller) output width used by the libfunc. + let (binding, output_bits) = match value_type.width() { + 8 => (RuntimeBinding::U8SquareRoot, 8), + 16 => (RuntimeBinding::U16SquareRoot, 8), + 32 => (RuntimeBinding::U32SquareRoot, 16), + 64 => (RuntimeBinding::U64SquareRoot, 32), + 128 => (RuntimeBinding::U128SquareRoot, 64), + _ => crate::native_panic!("invalid integer width in square root"), + }; + let output_type = IntegerType::new(context, output_bits).into(); + let function = self.build_function(context, module, block, location, binding)?; + Ok(block + .append_operation( + OperationBuilder::new("llvm.call", location) + .add_operands(&[function, value]) + .add_results(&[output_type]) + .build()?, + ) + .result(0)? + .into()) + } + + /// Register if necessary, then invoke `cairo_native__u256_square_root` on + /// the `u256` given by its low and high `u128` limbs, returning the `u128` + /// result. + pub fn u256_square_root<'c, 'a>( + &mut self, + context: &'c Context, + module: &Module, + block: &'a Block<'c>, + location: Location<'c>, + lo: Value<'c, '_>, + hi: Value<'c, '_>, + ) -> Result> + where + 'c: 'a, + { + let function = self.build_function( + context, + module, + block, + location, + RuntimeBinding::U256SquareRoot, + )?; + Ok(block + .append_operation( + OperationBuilder::new("llvm.call", location) + .add_operands(&[function, lo, hi]) + .add_results(&[IntegerType::new(context, 128).into()]) + .build()?, + ) + .result(0)? + .into()) + } + /// Builds, if necessary, the circuit operation function, used to perform /// circuit arithmetic operations. /// @@ -908,6 +967,12 @@ pub fn setup_runtime(find_symbol_ptr: impl Fn(&str) -> Option<*mut c_void>) { RuntimeBinding::QM31Mul, RuntimeBinding::QM31Div, RuntimeBinding::ArenaAlloc, + RuntimeBinding::U8SquareRoot, + RuntimeBinding::U16SquareRoot, + RuntimeBinding::U32SquareRoot, + RuntimeBinding::U64SquareRoot, + RuntimeBinding::U128SquareRoot, + RuntimeBinding::U256SquareRoot, #[cfg(feature = "with-cheatcode")] RuntimeBinding::VtableCheatcode, ] { diff --git a/src/runtime.rs b/src/runtime.rs index 61182c1c2e..b1380c02ac 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -11,6 +11,7 @@ use cairo_lang_sierra_gas::core_libfunc_cost::{ use itertools::Itertools; use lambdaworks_math::field::fields::mersenne31::extensions::Degree4ExtensionField; use lazy_static::lazy_static; +use num_bigint::BigUint; use num_traits::{ToPrimitive, Zero}; use starknet_curve::curve_params::BETA; use starknet_types_core::{ @@ -42,6 +43,42 @@ pub(crate) static EXECUTION_ARENA: RefCell = RefCell::new(Bump::new()); pub(crate) static DICT_REGISTRY: RefCell> = const { RefCell::new(Vec::new()) }; } +/// Compute `floor(sqrt(value))`. The result of each integer square root always +/// fits in the (smaller) output type used by the corresponding libfunc. +pub extern "C" fn cairo_native__u8_square_root(value: u8) -> u8 { + value.isqrt() +} + +/// Compute `floor(sqrt(value))`. See [`cairo_native__u8_square_root`]. +pub extern "C" fn cairo_native__u16_square_root(value: u16) -> u8 { + value.isqrt() as u8 +} + +/// Compute `floor(sqrt(value))`. See [`cairo_native__u8_square_root`]. +pub extern "C" fn cairo_native__u32_square_root(value: u32) -> u16 { + value.isqrt() as u16 +} + +/// Compute `floor(sqrt(value))`. See [`cairo_native__u8_square_root`]. +pub extern "C" fn cairo_native__u64_square_root(value: u64) -> u32 { + value.isqrt() as u32 +} + +/// Compute `floor(sqrt(value))`. See [`cairo_native__u8_square_root`]. +pub extern "C" fn cairo_native__u128_square_root(value: u128) -> u64 { + value.isqrt() as u64 +} + +/// Compute `floor(sqrt(value))` of the `u256` given by its low and high `u128` +/// limbs. The result always fits in a `u128`. +pub extern "C" fn cairo_native__u256_square_root(lo: u128, hi: u128) -> u128 { + let value = (BigUint::from(hi) << 128u32) + BigUint::from(lo); + value + .sqrt() + .to_u128() + .expect("the square root of a u256 always fits in a u128") +} + /// Allocate `size` bytes with `align` alignment from the per-execution arena. pub unsafe extern "C" fn cairo_native__arena_alloc(size: u64, align: u64) -> *mut u8 { EXECUTION_ARENA.with(|arena| { diff --git a/tests/tests/uint.rs b/tests/tests/uint.rs index 2ade72ed8b..4bab891af8 100644 --- a/tests/tests/uint.rs +++ b/tests/tests/uint.rs @@ -792,4 +792,35 @@ proptest! { &result_native, )?; } + + // u256 + + #[test] + fn u256_sqrt_proptest(lo in any::(), hi in any::()) { + let program = &load_program_and_runner("programs/libfuncs/u256_sqrt"); + let result_vm = run_vm_program( + program, + "run_test", + vec![Arg::Value(lo.into()), Arg::Value(hi.into())], + Some(DEFAULT_GAS as usize), + ) + .unwrap(); + let result_native = run_native_program( + program, + "run_test", + &[Value::Struct { + fields: vec![Value::Uint128(lo), Value::Uint128(hi)], + debug_name: None, + }], + Some(DEFAULT_GAS), + Option::::None, + ); + + compare_outputs( + &program.1, + &program.2.find_function("run_test").unwrap().id, + &result_vm, + &result_native, + )?; + } }