Skip to content
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ curve25519-dalek = {version = "3.2.0", features = ["serde"], optional = true}
paste = "1.0"
im = "15"
once_cell = "1"
num-traits = "0.2"

[dev-dependencies]
quickcheck = "1"
Expand Down
125 changes: 122 additions & 3 deletions src/ir/term/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

use super::{
check, const_, extras, term, Array, BitVector, BoolNaryOp, BvBinOp, BvBinPred, BvNaryOp,
BvUnOp, FieldToBv, FxHashMap, IntBinOp, IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op,
PfNaryOp, PfUnOp, Sort, Term, TermMap, Value,
BvUnOp, FieldToBv, Float, FpBinOp, FpBinPred, FpUnOp, FpUnPred, FxHashMap, IntBinOp,
IntBinPred, IntNaryOp, IntUnOp, Integer, Node, Op, PfNaryOp, PfUnOp, Sort, Term, TermMap,
Value,
};
use crate::cfg::cfg_or_default;

Expand Down Expand Up @@ -165,6 +166,7 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
}
}),
Op::BoolToBv => Value::BitVector(BitVector::new(Integer::from(args[0].as_bool()), 1)),

Op::PfUnOp(o) => Value::Field({
let a = args[0].as_pf().clone();
match o {
Expand Down Expand Up @@ -215,7 +217,6 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
},
)
}),

Op::IntBinOp(o) => Value::Int({
let a = args[0].as_int().clone();
let b = args[1].as_int().clone();
Expand All @@ -239,6 +240,124 @@ pub fn eval_op(op: &Op, args: &[&Value], var_vals: &FxHashMap<String, Value>) ->
IntUnOp::Neg => -a,
}
}),
Op::FpBinOp(o) => {
// Promote to f64 if either operand is f64
let promote_to_f64 =
matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_));
Comment thread
lorenzorota marked this conversation as resolved.
Outdated
fn comp<T: Float>(a: T, b: T, op: FpBinOp) -> T {
match op {
FpBinOp::Add => a + b,
FpBinOp::Sub => a - b,
FpBinOp::Mul => a * b,
FpBinOp::Div => a / b,
FpBinOp::Rem => a % b,
FpBinOp::Max => a.max(b),
FpBinOp::Min => a.min(b),
}
}
if promote_to_f64 {
Value::F64(comp(args[0].as_f64(), args[1].as_f64(), *o))
} else {
Value::F32(comp(args[0].as_f32(), args[1].as_f32(), *o))
}
}
Op::FpBinPred(o) => Value::Bool({
// Promote to f64 if either operand is f64
let promote_to_f64 =
matches!(args[0], Value::F64(_)) || matches!(args[1], Value::F64(_));
fn comp<T: Float>(a: T, b: T, op: FpBinPred) -> bool {
match op {
FpBinPred::Le => a <= b,
FpBinPred::Lt => a < b,
FpBinPred::Eq => a == b,
FpBinPred::Ge => a >= b,
FpBinPred::Gt => a > b,
}
}
if promote_to_f64 {
comp(args[0].as_f64(), args[1].as_f64(), *o)
} else {
comp(args[0].as_f32(), args[1].as_f32(), *o)
}
}),
Op::FpUnPred(o) => Value::Bool({
fn comp<T: Float>(a: T, op: FpUnPred) -> bool {
match op {
FpUnPred::Normal => a.is_normal(),
FpUnPred::Subnormal => a.is_subnormal(),
FpUnPred::Zero => a == T::zero(),
FpUnPred::Infinite => a.is_infinite(),
FpUnPred::Nan => a.is_nan(),
FpUnPred::Negative => a.is_sign_negative(),
FpUnPred::Positive => a.is_sign_positive(),
}
}
match args[0] {
Value::F32(a) => comp(*a, *o),
Value::F64(a) => comp(*a, *o),
_ => panic!("Expected F32 or F64, got {}", args[0]),
}
}),
Op::FpUnOp(o) => {
fn comp<T: Float>(a: T, op: FpUnOp) -> T {
match op {
FpUnOp::Neg => -a,
FpUnOp::Abs => a.abs(),
FpUnOp::Sqrt => a.sqrt(),
FpUnOp::Round => a.round(),
}
}
match args[0] {
Value::F32(a) => Value::F32(comp(*a, *o)),
Value::F64(a) => Value::F64(comp(*a, *o)),
_ => panic!("Expected F32 or F64, got {}", args[0]),
}
}
Op::BvToFp => {
let bv = args[0].as_bv();
let val = bv.uint();
let w = bv.width();
match w {
32 => Value::F32(f32::from_bits(val.to_u32().unwrap())),
64 => Value::F64(f64::from_bits(val.to_u64().unwrap())),
_ => panic!("{} out of bounds for {} on {:?}", w, op, args),
}
}
Op::UbvToFp(w) => {
let val = args[0].as_bv().uint();
match w {
0..=32 => Value::F32(val.to_f32()),
33..=64 => Value::F64(val.to_f64()),
_ => panic!("{} out of bounds for {} on {:?}", w, op, args),
}
}
Op::SbvToFp(w) => {
let val = args[0].as_bv().as_sint();
match w {
0..=32 => Value::F32(val.to_f32()),
33..=64 => Value::F64(val.to_f64()),
_ => panic!("{} out of bounds for {} on {:?}", w, op, args),
}
}
Op::FpToFp(w) => {
match (args[0], w) {
(Value::F32(v), 64) => Value::F64(*v as f64), // Promote F32 to F64
(Value::F64(v), 32) => Value::F32(*v as f32), // Truncate F64 to F32
(Value::F32(_), 32) | (Value::F64(_), 64) => args[0].clone(),
_ => panic!("Invalid conversion width {} (expected 32 or 64)", w),
}
}
Op::PfToFp(w) => {
let val = args[0].as_pf().i();
match w {
32 => Value::F32(val.to_f32()),
64 => Value::F64(val.to_f64()),
_ => panic!(
"{} out of bounds for {} on {:?} (expected 32 or 64)",
w, op, args
),
}
}
Op::UbvToPf(fty) => Value::Field(fty.new_v(args[0].as_bv().uint())),
Op::PfChallenge(c) => Value::Field(eval_pf_challenge(&c.name, &c.field)),
Op::Witness(_) => args[0].clone(),
Expand Down
1 change: 1 addition & 0 deletions src/ir/term/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ impl DisplayIr for Op {
Op::UbvToFp(a) => write!(f, "(ubv2fp {a})"),
Op::SbvToFp(a) => write!(f, "(sbv2fp {a})"),
Op::FpToFp(a) => write!(f, "(fp2fp {a})"),
Op::PfToFp(a) => write!(f, "(pf2fp {a})"),
Op::PfUnOp(a) => write!(f, "{a}"),
Op::PfNaryOp(a) => write!(f, "{a}"),
Op::PfDiv => write!(f, "/"),
Expand Down
40 changes: 40 additions & 0 deletions src/ir/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub use circ_hc::{Node, Table, Weak};
use circ_opt::FieldToBv;
use fxhash::{FxHashMap, FxHashSet};
use log::debug;
use num_traits::Float;
use rug::Integer;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
Expand Down Expand Up @@ -119,6 +120,9 @@ pub enum Op {
// dest width
/// translate the number represented by the argument to a floating-point value of this width.
FpToFp(usize),
/// translate the prime-field number represented by the argument to a floating-point value
/// of this width.
PfToFp(usize),

/// Prime-field unary operator
PfUnOp(PfUnOp),
Expand Down Expand Up @@ -333,6 +337,7 @@ pub const INT_LE: Op = Op::IntBinPred(IntBinPred::Le);
pub const INT_GT: Op = Op::IntBinPred(IntBinPred::Gt);
/// integer greater than or equal
pub const INT_GE: Op = Op::IntBinPred(IntBinPred::Ge);
// TODO: add floating-point operator abbreviations

impl Op {
/// Number of arguments for this operator. `None` if n-ary.
Expand Down Expand Up @@ -365,6 +370,7 @@ impl Op {
Op::UbvToFp(_) => Some(1),
Op::SbvToFp(_) => Some(1),
Op::FpToFp(_) => Some(1),
Op::PfToFp(_) => Some(1),
Op::PfUnOp(_) => Some(1),
Op::PfDiv => Some(2),
Op::PfNaryOp(_) => None,
Expand Down Expand Up @@ -1263,6 +1269,22 @@ impl Term {
None
}
}
/// Get the underlying 32-bit floating-point constant, if possible.
pub fn as_f32_opt(&self) -> Option<f32> {
if let Some(Value::F32(v)) = self.as_value_opt() {
Some(*v)
} else {
None
}
}
/// Get the underlying 64-bit floating-point constant, if possible.
pub fn as_f64_opt(&self) -> Option<f64> {
match self.as_value_opt()? {
Value::F64(v) => Some(*v),
Value::F32(v) => Some(*v as f64), // Floating-point promotion
_ => None,
}
}
/// Get the underlying prime field constant, if possible.
pub fn as_pf_opt(&self) -> Option<&FieldV> {
if let Some(Value::Field(b)) = self.as_value_opt() {
Expand Down Expand Up @@ -1381,6 +1403,24 @@ impl Value {
}
}
#[track_caller]
/// Get the underlying 32-bit floating-point constant, or panic!
pub fn as_f32(&self) -> f32 {
if let Value::F32(v) = self {
*v
} else {
panic!("Not a f32: {}", self)
}
}
#[track_caller]
/// Get the underlying 64-bit floating-point constant, or panic!
pub fn as_f64(&self) -> f64 {
match self {
Value::F32(v) => *v as f64, // Floating-point promotion
Value::F64(v) => *v,
_ => panic!("Not a f64 or f32: {}", self),
}
}
#[track_caller]
/// Get the underlying prime field constant, if possible.
pub fn as_pf(&self) -> &FieldV {
if let Value::Field(b) = self {
Expand Down
Loading