Skip to content

Commit 06d744a

Browse files
committed
Rework internals of transform!(::SHA3_CTX) to use tuples
1 parent 4451e13 commit 06d744a

File tree

4 files changed

+64
-92
lines changed

4 files changed

+64
-92
lines changed

src/constants.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ const SHA3_ROUND_CONSTS = UInt64[
138138
]
139139

140140
# Rotation constants for SHA3 rounds
141-
const SHA3_ROTC = UInt64[
142-
1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 2, 14,
143-
27, 41, 56, 8, 25, 43, 62, 18, 39, 61, 20, 44
144-
]
141+
const SHA3_ROTC = (
142+
0, 1, 62, 28, 27, 36, 44, 6, 55, 20, 3, 10, 43,
143+
25, 39, 41, 45, 15, 21, 8, 18, 2, 61, 56, 14,
144+
)
145145

146146
# Permutation indices for SHA3 rounds (+1'ed so as to work with julia's 1-based indexing)
147-
const SHA3_PILN = Int[
148-
11, 8, 12, 18, 19, 4, 6, 17, 9, 22, 25, 5,
149-
16, 24, 20, 14, 13, 3, 21, 15, 23, 10, 7, 2
150-
]
147+
const SHA3_PILN = (
148+
1, 7, 13, 19, 25, 4, 10, 11, 17, 23, 2, 8, 14,
149+
20, 21, 5, 6, 12, 18, 24, 3, 9, 15, 16, 22,
150+
)

src/sha3.jl

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,47 @@
1+
@inline function keccak_theta(state::NTuple{25,UInt64})
2+
C = ntuple(i -> state[i] state[i + 5] state[i + 10] state[i + 15] state[i + 20], Val(5))
3+
D = ntuple(i -> C[rem(i + 3, 5) + 1] L64(1, C[rem(i, 5) + 1]), Val(5))
4+
return ntuple(k -> state[k] D[rem(k - 1, 5) + 1], Val(25))
5+
end
6+
7+
@inline keccak_rho(state::NTuple{25,UInt64}) =
8+
ntuple(k -> bitrotate(state[k], SHA3_ROTC[k]), Val(25))
9+
10+
@inline keccak_pi(state::NTuple{25,UInt64}) =
11+
ntuple(k -> state[SHA3_PILN[k]], Val(25))
12+
13+
@inline function keccak_chi(state::NTuple{25,UInt64})
14+
return ntuple(
15+
k -> let j = k - rem(k - 1, 5)
16+
state[k] (~state[rem(k, 5) + j] & state[rem(k + 1, 5) + j])
17+
end,
18+
Val(25)
19+
)
20+
end
21+
22+
@inline keccak_iota(round, state::NTuple{25,UInt64}) =
23+
(state[1] SHA3_ROUND_CONSTS[round+1], state[2:end]...)
24+
125
function transform!(context::T) where {T<:SHA3_CTX}
226
# First, update state with buffer
327
pbuf = Ptr{eltype(context.state)}(pointer(context.buffer))
428
for idx in 1:div(blocklen(T),8)
529
context.state[idx] = context.state[idx] unsafe_load(pbuf, idx)
630
end
7-
bc = context.bc
8-
state = context.state
9-
10-
# We always assume 24 rounds
11-
@inbounds for round in 0:23
12-
# Theta function
13-
for i in 1:5
14-
bc[i] = state[i] state[i + 5] state[i + 10] state[i + 15] state[i + 20]
15-
end
16-
17-
for i in 0:4
18-
temp = bc[rem(i + 4, 5) + 1] L64(1, bc[rem(i + 1, 5) + 1])
19-
j = 0
20-
while j <= 20
21-
state[Int(i + j + 1)] = state[i + j + 1] temp
22-
j += 5
23-
end
24-
end
2531

26-
# Rho Pi
27-
temp = state[2]
28-
for i in 1:24
29-
j = SHA3_PILN[i]
30-
bc[1] = state[j]
31-
state[j] = L64(SHA3_ROTC[i], temp)
32-
temp = bc[1]
33-
end
32+
state = let s = context.state; ntuple(i -> s[i], Val(25)); end
3433

35-
# Chi
36-
j = 0
37-
while j <= 20
38-
for i in 1:5
39-
bc[i] = state[i + j]
40-
end
41-
for i in 0:4
42-
state[j + i + 1] = state[j + i + 1] (~bc[rem(i + 1, 5) + 1] & bc[rem(i + 2, 5) + 1])
43-
end
44-
j += 5
45-
end
34+
# We always assume 24 rounds
35+
for round in 0:23
36+
state = keccak_theta(state)
37+
state = keccak_rho(state)
38+
state = keccak_pi(state)
39+
state = keccak_chi(state)
40+
state = keccak_iota(round, state)
41+
end
4642

47-
# Iota
48-
state[1] = state[1] SHA3_ROUND_CONSTS[round+1]
43+
for k in 1:25
44+
context.state[k] = state[k]
4945
end
5046

5147
return context.state

src/shake.jl

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@ mutable struct SHAKE_128_CTX <: SHAKE
44
state::Vector{UInt64}
55
bytecount::UInt128
66
buffer::Vector{UInt8}
7-
bc::Vector{UInt64}
87
used::Bool
98
end
109
mutable struct SHAKE_256_CTX <: SHAKE
1110
state::Vector{UInt64}
1211
bytecount::UInt128
1312
buffer::Vector{UInt8}
14-
bc::Vector{UInt64}
1513
used::Bool
1614
end
1715

@@ -22,8 +20,8 @@ blocklen(::Type{SHAKE_256_CTX}) = UInt64(25*8 - 2*digestlen(SHAKE_256_CTX))
2220
buffer_pointer(ctx::T) where {T<:SHAKE} = Ptr{state_type(T)}(pointer(ctx.buffer))
2321

2422
# construct an empty SHA context
25-
SHAKE_128_CTX() = SHAKE_128_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHAKE_128_CTX)), Vector{UInt64}(undef, 5), false)
26-
SHAKE_256_CTX() = SHAKE_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHAKE_256_CTX)), Vector{UInt64}(undef, 5), false)
23+
SHAKE_128_CTX() = SHAKE_128_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHAKE_128_CTX)), false)
24+
SHAKE_256_CTX() = SHAKE_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHAKE_256_CTX)), false)
2725

2826
function transform!(context::T) where {T<:SHAKE}
2927
# First, update state with buffer
@@ -34,40 +32,22 @@ function transform!(context::T) where {T<:SHAKE}
3432
context.state[idx] = context.state[idx] unsafe_load(pbuf, idx)
3533
end
3634
end
37-
bc = context.bc
38-
state = context.state
35+
36+
state = let s = context.state; ntuple(i -> s[i], Val(25)); end
37+
3938
# We always assume 24 rounds
40-
@inbounds for round in 0:23
41-
# Theta function
42-
for i in 1:5
43-
bc[i] = state[i] state[i + 5] state[i + 10] state[i + 15] state[i + 20]
44-
end
45-
for i in 0:4
46-
temp = bc[rem(i + 4, 5) + 1] L64(1, bc[rem(i + 1, 5) + 1])
47-
for j in 0:5:20
48-
state[Int(i + j + 1)] = state[i + j + 1] temp
49-
end
50-
end
51-
# Rho Pi
52-
temp = state[2]
53-
for i in 1:24
54-
j = SHA3_PILN[i]
55-
bc[1] = state[j]
56-
state[j] = L64(SHA3_ROTC[i], temp)
57-
temp = bc[1]
58-
end
59-
# Chi
60-
for j in 0:5:20
61-
for i in 1:5
62-
bc[i] = state[i + j]
63-
end
64-
for i in 0:4
65-
state[j + i + 1] = state[j + i + 1] (~bc[rem(i + 1, 5) + 1] & bc[rem(i + 2, 5) + 1])
66-
end
67-
end
68-
# Iota
69-
state[1] = state[1] SHA3_ROUND_CONSTS[round+1]
39+
for round in 0:23
40+
state = keccak_theta(state)
41+
state = keccak_rho(state)
42+
state = keccak_pi(state)
43+
state = keccak_chi(state)
44+
state = keccak_iota(round, state)
45+
end
46+
47+
for k in 1:25
48+
context.state[k] = state[k]
7049
end
50+
7151
return context.state
7252
end
7353
function digest!(context::T,d::UInt,p::Ptr{UInt8}) where {T<:SHAKE}

src/types.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,28 +72,24 @@ mutable struct SHA3_224_CTX <: SHA3_CTX
7272
state::Vector{UInt64}
7373
bytecount::UInt128
7474
buffer::Vector{UInt8}
75-
bc::Vector{UInt64}
7675
used::Bool
7776
end
7877
mutable struct SHA3_256_CTX <: SHA3_CTX
7978
state::Vector{UInt64}
8079
bytecount::UInt128
8180
buffer::Vector{UInt8}
82-
bc::Vector{UInt64}
8381
used::Bool
8482
end
8583
mutable struct SHA3_384_CTX <: SHA3_CTX
8684
state::Vector{UInt64}
8785
bytecount::UInt128
8886
buffer::Vector{UInt8}
89-
bc::Vector{UInt64}
9087
used::Bool
9188
end
9289
mutable struct SHA3_512_CTX <: SHA3_CTX
9390
state::Vector{UInt64}
9491
bytecount::UInt128
9592
buffer::Vector{UInt8}
96-
bc::Vector{UInt64}
9793
used::Bool
9894
end
9995

@@ -189,25 +185,25 @@ SHA2_512_256_CTX() = SHA2_512_256_CTX(copy(SHA2_512_256_initial_hash_value), 0,
189185
190186
Construct an empty SHA3_224 context.
191187
"""
192-
SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX)), Vector{UInt64}(undef, 5), false)
188+
SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX)), false)
193189
"""
194190
SHA3_256_CTX()
195191
196192
Construct an empty SHA3_256 context.
197193
"""
198-
SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX)), Vector{UInt64}(undef, 5), false)
194+
SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX)), false)
199195
"""
200196
SHA3_384_CTX()
201197
202198
Construct an empty SHA3_384 context.
203199
"""
204-
SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX)), Vector{UInt64}(undef, 5), false)
200+
SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX)), false)
205201
"""
206202
SHA3_512_CTX()
207203
208204
Construct an empty SHA3_512 context.
209205
"""
210-
SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX)), Vector{UInt64}(undef, 5), false)
206+
SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX)), false)
211207

212208
# SHA1 is special; he needs extra workspace
213209
"""
@@ -221,7 +217,7 @@ SHA1_CTX() = SHA1_CTX(copy(SHA1_initial_hash_value), 0, zeros(UInt8, blocklen(SH
221217
# Copy functions
222218
copy(ctx::T) where {T<:SHA1_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), copy(ctx.W), ctx.used)
223219
copy(ctx::T) where {T<:SHA2_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), ctx.used)
224-
copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), Vector{UInt64}(undef, 5), ctx.used)
220+
copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), ctx.used)
225221

226222

227223
# Make printing these types a little friendlier

0 commit comments

Comments
 (0)