Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl Hash and Eq on more comptime types #6022

Merged
merged 4 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 68 additions & 102 deletions compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
hash::{Hash, Hasher},
rc::Rc,
};
use std::rc::Rc;

use acvm::{AcirField, FieldElement};
use builtin_helpers::{
Expand Down Expand Up @@ -43,7 +40,7 @@ use crate::{
Kind, QuotedType, ResolvedGeneric, Shared, Type, TypeVariable,
};

use self::builtin_helpers::{get_array, get_str, get_u8};
use self::builtin_helpers::{eq_item, get_array, get_str, get_u8, hash_item};
use super::Interpreter;

pub(crate) mod builtin_helpers;
Expand Down Expand Up @@ -104,9 +101,11 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"fresh_type_variable" => fresh_type_variable(interner),
"function_def_add_attribute" => function_def_add_attribute(self, arguments, location),
"function_def_body" => function_def_body(interner, arguments, location),
"function_def_eq" => function_def_eq(arguments, location),
"function_def_has_named_attribute" => {
function_def_has_named_attribute(interner, arguments, location)
}
"function_def_hash" => function_def_hash(arguments, location),
"function_def_is_unconstrained" => {
function_def_is_unconstrained(interner, arguments, location)
}
Expand All @@ -126,21 +125,24 @@ impl<'local, 'context> Interpreter<'local, 'context> {
function_def_set_unconstrained(self, arguments, location)
}
"module_add_item" => module_add_item(self, arguments, location),
"module_eq" => module_eq(arguments, location),
"module_functions" => module_functions(self, arguments, location),
"module_has_named_attribute" => module_has_named_attribute(self, arguments, location),
"module_hash" => module_hash(arguments, location),
"module_is_contract" => module_is_contract(self, arguments, location),
"module_name" => module_name(interner, arguments, location),
"module_structs" => module_structs(self, arguments, location),
"modulus_be_bits" => modulus_be_bits(interner, arguments, location),
"modulus_be_bytes" => modulus_be_bytes(interner, arguments, location),
"modulus_le_bits" => modulus_le_bits(interner, arguments, location),
"modulus_le_bytes" => modulus_le_bytes(interner, arguments, location),
"modulus_num_bits" => modulus_num_bits(interner, arguments, location),
"modulus_be_bits" => modulus_be_bits(arguments, location),
"modulus_be_bytes" => modulus_be_bytes(arguments, location),
"modulus_le_bits" => modulus_le_bits(arguments, location),
"modulus_le_bytes" => modulus_le_bytes(arguments, location),
"modulus_num_bits" => modulus_num_bits(arguments, location),
"quoted_as_expr" => quoted_as_expr(arguments, return_type, location),
"quoted_as_module" => quoted_as_module(self, arguments, return_type, location),
"quoted_as_trait_constraint" => quoted_as_trait_constraint(self, arguments, location),
"quoted_as_type" => quoted_as_type(self, arguments, location),
"quoted_eq" => quoted_eq(arguments, location),
"quoted_hash" => quoted_hash(arguments, location),
"quoted_tokens" => quoted_tokens(arguments, location),
"slice_insert" => slice_insert(interner, arguments, location),
"slice_pop_back" => slice_pop_back(interner, arguments, location, call_stack),
Expand All @@ -152,22 +154,24 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"struct_def_add_attribute" => struct_def_add_attribute(interner, arguments, location),
"struct_def_add_generic" => struct_def_add_generic(interner, arguments, location),
"struct_def_as_type" => struct_def_as_type(interner, arguments, location),
"struct_def_eq" => struct_def_eq(arguments, location),
"struct_def_fields" => struct_def_fields(interner, arguments, location),
"struct_def_generics" => struct_def_generics(interner, arguments, location),
"struct_def_has_named_attribute" => {
struct_def_has_named_attribute(interner, arguments, location)
}
"struct_def_hash" => struct_def_hash(arguments, location),
"struct_def_module" => struct_def_module(self, arguments, location),
"struct_def_name" => struct_def_name(interner, arguments, location),
"struct_def_set_fields" => struct_def_set_fields(interner, arguments, location),
"to_le_radix" => to_le_radix(arguments, return_type, location),
"trait_constraint_eq" => trait_constraint_eq(interner, arguments, location),
"trait_constraint_hash" => trait_constraint_hash(interner, arguments, location),
"trait_constraint_eq" => trait_constraint_eq(arguments, location),
"trait_constraint_hash" => trait_constraint_hash(arguments, location),
"trait_def_as_trait_constraint" => {
trait_def_as_trait_constraint(interner, arguments, location)
}
"trait_def_eq" => trait_def_eq(interner, arguments, location),
"trait_def_hash" => trait_def_hash(interner, arguments, location),
"trait_def_eq" => trait_def_eq(arguments, location),
"trait_def_hash" => trait_def_hash(arguments, location),
"trait_impl_methods" => trait_impl_methods(interner, arguments, location),
"trait_impl_trait_generic_args" => {
trait_impl_trait_generic_args(interner, arguments, location)
Expand All @@ -183,6 +187,7 @@ impl<'local, 'context> Interpreter<'local, 'context> {
"type_get_trait_impl" => {
type_get_trait_impl(interner, arguments, return_type, location)
}
"type_hash" => type_hash(arguments, location),
"type_implements" => type_implements(interner, arguments, location),
"type_is_bool" => type_is_bool(arguments, location),
"type_is_field" => type_is_field(arguments, location),
Expand Down Expand Up @@ -428,6 +433,14 @@ fn struct_def_generics(
Ok(Value::Slice(generics.collect(), typ))
}

fn struct_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_struct)
}

fn struct_def_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
eq_item(arguments, location, get_struct)
}

// fn has_named_attribute(self, name: Quoted) -> bool
fn struct_def_has_named_attribute(
interner: &NodeInterner,
Expand Down Expand Up @@ -904,12 +917,12 @@ where

// fn type_eq(_first: Type, _second: Type) -> bool
fn type_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
let (self_type, other_type) = check_two_arguments(arguments, location)?;

let self_type = get_type(self_type)?;
let other_type = get_type(other_type)?;
eq_item(arguments, location, get_type)
}

Ok(Value::Bool(self_type == other_type))
// fn type_hash(_t: Type) -> Field
fn type_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_type)
}

// fn get_trait_impl(self, constraint: TraitConstraint) -> Option<TraitImpl>
Expand Down Expand Up @@ -978,65 +991,23 @@ fn type_of(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Val
}

// fn constraint_hash(constraint: TraitConstraint) -> Field
fn trait_constraint_hash(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let argument = check_one_argument(arguments, location)?;

let bound = get_trait_constraint(argument)?;

let mut hasher = std::collections::hash_map::DefaultHasher::new();
bound.hash(&mut hasher);
let hash = hasher.finish();

Ok(Value::Field((hash as u128).into()))
fn trait_constraint_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_trait_constraint)
}

// fn constraint_eq(constraint_a: TraitConstraint, constraint_b: TraitConstraint) -> bool
fn trait_constraint_eq(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let (value_a, value_b) = check_two_arguments(arguments, location)?;

let constraint_a = get_trait_constraint(value_a)?;
let constraint_b = get_trait_constraint(value_b)?;

Ok(Value::Bool(constraint_a == constraint_b))
fn trait_constraint_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
eq_item(arguments, location, get_trait_constraint)
}

// fn trait_def_hash(def: TraitDefinition) -> Field
fn trait_def_hash(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let argument = check_one_argument(arguments, location)?;

let id = get_trait_def(argument)?;

let mut hasher = std::collections::hash_map::DefaultHasher::new();
id.hash(&mut hasher);
let hash = hasher.finish();

Ok(Value::Field((hash as u128).into()))
fn trait_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_trait_def)
}

// fn trait_def_eq(def_a: TraitDefinition, def_b: TraitDefinition) -> bool
fn trait_def_eq(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let (id_a, id_b) = check_two_arguments(arguments, location)?;

let id_a = get_trait_def(id_a)?;
let id_b = get_trait_def(id_b)?;

Ok(Value::Bool(id_a == id_b))
fn trait_def_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
eq_item(arguments, location, get_trait_def)
}

// fn methods(self) -> [FunctionDefinition]
Expand Down Expand Up @@ -2005,6 +1976,14 @@ fn function_def_has_named_attribute(
Ok(Value::Bool(has_named_attribute(&name, attributes, location)))
}

fn function_def_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_function_def)
}

fn function_def_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
eq_item(arguments, location, get_function_def)
}

// fn is_unconstrained(self) -> bool
fn function_def_is_unconstrained(
interner: &NodeInterner,
Expand Down Expand Up @@ -2271,6 +2250,14 @@ fn module_add_item(
Ok(Value::Unit)
}

fn module_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_module)
}

fn module_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
eq_item(arguments, location, get_module)
}

// fn functions(self) -> [FunctionDefinition]
fn module_functions(
interpreter: &Interpreter,
Expand Down Expand Up @@ -2361,11 +2348,7 @@ fn module_name(
Ok(Value::Quoted(tokens))
}

fn modulus_be_bits(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
fn modulus_be_bits(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
check_argument_count(0, &arguments, location)?;

let bits = FieldElement::modulus().to_radix_be(2);
Expand All @@ -2376,11 +2359,7 @@ fn modulus_be_bits(
Ok(Value::Slice(bits_vector, typ))
}

fn modulus_be_bytes(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
fn modulus_be_bytes(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
check_argument_count(0, &arguments, location)?;

let bytes = FieldElement::modulus().to_bytes_be();
Expand All @@ -2391,55 +2370,42 @@ fn modulus_be_bytes(
Ok(Value::Slice(bytes_vector, typ))
}

fn modulus_le_bits(
interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let Value::Slice(bits, typ) = modulus_be_bits(interner, arguments, location)? else {
fn modulus_le_bits(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
let Value::Slice(bits, typ) = modulus_be_bits(arguments, location)? else {
unreachable!("modulus_be_bits must return slice")
};
let reversed_bits = bits.into_iter().rev().collect();
Ok(Value::Slice(reversed_bits, typ))
}

fn modulus_le_bytes(
interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
let Value::Slice(bytes, typ) = modulus_be_bytes(interner, arguments, location)? else {
fn modulus_le_bytes(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
let Value::Slice(bytes, typ) = modulus_be_bytes(arguments, location)? else {
unreachable!("modulus_be_bytes must return slice")
};
let reversed_bytes = bytes.into_iter().rev().collect();
Ok(Value::Slice(reversed_bytes, typ))
}

fn modulus_num_bits(
_interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> IResult<Value> {
fn modulus_num_bits(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
check_argument_count(0, &arguments, location)?;
let bits = FieldElement::max_num_bits().into();
Ok(Value::U64(bits))
}

// fn quoted_eq(_first: Quoted, _second: Quoted) -> bool
fn quoted_eq(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
let (self_value, other_value) = check_two_arguments(arguments, location)?;

let self_quoted = get_quoted(self_value)?;
let other_quoted = get_quoted(other_value)?;
eq_item(arguments, location, get_quoted)
}

Ok(Value::Bool(self_quoted == other_quoted))
fn quoted_hash(arguments: Vec<(Value, Location)>, location: Location) -> IResult<Value> {
hash_item(arguments, location, get_quoted)
}

fn trait_def_as_trait_constraint(
interner: &mut NodeInterner,
arguments: Vec<(Value, Location)>,
location: Location,
) -> Result<Value, InterpreterError> {
) -> IResult<Value> {
let argument = check_one_argument(arguments, location)?;

let trait_id = get_trait_def(argument)?;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::rc::Rc;
use std::hash::Hash;
use std::{hash::Hasher, rc::Rc};

use acvm::FieldElement;
use noirc_errors::Location;
Expand Down Expand Up @@ -471,3 +472,28 @@ pub(super) fn has_named_attribute<'a>(

false
}

pub(super) fn hash_item<T: Hash>(
arguments: Vec<(Value, Location)>,
location: Location,
get_item: impl FnOnce((Value, Location)) -> IResult<T>,
) -> IResult<Value> {
let argument = check_one_argument(arguments, location)?;
let item = get_item(argument)?;

let mut hasher = std::collections::hash_map::DefaultHasher::new();
item.hash(&mut hasher);
let hash = hasher.finish();
Ok(Value::Field((hash as u128).into()))
}

pub(super) fn eq_item<T: Eq>(
arguments: Vec<(Value, Location)>,
location: Location,
mut get_item: impl FnMut((Value, Location)) -> IResult<T>,
) -> IResult<Value> {
let (self_arg, other_arg) = check_two_arguments(arguments, location)?;
let self_arg = get_item(self_arg)?;
let other_arg = get_item(other_arg)?;
Ok(Value::Bool(self_arg == other_arg))
}
11 changes: 11 additions & 0 deletions docs/docs/noir/standard_library/meta/function_def.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,14 @@ This means any functions called at compile-time are invalid targets for this met
Mutates the function to be unconstrained (if `true` is given) or not (if `false` is given).
This is only valid on functions in the current crate which have not yet been resolved.
This means any functions called at compile-time are invalid targets for this method.

## Trait Implementations

```rust
impl Eq for FunctionDefinition
impl Hash for FunctionDefinition
```

Note that each function is assigned a unique ID internally and this is what is used for
equality and hashing. So even functions with identical signatures and bodies may not
be equal in this sense if they were originally different items in the source program.
11 changes: 11 additions & 0 deletions docs/docs/noir/standard_library/meta/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,14 @@ Returns the name of the module.
#include_code structs noir_stdlib/src/meta/module.nr rust

Returns each struct defined in the module.

## Trait Implementations

```rust
impl Eq for Module
impl Hash for Module
```

Note that each module is assigned a unique ID internally and this is what is used for
equality and hashing. So even modules with identical names and contents may not
be equal in this sense if they were originally different items in the source program.
Loading
Loading