Skip to content

Commit 9a43f85

Browse files
authored
feat: Implement std::unsafe::zeroed (#1048)
* Add zeroed builtin * Implement std::unsafe::zeroed * Fix merge conflict
1 parent a67e8c5 commit 9a43f85

File tree

3 files changed

+100
-34
lines changed

3 files changed

+100
-34
lines changed

crates/noirc_frontend/src/monomorphization/mod.rs

+94-34
Original file line numberDiff line numberDiff line change
@@ -700,12 +700,8 @@ impl<'interner> Monomorphizer<'interner> {
700700
let return_type = Self::convert_type(&return_type);
701701
let location = call.location;
702702

703-
self.try_evaluate_call(&func, &call.arguments).unwrap_or(ast::Expression::Call(ast::Call {
704-
func,
705-
arguments,
706-
return_type,
707-
location,
708-
}))
703+
self.try_evaluate_call(&func, &call.arguments, &return_type)
704+
.unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location }))
709705
}
710706

711707
/// Try to evaluate certain builtin functions (currently only 'array_len' and field modulus methods)
@@ -715,50 +711,47 @@ impl<'interner> Monomorphizer<'interner> {
715711
/// To fix this we need to evaluate on the identifier instead, which
716712
/// requires us to evaluate to a Lambda value which isn't in noir yet.
717713
fn try_evaluate_call(
718-
&self,
714+
&mut self,
719715
func: &ast::Expression,
720716
arguments: &[node_interner::ExprId],
717+
result_type: &ast::Type,
721718
) -> Option<ast::Expression> {
722-
match func {
723-
ast::Expression::Ident(ident) => match &ident.definition {
724-
Definition::Builtin(opcode) if opcode == "array_len" => {
719+
if let ast::Expression::Ident(ident) = func {
720+
if let Definition::Builtin(opcode) = &ident.definition {
721+
if opcode == "array_len" {
725722
let typ = self.interner.id_type(arguments[0]);
726723
let len = typ.evaluate_to_u64().unwrap();
727-
Some(ast::Expression::Literal(ast::Literal::Integer(
724+
return Some(ast::Expression::Literal(ast::Literal::Integer(
728725
(len as u128).into(),
729726
ast::Type::Field,
730-
)))
731-
}
732-
Definition::Builtin(opcode) if opcode == "modulus_num_bits" => {
733-
Some(ast::Expression::Literal(ast::Literal::Integer(
727+
)));
728+
} else if opcode == "modulus_num_bits" {
729+
return Some(ast::Expression::Literal(ast::Literal::Integer(
734730
(FieldElement::max_num_bits() as u128).into(),
735731
ast::Type::Field,
736-
)))
732+
)));
733+
} else if opcode == "zeroed" {
734+
return Some(self.zeroed_value_of_type(result_type));
737735
}
738-
Definition::Builtin(opcode) if opcode == "modulus_le_bits" => {
739-
let modulus = FieldElement::modulus();
736+
737+
let modulus = FieldElement::modulus();
738+
739+
if opcode == "modulus_le_bits" {
740740
let bits = modulus.to_radix_le(2);
741-
Some(self.modulus_array_literal(bits, 1))
742-
}
743-
Definition::Builtin(opcode) if opcode == "modulus_be_bits" => {
744-
let modulus = FieldElement::modulus();
741+
return Some(self.modulus_array_literal(bits, 1));
742+
} else if opcode == "modulus_be_bits" {
745743
let bits = modulus.to_radix_be(2);
746-
Some(self.modulus_array_literal(bits, 1))
747-
}
748-
Definition::Builtin(opcode) if opcode == "modulus_be_bytes" => {
749-
let modulus = FieldElement::modulus();
744+
return Some(self.modulus_array_literal(bits, 1));
745+
} else if opcode == "modulus_be_bytes" {
750746
let bytes = modulus.to_bytes_be();
751-
Some(self.modulus_array_literal(bytes, 8))
752-
}
753-
Definition::Builtin(opcode) if opcode == "modulus_le_bytes" => {
754-
let modulus = FieldElement::modulus();
747+
return Some(self.modulus_array_literal(bytes, 8));
748+
} else if opcode == "modulus_le_bytes" {
755749
let bytes = modulus.to_bytes_le();
756-
Some(self.modulus_array_literal(bytes, 8))
750+
return Some(self.modulus_array_literal(bytes, 8));
757751
}
758-
_ => None,
759-
},
760-
_ => None,
752+
}
761753
}
754+
None
762755
}
763756

764757
fn modulus_array_literal(&self, bytes: Vec<u8>, arr_elem_bits: u32) -> ast::Expression {
@@ -919,6 +912,73 @@ impl<'interner> Monomorphizer<'interner> {
919912
typ,
920913
})
921914
}
915+
916+
/// Implements std::unsafe::zeroed by returning an appropriate zeroed
917+
/// ast literal or collection node for the given type. Note that for functions
918+
/// there is no obvious zeroed value so this should be considered unsafe to use.
919+
fn zeroed_value_of_type(&mut self, typ: &ast::Type) -> ast::Expression {
920+
match typ {
921+
ast::Type::Field | ast::Type::Integer(..) => {
922+
ast::Expression::Literal(ast::Literal::Integer(0_u128.into(), typ.clone()))
923+
}
924+
ast::Type::Bool => ast::Expression::Literal(ast::Literal::Bool(false)),
925+
// There is no unit literal currently. Replace it with 'false' since it should be ignored
926+
// anyway.
927+
ast::Type::Unit => ast::Expression::Literal(ast::Literal::Bool(false)),
928+
ast::Type::Array(length, element_type) => {
929+
let element = self.zeroed_value_of_type(element_type.as_ref());
930+
ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral {
931+
contents: vec![element; *length as usize],
932+
element_type: element_type.as_ref().clone(),
933+
}))
934+
}
935+
ast::Type::String(length) => {
936+
ast::Expression::Literal(ast::Literal::Str("\0".repeat(*length as usize)))
937+
}
938+
ast::Type::Tuple(fields) => {
939+
ast::Expression::Tuple(vecmap(fields, |field| self.zeroed_value_of_type(field)))
940+
}
941+
ast::Type::Function(parameter_types, ret_type) => {
942+
self.create_zeroed_function(parameter_types, ret_type)
943+
}
944+
}
945+
}
946+
947+
// Creating a zeroed function value is almost always an error if it is used later,
948+
// Hence why std::unsafe::zeroed is unsafe.
949+
//
950+
// To avoid confusing later passes, we arbitrarily choose to construct a function
951+
// that satisfies the input type by discarding all its parameters and returning a
952+
// zeroed value of the result type.
953+
fn create_zeroed_function(
954+
&mut self,
955+
parameter_types: &[ast::Type],
956+
ret_type: &ast::Type,
957+
) -> ast::Expression {
958+
let lambda_name = "zeroed_lambda";
959+
960+
let parameters = vecmap(parameter_types, |parameter_type| {
961+
(self.next_local_id(), false, "_".into(), parameter_type.clone())
962+
});
963+
964+
let body = self.zeroed_value_of_type(ret_type);
965+
966+
let id = self.next_function_id();
967+
let return_type = ret_type.clone();
968+
let name = lambda_name.to_owned();
969+
970+
let unconstrained = false;
971+
let function = ast::Function { id, name, parameters, body, return_type, unconstrained };
972+
self.push_function(id, function);
973+
974+
ast::Expression::Ident(ast::Ident {
975+
definition: Definition::Function(id),
976+
mutable: false,
977+
location: None,
978+
name: lambda_name.to_owned(),
979+
typ: ast::Type::Function(parameter_types.to_owned(), Box::new(ret_type.clone())),
980+
})
981+
}
922982
}
923983

924984
fn unwrap_tuple_type(typ: &HirType) -> Vec<HirType> {

noir_stdlib/src/lib.nr

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod sha256;
88
mod sha512;
99
mod field;
1010
mod ec;
11+
mod unsafe;
1112

1213
#[builtin(println)]
1314
fn println<T>(_input : T) {}

noir_stdlib/src/unsafe.nr

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
/// For any type, return an instance of that type by initializing
2+
/// all of its fields to 0. This is considered to be unsafe since there
3+
/// is no guarantee that all zeroes is a valid bit pattern for every type.
4+
#[builtin(zeroed)]
5+
fn zeroed<T>() -> T {}

0 commit comments

Comments
 (0)