Skip to content

Commit 0bcab0a

Browse files
authored
refactor!: use enum op traits for floats + conversions (#755)
BREAKING CHANGES: extension() function replaced with EXTENSION static ref for float_ops and conversions
1 parent 268f120 commit 0bcab0a

File tree

7 files changed

+245
-159
lines changed

7 files changed

+245
-159
lines changed

src/ops/constant.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ mod test {
148148
use super::*;
149149

150150
fn test_registry() -> ExtensionRegistry {
151-
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap()
151+
ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap()
152152
}
153153

154154
#[test]

src/std_extensions/arithmetic/conversions.rs

+116-48
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,131 @@
11
//! Conversions between integer and floating-point values.
22
3+
use smol_str::SmolStr;
4+
use strum_macros::{EnumIter, EnumString, IntoStaticStr};
5+
36
use crate::{
4-
extension::{prelude::sum_with_error, ExtensionId, ExtensionSet},
7+
extension::{
8+
prelude::sum_with_error,
9+
simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError},
10+
ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc,
11+
},
12+
ops::{custom::ExtensionOp, OpName},
513
type_row,
6-
types::{FunctionType, PolyFuncType},
14+
types::{FunctionType, PolyFuncType, TypeArg},
715
Extension,
816
};
917

1018
use super::int_types::int_tv;
1119
use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM};
20+
use lazy_static::lazy_static;
1221

1322
/// The extension identifier.
1423
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions");
1524

16-
/// Extension for basic arithmetic operations.
17-
pub fn extension() -> Extension {
18-
let ftoi_sig = PolyFuncType::new(
19-
vec![LOG_WIDTH_TYPE_PARAM],
20-
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
21-
);
22-
23-
let itof_sig = PolyFuncType::new(
24-
vec![LOG_WIDTH_TYPE_PARAM],
25-
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
26-
);
27-
28-
let mut extension = Extension::new_with_reqs(
29-
EXTENSION_ID,
30-
ExtensionSet::from_iter(vec![
31-
super::int_types::EXTENSION_ID,
32-
super::float_types::EXTENSION_ID,
33-
]),
34-
);
35-
extension
36-
.add_op(
37-
"trunc_u".into(),
38-
"float to unsigned int".to_owned(),
39-
ftoi_sig.clone(),
40-
)
41-
.unwrap();
42-
extension
43-
.add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig)
44-
.unwrap();
45-
extension
46-
.add_op(
47-
"convert_u".into(),
48-
"unsigned int to float".to_owned(),
49-
itof_sig.clone(),
50-
)
51-
.unwrap();
52-
extension
53-
.add_op(
54-
"convert_s".into(),
55-
"signed int to float".to_owned(),
56-
itof_sig,
57-
)
58-
.unwrap();
59-
60-
extension
25+
/// Extensiop for conversions between floats and integers.
26+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)]
27+
#[allow(missing_docs, non_camel_case_types)]
28+
pub enum ConvertOpDef {
29+
trunc_u,
30+
trunc_s,
31+
convert_u,
32+
convert_s,
33+
}
34+
35+
impl MakeOpDef for ConvertOpDef {
36+
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
37+
crate::extension::simple_op::try_from_name(op_def.name())
38+
}
39+
40+
fn signature(&self) -> SignatureFunc {
41+
use ConvertOpDef::*;
42+
match self {
43+
trunc_s | trunc_u => PolyFuncType::new(
44+
vec![LOG_WIDTH_TYPE_PARAM],
45+
FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]),
46+
),
47+
48+
convert_s | convert_u => PolyFuncType::new(
49+
vec![LOG_WIDTH_TYPE_PARAM],
50+
FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]),
51+
),
52+
}
53+
.into()
54+
}
55+
56+
fn description(&self) -> String {
57+
use ConvertOpDef::*;
58+
match self {
59+
trunc_u => "float to unsigned int",
60+
trunc_s => "float to signed int",
61+
convert_u => "unsigned int to float",
62+
convert_s => "signed int to float",
63+
}
64+
.to_string()
65+
}
66+
}
67+
68+
/// Concrete convert operation with integer width set.
69+
#[derive(Debug, Clone, PartialEq)]
70+
pub struct ConvertOpType {
71+
def: ConvertOpDef,
72+
width: u64,
73+
}
74+
75+
impl OpName for ConvertOpType {
76+
fn name(&self) -> SmolStr {
77+
self.def.name()
78+
}
79+
}
80+
81+
impl MakeExtensionOp for ConvertOpType {
82+
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError> {
83+
let def = ConvertOpDef::from_def(ext_op.def())?;
84+
let width = match *ext_op.args() {
85+
[TypeArg::BoundedNat { n }] => n,
86+
_ => return Err(SignatureError::InvalidTypeArgs.into()),
87+
};
88+
Ok(Self { def, width })
89+
}
90+
91+
fn type_args(&self) -> Vec<crate::types::TypeArg> {
92+
vec![TypeArg::BoundedNat { n: self.width }]
93+
}
94+
}
95+
96+
lazy_static! {
97+
/// Extension for conversions between integers and floats.
98+
pub static ref EXTENSION: Extension = {
99+
let mut extension = Extension::new_with_reqs(
100+
EXTENSION_ID,
101+
ExtensionSet::from_iter(vec![
102+
super::int_types::EXTENSION_ID,
103+
super::float_types::EXTENSION_ID,
104+
]),
105+
);
106+
107+
ConvertOpDef::load_all_ops(&mut extension).unwrap();
108+
109+
extension
110+
};
111+
112+
/// Registry of extensions required to validate integer operations.
113+
pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([
114+
super::int_types::EXTENSION.to_owned(),
115+
super::float_types::EXTENSION.to_owned(),
116+
EXTENSION.to_owned(),
117+
])
118+
.unwrap();
119+
}
120+
121+
impl MakeRegisteredOp for ConvertOpType {
122+
fn extension_id(&self) -> ExtensionId {
123+
EXTENSION_ID.to_owned()
124+
}
125+
126+
fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry {
127+
&CONVERT_OPS_REGISTRY
128+
}
61129
}
62130

63131
#[cfg(test)]
@@ -66,7 +134,7 @@ mod test {
66134

67135
#[test]
68136
fn test_conversions_extension() {
69-
let r = extension();
137+
let r = &EXTENSION;
70138
assert_eq!(r.name() as &str, "arithmetic.conversions");
71139
assert_eq!(r.types().count(), 0);
72140
for (name, _) in r.operations() {

0 commit comments

Comments
 (0)