Skip to content

Commit b0554a6

Browse files
authored
Merge 5b805bf into 8dec847
2 parents 8dec847 + 5b805bf commit b0554a6

File tree

7 files changed

+441
-183
lines changed

7 files changed

+441
-183
lines changed

noir_stdlib/src/hash/sha256.nr

+146-132
Original file line numberDiff line numberDiff line change
@@ -3,101 +3,63 @@ use crate::runtime::is_unconstrained;
33
// Implementation of SHA-256 mapping a byte array of variable length to
44
// 32 bytes.
55

6+
// A message block is up to 64 bytes taken from the input.
7+
global BLOCK_SIZE = 64;
8+
9+
// The first index in the block where the 8 byte message size will be written.
10+
global MSG_SIZE_PTR = 56;
11+
12+
// Size of the message block when packed as 4-byte integer array.
13+
global INT_BLOCK_SIZE = 16;
14+
15+
// Index of a byte in a 64 byte block; ie. 0..=63
16+
type BLOCK_BYTE_PTR = u32;
17+
18+
// The foreign function to compress blocks works on 16 pieces of 4-byte integers, instead of 64 bytes.
19+
type INT_BLOCK = [u32; INT_BLOCK_SIZE];
20+
21+
// A message block is a slice of the original message of a fixed size,
22+
// potentially padded with zeroes.
23+
type MSG_BLOCK = [u8; BLOCK_SIZE];
24+
25+
// The hash is 32 bytes.
26+
type HASH = [u8; 32];
27+
28+
// The state accumulates the blocks.
29+
// Its overall size is the same as the `HASH`.
30+
type STATE = [u32; 8];
31+
632
// Deprecated in favour of `sha256_var`
733
// docs:start:sha256
8-
pub fn sha256<let N: u32>(input: [u8; N]) -> [u8; 32]
34+
pub fn sha256<let N: u32>(input: [u8; N]) -> HASH
935
// docs:end:sha256
1036
{
1137
digest(input)
1238
}
1339

1440
#[foreign(sha256_compression)]
15-
pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {}
41+
pub fn sha256_compression(_input: INT_BLOCK, _state: STATE) -> STATE {}
1642

1743
// SHA-256 hash function
1844
#[no_predicates]
19-
pub fn digest<let N: u32>(msg: [u8; N]) -> [u8; 32] {
45+
pub fn digest<let N: u32>(msg: [u8; N]) -> HASH {
2046
sha256_var(msg, N as u64)
2147
}
2248

23-
// Convert 64-byte array to array of 16 u32s
24-
fn msg_u8_to_u32(msg: [u8; 64]) -> [u32; 16] {
25-
let mut msg32: [u32; 16] = [0; 16];
26-
27-
for i in 0..16 {
28-
let mut msg_field: Field = 0;
29-
for j in 0..4 {
30-
msg_field = msg_field * 256 + msg[64 - 4 * (i + 1) + j] as Field;
31-
}
32-
msg32[15 - i] = msg_field as u32;
33-
}
34-
35-
msg32
36-
}
37-
38-
unconstrained fn build_msg_block_iter<let N: u32>(
39-
msg: [u8; N],
40-
message_size: u32,
41-
msg_start: u32,
42-
) -> ([u8; 64], u32) {
43-
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
44-
// We insert `BLOCK_SIZE` bytes (or up to the end of the message)
45-
let block_input = if msg_start + BLOCK_SIZE > message_size {
46-
if message_size < msg_start {
47-
// This function is sometimes called with `msg_start` past the end of the message.
48-
// In this case we return an empty block and zero pointer to signal that the result should be ignored.
49-
0
50-
} else {
51-
message_size - msg_start
52-
}
53-
} else {
54-
BLOCK_SIZE
55-
};
56-
for k in 0..block_input {
57-
msg_block[k] = msg[msg_start + k];
58-
}
59-
(msg_block, block_input)
60-
}
61-
62-
// Verify the block we are compressing was appropriately constructed
63-
fn verify_msg_block<let N: u32>(
64-
msg: [u8; N],
65-
message_size: u32,
66-
msg_block: [u8; 64],
67-
msg_start: u32,
68-
) -> u32 {
69-
let mut msg_byte_ptr: u32 = 0; // Message byte pointer
70-
let mut msg_end = msg_start + BLOCK_SIZE;
71-
if msg_end > N {
72-
msg_end = N;
73-
}
74-
75-
for k in msg_start..msg_end {
76-
if k < message_size {
77-
assert_eq(msg_block[msg_byte_ptr], msg[k]);
78-
msg_byte_ptr = msg_byte_ptr + 1;
79-
}
80-
}
81-
82-
msg_byte_ptr
83-
}
84-
85-
global BLOCK_SIZE = 64;
86-
8749
// Variable size SHA-256 hash
88-
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
50+
pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> HASH {
8951
let message_size = message_size as u32;
9052
let num_blocks = N / BLOCK_SIZE;
91-
let mut msg_block: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
92-
let mut h: [u32; 8] = [
53+
let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE];
54+
let mut h: STATE = [
9355
1779033703, 3144134277, 1013904242, 2773480762, 1359893119, 2600822924, 528734635,
9456
1541459225,
9557
]; // Intermediate hash, starting with the canonical initial value
9658
let mut msg_byte_ptr = 0; // Pointer into msg_block
9759
for i in 0..num_blocks {
9860
let msg_start = BLOCK_SIZE * i;
9961
let (new_msg_block, new_msg_byte_ptr) =
100-
unsafe { build_msg_block_iter(msg, message_size, msg_start) };
62+
unsafe { build_msg_block(msg, message_size, msg_start) };
10163
if msg_start < message_size {
10264
msg_block = new_msg_block;
10365
}
@@ -126,7 +88,7 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
12688
if modulo != 0 {
12789
let msg_start = BLOCK_SIZE * num_blocks;
12890
let (new_msg_block, new_msg_byte_ptr) =
129-
unsafe { build_msg_block_iter(msg, message_size, msg_start) };
91+
unsafe { build_msg_block(msg, message_size, msg_start) };
13092

13193
if msg_start < message_size {
13294
msg_block = new_msg_block;
@@ -136,116 +98,168 @@ pub fn sha256_var<let N: u32>(msg: [u8; N], message_size: u64) -> [u8; 32] {
13698
let new_msg_byte_ptr = verify_msg_block(msg, message_size, msg_block, msg_start);
13799
if msg_start < message_size {
138100
msg_byte_ptr = new_msg_byte_ptr;
101+
verify_msg_block_padding(msg_block, msg_byte_ptr);
139102
}
140103
} else if msg_start < message_size {
141104
msg_byte_ptr = new_msg_byte_ptr;
142105
}
143106
}
144107

108+
// If we had modulo == 0 then it means the last block was full,
109+
// and we can reset the pointer to zero to overwrite it.
145110
if msg_byte_ptr == BLOCK_SIZE {
146111
msg_byte_ptr = 0;
147112
}
148113

149-
// This variable is used to get around the compiler under-constrained check giving a warning.
150-
// We want to check against a constant zero, but if it does not come from the circuit inputs
151-
// or return values the compiler check will issue a warning.
152-
let zero = msg_block[0] - msg_block[0];
153-
154114
// Pad the rest such that we have a [u32; 2] block at the end representing the length
155115
// of the message, and a block of 1 0 ... 0 following the message (i.e. [1 << 7, 0, ..., 0]).
116+
// Here we rely on the fact that everything beyond the available input is set to 0.
156117
msg_block[msg_byte_ptr] = 1 << 7;
157118
let last_block = msg_block;
158119
msg_byte_ptr = msg_byte_ptr + 1;
159120

160-
unsafe {
161-
let (new_msg_block, new_msg_byte_ptr) = pad_msg_block(msg_block, msg_byte_ptr);
162-
msg_block = new_msg_block;
163-
if crate::runtime::is_unconstrained() {
164-
msg_byte_ptr = new_msg_byte_ptr;
165-
}
121+
// If we don't have room to write the size, compress the block and reset it.
122+
if msg_byte_ptr > MSG_SIZE_PTR {
123+
h = sha256_compression(msg_u8_to_u32(msg_block), h);
124+
// `attach_len_to_msg_block` will zero out everything after the `msg_byte_ptr`.
125+
msg_byte_ptr = 0;
166126
}
167127

168-
if !crate::runtime::is_unconstrained() {
169-
for i in 0..BLOCK_SIZE {
170-
assert_eq(msg_block[i], last_block[i]);
171-
}
128+
msg_block = unsafe { attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) };
172129

173-
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
174-
// the 1 and 0s fill up the current block, which we then compress accordingly.
175-
// Not enough bits (64) to store length. Fill up with zeros.
176-
for _i in 57..BLOCK_SIZE {
177-
if msg_byte_ptr <= 63 & msg_byte_ptr >= 57 {
178-
assert_eq(msg_block[msg_byte_ptr], zero);
179-
msg_byte_ptr += 1;
180-
}
181-
}
130+
if !crate::runtime::is_unconstrained() {
131+
verify_msg_len(msg_block, last_block, msg_byte_ptr, message_size);
182132
}
183133

184-
if msg_byte_ptr >= 57 {
185-
h = sha256_compression(msg_u8_to_u32(msg_block), h);
134+
hash_final_block(msg_block, h)
135+
}
186136

187-
msg_byte_ptr = 0;
137+
// Convert 64-byte array to array of 16 u32s
138+
fn msg_u8_to_u32(msg: MSG_BLOCK) -> INT_BLOCK {
139+
let mut msg32: INT_BLOCK = [0; INT_BLOCK_SIZE];
140+
141+
for i in 0..INT_BLOCK_SIZE {
142+
let mut msg_field: Field = 0;
143+
for j in 0..4 {
144+
msg_field = msg_field * 256 + msg[64 - 4 * (i + 1) + j] as Field;
145+
}
146+
msg32[15 - i] = msg_field as u32;
188147
}
189148

190-
msg_block = unsafe { attach_len_to_msg_block(msg_block, msg_byte_ptr, message_size) };
149+
msg32
150+
}
191151

192-
if !crate::runtime::is_unconstrained() {
193-
for i in 0..56 {
194-
let predicate = (i < msg_byte_ptr) as u8;
195-
let expected_byte = predicate * last_block[i];
196-
assert_eq(msg_block[i], expected_byte);
152+
// Take `BLOCK_SIZE` number of bytes from `msg` starting at `msg_start`.
153+
// Returns the block and the length that has been copied rather than padded with zeroes.
154+
unconstrained fn build_msg_block<let N: u32>(
155+
msg: [u8; N],
156+
message_size: u32,
157+
msg_start: u32,
158+
) -> (MSG_BLOCK, BLOCK_BYTE_PTR) {
159+
let mut msg_block: MSG_BLOCK = [0; BLOCK_SIZE];
160+
// We insert `BLOCK_SIZE` bytes (or up to the end of the message)
161+
let block_input = if msg_start + BLOCK_SIZE > message_size {
162+
if message_size < msg_start {
163+
// This function is sometimes called with `msg_start` past the end of the message.
164+
// In this case we return an empty block and zero pointer to signal that the result should be ignored.
165+
0
166+
} else {
167+
message_size - msg_start
197168
}
169+
} else {
170+
BLOCK_SIZE
171+
};
172+
for k in 0..block_input {
173+
msg_block[k] = msg[msg_start + k];
174+
}
175+
(msg_block, block_input)
176+
}
177+
178+
// Verify the block we are compressing was appropriately constructed by `build_msg_block`
179+
// and matches the input data. Returns the index of the first unset item.
180+
fn verify_msg_block<let N: u32>(
181+
msg: [u8; N],
182+
message_size: u32,
183+
msg_block: MSG_BLOCK,
184+
msg_start: u32,
185+
) -> BLOCK_BYTE_PTR {
186+
let mut msg_byte_ptr: u32 = 0; // Message byte pointer
187+
let mut msg_end = msg_start + BLOCK_SIZE;
188+
if msg_end > N {
189+
msg_end = N;
190+
}
198191

199-
// We verify the message length was inserted correctly by reversing the byte decomposition.
200-
let len = 8 * message_size;
201-
let mut reconstructed_len: Field = 0;
202-
for i in 56..64 {
203-
reconstructed_len = 256 * reconstructed_len + msg_block[i] as Field;
192+
for k in msg_start..msg_end {
193+
if k < message_size {
194+
assert_eq(msg_block[msg_byte_ptr], msg[k]);
195+
msg_byte_ptr = msg_byte_ptr + 1;
204196
}
205-
assert_eq(reconstructed_len, len as Field);
206197
}
207198

208-
hash_final_block(msg_block, h)
199+
msg_byte_ptr
209200
}
210201

211-
unconstrained fn pad_msg_block(
212-
mut msg_block: [u8; 64],
213-
mut msg_byte_ptr: u32,
214-
) -> ([u8; BLOCK_SIZE], u32) {
215-
// If i >= 57, there aren't enough bits in the current message block to accomplish this, so
216-
// the 1 and 0s fill up the current block, which we then compress accordingly.
217-
if msg_byte_ptr >= 57 {
218-
// Not enough bits (64) to store length. Fill up with zeros.
219-
for i in msg_byte_ptr..BLOCK_SIZE {
220-
msg_block[i] = 0;
202+
// Verify the block we are compressing was appropriately padded with zeroes by `build_msg_block`.
203+
// This is only relevant for the last, potentially partially filled block.
204+
fn verify_msg_block_padding(msg_block: MSG_BLOCK, msg_byte_ptr: BLOCK_BYTE_PTR) {
205+
// This variable is used to get around the compiler under-constrained check giving a warning.
206+
// We want to check against a constant zero, but if it does not come from the circuit inputs
207+
// or return values the compiler check will issue a warning.
208+
let zero = msg_block[0] - msg_block[0];
209+
210+
for i in 0..BLOCK_SIZE {
211+
if i >= msg_byte_ptr {
212+
assert_eq(msg_block[i], zero);
221213
}
222-
(msg_block, BLOCK_SIZE)
223-
} else {
224-
(msg_block, msg_byte_ptr)
225214
}
226215
}
227216

217+
// Zero out all bytes between the end of the message and where the length is appended,
218+
// then write the length into the last 8 bytes of the block.
228219
unconstrained fn attach_len_to_msg_block(
229-
mut msg_block: [u8; BLOCK_SIZE],
230-
msg_byte_ptr: u32,
220+
mut msg_block: MSG_BLOCK,
221+
msg_byte_ptr: BLOCK_BYTE_PTR,
231222
message_size: u32,
232-
) -> [u8; BLOCK_SIZE] {
223+
) -> MSG_BLOCK {
233224
// We assume that `msg_byte_ptr` is less than 57 because if not then it is reset to zero before calling this function.
234225
// In any case, fill blocks up with zeros until the last 64 (i.e. until msg_byte_ptr = 56).
235-
for i in msg_byte_ptr..56 {
226+
for i in msg_byte_ptr..MSG_SIZE_PTR {
236227
msg_block[i] = 0;
237228
}
238229

239230
let len = 8 * message_size;
240231
let len_bytes: [u8; 8] = (len as Field).to_be_bytes();
241232
for i in 0..8 {
242-
msg_block[56 + i] = len_bytes[i];
233+
msg_block[MSG_SIZE_PTR + i] = len_bytes[i];
243234
}
244235
msg_block
245236
}
246237

247-
fn hash_final_block(msg_block: [u8; BLOCK_SIZE], mut state: [u32; 8]) -> [u8; 32] {
248-
let mut out_h: [u8; 32] = [0; 32]; // Digest as sequence of bytes
238+
// Verify that the message length was correctly written by `attach_len_to_msg_block`.
239+
fn verify_msg_len(
240+
msg_block: MSG_BLOCK,
241+
last_block: MSG_BLOCK,
242+
msg_byte_ptr: BLOCK_BYTE_PTR,
243+
message_size: u32,
244+
) {
245+
for i in 0..MSG_SIZE_PTR {
246+
let predicate = (i < msg_byte_ptr) as u8;
247+
let expected_byte = predicate * last_block[i];
248+
assert_eq(msg_block[i], expected_byte);
249+
}
250+
251+
// We verify the message length was inserted correctly by reversing the byte decomposition.
252+
let len = 8 * message_size;
253+
let mut reconstructed_len: Field = 0;
254+
for i in MSG_SIZE_PTR..BLOCK_SIZE {
255+
reconstructed_len = 256 * reconstructed_len + msg_block[i] as Field;
256+
}
257+
assert_eq(reconstructed_len, len as Field);
258+
}
259+
260+
// Perform the final compression, then transform the `STATE` into `HASH`.
261+
fn hash_final_block(msg_block: MSG_BLOCK, mut state: STATE) -> HASH {
262+
let mut out_h: HASH = [0; 32]; // Digest as sequence of bytes
249263
// Hash final padded block
250264
state = sha256_compression(msg_u8_to_u32(msg_block), state);
251265

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[package]
2+
name = "bench_sha256_long"
3+
version = "0.1.0"
4+
type = "bin"
5+
authors = [""]
6+
7+
[dependencies]

0 commit comments

Comments
 (0)