Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 10 additions & 145 deletions src/libfuncs/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ use crate::{
error::{panic::ToNativeAssertError, Result},
execution_result::BITWISE_BUILTIN_SIZE,
libfuncs::{increment_builtin_counter, increment_builtin_counter_by},
metadata::MetadataStorage,
native_panic,
metadata::{runtime_bindings::RuntimeBindingsMeta, MetadataStorage},
types::TypeBuilder,
utils::{ProgramRegistryExt, PRIME},
};
use cairo_lang_sierra::{
extensions::{
bounded_int::BoundedIntDivRemAlgorithm,
core::{CoreLibfunc, CoreType, CoreTypeConcrete},
core::{CoreLibfunc, CoreType},
int::{
signed::{SintConcrete, SintTraits},
signed128::Sint128Concrete,
Expand All @@ -35,8 +34,8 @@ use melior::{
},
helpers::{ArithBlockExt, BuiltinBlockExt, LlvmBlockExt},
ir::{
attribute::IntegerAttribute, operation::OperationBuilder, r#type::IntegerType, Block,
BlockLike, Location, Region, ValueLike,
operation::OperationBuilder, r#type::IntegerType, Block, BlockLike, Location, Region,
ValueLike,
},
Context,
};
Expand Down Expand Up @@ -689,157 +688,23 @@ fn build_operation<'ctx, 'this>(

fn build_square_root<'ctx, 'this>(
context: &'ctx Context,
registry: &ProgramRegistry<CoreType, CoreLibfunc>,
_registry: &ProgramRegistry<CoreType, CoreLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
_metadata: &mut MetadataStorage,
info: &SignatureOnlyConcreteLibfunc,
metadata: &mut MetadataStorage,
_info: &SignatureOnlyConcreteLibfunc,
) -> Result<()> {
// The sierra-to-casm compiler uses the range_check builtin 4 times.
// https://github.com/starkware-libs/cairo/blob/v2.12.0-dev.1/crates/cairo-lang-sierra-to-casm/src/invocations/int/unsigned.rs#L73
let range_check =
super::increment_builtin_counter_by(context, entry, location, entry.arg(0)?, 4)?;

let input = entry.arg(1)?;
let (input_bits, value_bits) =
match registry.get_type(&info.signature.param_signatures[1].ty)? {
CoreTypeConcrete::Uint8(_) => (8, 8),
CoreTypeConcrete::Uint16(_) => (16, 8),
CoreTypeConcrete::Uint32(_) => (32, 16),
CoreTypeConcrete::Uint64(_) => (64, 32),
CoreTypeConcrete::Uint128(_) => (128, 64),
_ => native_panic!("invalid value type in int square root"),
};

let k1 = entry.const_int(context, location, 1, input_bits)?;
let is_small = entry.cmpi(context, CmpiPredicate::Ule, input, k1, location)?;

let value = entry.append_op_result(scf::r#if(
is_small,
&[IntegerType::new(context, value_bits).into()],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));

let value = block.trunci(
input,
IntegerType::new(context, value_bits).into(),
location,
)?;

block.append_operation(scf::r#yield(&[value], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));

let leading_zeros = block.append_op_result(
ods::llvm::intr_ctlz(
context,
IntegerType::new(context, input_bits).into(),
input,
IntegerAttribute::new(IntegerType::new(context, 1).into(), 1),
location,
)
.into(),
)?;

let k_bits = block.const_int(context, location, input_bits, input_bits)?;
let num_bits = block.append_op_result(arith::subi(k_bits, leading_zeros, location))?;
let shift_amount = block.addi(num_bits, k1, location)?;

let parity_mask = block.const_int(context, location, -2, input_bits)?;
let shift_amount =
block.append_op_result(arith::andi(shift_amount, parity_mask, location))?;

let k0 = block.const_int(context, location, 0, input_bits)?;
let value = block.append_op_result(scf::r#while(
&[k0, shift_amount],
&[
IntegerType::new(context, input_bits).into(),
IntegerType::new(context, input_bits).into(),
],
{
let region = Region::new();
let block = region.append_block(Block::new(&[
(IntegerType::new(context, input_bits).into(), location),
(IntegerType::new(context, input_bits).into(), location),
]));

let value = block.shli(block.arg(0)?, k1, location)?;
let large_candidate =
block.append_op_result(arith::xori(value, k1, location))?;
let large_candidate_squared =
block.muli(large_candidate, large_candidate, location)?;

let threshold = block.shrui(input, block.arg(1)?, location)?;
let threshold_is_poison =
block.cmpi(context, CmpiPredicate::Eq, block.arg(1)?, k_bits, location)?;
let threshold = block.append_op_result(arith::select(
threshold_is_poison,
k0,
threshold,
location,
))?;

let is_in_range = block.cmpi(
context,
CmpiPredicate::Ule,
large_candidate_squared,
threshold,
location,
)?;
let value = block.append_op_result(arith::select(
is_in_range,
large_candidate,
value,
location,
))?;

let k2 = block.const_int(context, location, 2, input_bits)?;
let shift_amount =
block.append_op_result(arith::subi(block.arg(1)?, k2, location))?;

let should_continue =
block.cmpi(context, CmpiPredicate::Sge, shift_amount, k0, location)?;
block.append_operation(scf::condition(
should_continue,
&[value, shift_amount],
location,
));

region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[
(IntegerType::new(context, input_bits).into(), location),
(IntegerType::new(context, input_bits).into(), location),
]));

block.append_operation(scf::r#yield(&[block.arg(0)?, block.arg(1)?], location));
region
},
location,
))?;

let value = if input_bits == value_bits {
value
} else {
block.trunci(
value,
IntegerType::new(context, value_bits).into(),
location,
)?
};

block.append_operation(scf::r#yield(&[value], location));
region
},
location,
))?;
let runtime_bindings = metadata.get_or_insert_with(RuntimeBindingsMeta::default);
let value =
runtime_bindings.integer_square_root(context, helper.module, entry, location, input)?;

helper.br(entry, 0, &[range_check, value], location)
}
Expand Down
Loading
Loading