Skip to content

Commit cacdf74

Browse files
Merge pull request #322 from SciML/fm/show
feat: add nice show for layers and models
2 parents bbfe771 + 397001b commit cacdf74

File tree

8 files changed

+182
-16
lines changed

8 files changed

+182
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReservoirComputing"
22
uuid = "7c2d2b1e-3dd4-11ea-355a-8f6a8116e294"
33
authors = ["Francesco Martinuzzi"]
4-
version = "0.12.2"
4+
version = "0.12.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/layers/basic.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,32 @@ function (dl::DelayLayer)(inp::AbstractVecOrMat, ps, st::NamedTuple)
338338
return inp_with_delay, (history = history, clock = clock, rng = st.rng)
339339
end
340340

341+
function Base.show(io::IO, dl::DelayLayer)
342+
print(io, "DelayLayer(", dl.in_dims, "; num_delays=", dl.num_delays)
343+
344+
if dl.stride != 1
345+
print(io, ", stride=", dl.stride)
346+
end
347+
348+
inits = dl.init_delay
349+
if !isempty(inits)
350+
all_same = all(f -> f === first(inits), inits)
351+
if all_same && first(inits) === zeros32
352+
elseif all_same
353+
print(io, ", init_delay=", first(inits))
354+
else
355+
print(io, ", init_delay=(")
356+
for (i, f) in enumerate(inits)
357+
i > 1 && print(io, ", ")
358+
show(io, f)
359+
end
360+
print(io, ")")
361+
end
362+
end
363+
364+
print(io, ")")
365+
end
366+
341367
@doc raw"""
342368
NonlinearFeaturesLayer(features...; include_input=true)
343369
@@ -417,3 +443,27 @@ function (nfl::NonlinearFeaturesLayer)(inp::AbstractVector, ps, st)
417443

418444
return out, st
419445
end
446+
447+
function Base.show(io::IO, nfl::NonlinearFeaturesLayer)
448+
print(io, "NonlinearFeaturesLayer(")
449+
450+
if isempty(nfl.features)
451+
print(io, "features=()")
452+
else
453+
print(io, "features=(")
454+
for (i, f) in enumerate(nfl.features)
455+
i > 1 && print(io, ", ")
456+
show(io, f)
457+
end
458+
print(io, ")")
459+
end
460+
461+
inc = ReservoirComputing.known(nfl.include_input)
462+
if inc === true
463+
print(io, ", include_input=true")
464+
elseif inc === false
465+
print(io, ", include_input=false")
466+
end
467+
468+
print(io, ")")
469+
end

src/layers/esn_cell.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
function ESNCell((in_dims, out_dims)::Pair{<:IntegerType, <:IntegerType},
8282
activation = tanh; use_bias::BoolType = False(), init_bias = zeros32,
8383
init_reservoir = rand_sparse, init_input = scaled_rand,
84-
init_state = randn32, leak_coefficient = 1.0)
84+
init_state = randn32, leak_coefficient::AbstractFloat = 1.0)
8585
return ESNCell(activation, in_dims, out_dims, init_bias, init_reservoir,
8686
init_input, init_state, leak_coefficient, use_bias)
8787
end

src/layers/lux_layers.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ function applyrecurrentcell(sl::AbstractReservoirRecurrentCell, inp, ps, st, ::N
4545
return apply(sl, inp, ps, st)
4646
end
4747

48+
function Base.show(io::IO, sl::StatefulLayer)
49+
print(io, "StatefulLayer(")
50+
show(io, sl.cell)
51+
print(io, ")")
52+
end
53+
4854
@doc raw"""
4955
ReservoirChain(layers...; name=nothing)
5056
ReservoirChain(xs::AbstractVector; name=nothing)

src/models/esn.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,5 +102,32 @@ function ESN(in_dims::IntegerType, res_dims::IntegerType,
102102
Tuple(state_modifiers) : (state_modifiers,)
103103
mods = _wrap_layers(mods_tuple)
104104
ro = LinearReadout(res_dims => out_dims, readout_activation)
105-
return ReservoirComputer(cell, mods, ro)
105+
return ESN(cell, mods, ro)
106+
end
107+
108+
function Base.show(io::IO, esn::ESN)
109+
print(io, "ESN(\n")
110+
111+
print(io, " reservoir = ")
112+
show(io, esn.reservoir)
113+
print(io, ",\n")
114+
115+
print(io, " state_modifiers = ")
116+
if isempty(esn.states_modifiers)
117+
print(io, "()")
118+
else
119+
print(io, "(")
120+
for (i, m) in enumerate(esn.states_modifiers)
121+
i > 1 && print(io, ", ")
122+
show(io, m)
123+
end
124+
print(io, ")")
125+
end
126+
print(io, ",\n")
127+
128+
print(io, " readout = ")
129+
show(io, esn.readout)
130+
print(io, "\n)")
131+
132+
return
106133
end

src/models/esn_delay.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,32 @@ function DelayESN(
111111
ro_in_dims = res_dims * (num_delays + 1)
112112
ro = LinearReadout(ro_in_dims => out_dims, readout_activation)
113113

114-
return ReservoirComputer(cell, mods, ro)
114+
return DelayESN(cell, mods, ro)
115+
end
116+
117+
function Base.show(io::IO, esn::DelayESN)
118+
print(io, "DelayESN(\n")
119+
120+
print(io, " reservoir = ")
121+
show(io, esn.reservoir)
122+
print(io, ",\n")
123+
124+
print(io, " state_modifiers = ")
125+
if isempty(esn.states_modifiers)
126+
print(io, "()")
127+
else
128+
print(io, "(")
129+
for (i, m) in enumerate(esn.states_modifiers)
130+
i > 1 && print(io, ", ")
131+
show(io, m)
132+
end
133+
print(io, ")")
134+
end
135+
print(io, ",\n")
136+
137+
print(io, " readout = ")
138+
show(io, esn.readout)
139+
print(io, "\n)")
140+
141+
return
115142
end

src/models/esn_hybrid.jl

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ throughout the reservoir and readout computations.
5656
## Parameters
5757
5858
- `knowledge_model` — parameters of the knowledge model `km`.
59-
- `cell` — parameters of the internal [`ESNCell`](@ref), including:
59+
- `reservoir` — parameters of the internal [`ESNCell`](@ref), including:
6060
- `input_matrix :: (res_dims × (in_dims + km_dims))` — `W_in`
6161
- `reservoir_matrix :: (res_dims × res_dims)` — `W_res`
6262
- `bias :: (res_dims,)` — present only if `use_bias=true`
@@ -73,13 +73,13 @@ throughout the reservoir and readout computations.
7373
Created by `initialstates(rng, hesn)`:
7474
7575
- `knowledge_model` — states for the internal knowledge model.
76-
- `cell` — states for the internal [`ESNCell`](@ref).
76+
- `reservoir` — states for the internal [`ESNCell`](@ref).
7777
- `states_modifiers` — a `Tuple` with states for each modifier layer.
7878
- `readout` — states for [`LinearReadout`](@ref).
7979
"""
8080
@concrete struct HybridESN <: AbstractEchoStateNetwork{(
81-
:cell, :states_modifiers, :readout, :knowledge_model)}
82-
cell
81+
:reservoir, :states_modifiers, :readout, :knowledge_model)}
82+
reservoir
8383
knowledge_model
8484
states_modifiers
8585
readout
@@ -94,46 +94,78 @@ function HybridESN(km,
9494
include_collect::BoolType = True(),
9595
kwargs...)
9696
esn_inp_size = in_dims + km_dims
97-
cell = StatefulLayer(ESNCell(esn_inp_size => res_dims, activation; kwargs...))
97+
reservoir = StatefulLayer(ESNCell(esn_inp_size => res_dims, activation; kwargs...))
9898
mods_tuple = state_modifiers isa Tuple || state_modifiers isa AbstractVector ?
9999
Tuple(state_modifiers) : (state_modifiers,)
100100
mods = _wrap_layers(mods_tuple)
101101
ro = LinearReadout(res_dims + km_dims => out_dims, readout_activation;
102102
include_collect = static(include_collect))
103103
km_layer = km isa WrappedFunction ? km : WrappedFunction(km)
104-
return HybridESN(cell, km_layer, mods, ro)
104+
return HybridESN(reservoir, km_layer, mods, ro)
105105
end
106106

107107
function initialparameters(rng::AbstractRNG, hesn::HybridESN)
108-
ps_cell = initialparameters(rng, hesn.cell)
108+
ps_reservoir = initialparameters(rng, hesn.reservoir)
109109
ps_km = initialparameters(rng, hesn.knowledge_model)
110110
ps_mods = map(l -> initialparameters(rng, l), hesn.states_modifiers) |> Tuple
111111
ps_ro = initialparameters(rng, hesn.readout)
112-
return (cell = ps_cell, knowledge_model = ps_km,
112+
return (reservoir = ps_reservoir, knowledge_model = ps_km,
113113
states_modifiers = ps_mods, readout = ps_ro)
114114
end
115115

116116
function initialstates(rng::AbstractRNG, hesn::HybridESN)
117-
st_cell = initialstates(rng, hesn.cell)
117+
st_reservoir = initialstates(rng, hesn.reservoir)
118118
st_km = initialstates(rng, hesn.knowledge_model)
119119
st_mods = map(l -> initialstates(rng, l), hesn.states_modifiers) |> Tuple
120120
st_ro = initialstates(rng, hesn.readout)
121-
return (cell = st_cell, knowledge_model = st_km,
121+
return (reservoir = st_reservoir, knowledge_model = st_km,
122122
states_modifiers = st_mods, readout = st_ro)
123123
end
124124

125125
function _partial_apply(hesn::HybridESN, inp, ps, st)
126126
k_t, st_km = hesn.knowledge_model(inp, ps.knowledge_model, st.knowledge_model)
127127
xin = vcat(k_t, inp)
128-
r, st_cell = apply(hesn.cell, xin, ps.cell, st.cell)
128+
r, st_reservoir = apply(hesn.reservoir, xin, ps.reservoir, st.reservoir)
129129
rstar,
130130
st_mods = _apply_seq(hesn.states_modifiers, r, ps.states_modifiers, st.states_modifiers)
131131
feats = vcat(k_t, rstar)
132-
return feats, (cell = st_cell, states_modifiers = st_mods, knowledge_model = st_km)
132+
return feats,
133+
(reservoir = st_reservoir, states_modifiers = st_mods, knowledge_model = st_km)
133134
end
134135

135136
function (hesn::HybridESN)(inp, ps, st)
136137
feats, new_st = _partial_apply(hesn, inp, ps, st)
137138
y, st_ro = apply(hesn.readout, feats, ps.readout, st.readout)
138139
return y, merge(new_st, (readout = st_ro,))
139140
end
141+
142+
function Base.show(io::IO, esn::HybridESN)
143+
print(io, "HybridESN(\n")
144+
145+
print(io, " reservoir = ")
146+
show(io, esn.reservoir)
147+
print(io, ",\n")
148+
149+
print(io, " knowledge_model = ")
150+
show(io, esn.knowledge_model)
151+
print(io, ",\n")
152+
153+
print(io, " state_modifiers = ")
154+
if isempty(esn.states_modifiers)
155+
print(io, "()")
156+
else
157+
print(io, "(")
158+
for (i, m) in enumerate(esn.states_modifiers)
159+
i > 1 && print(io, ", ")
160+
show(io, m)
161+
end
162+
print(io, ")")
163+
end
164+
print(io, ",\n")
165+
166+
print(io, " readout = ")
167+
show(io, esn.readout)
168+
print(io, "\n)")
169+
170+
return
171+
end

src/reservoircomputer.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,30 @@ function addreadout!(::AbstractReservoirComputer, output_matrix::AbstractMatrix,
9797
return merge(ps, (readout = new_readout,)), st
9898
end
9999

100+
function Base.show(io::IO, rc::ReservoirComputer)
101+
print(io, "ReservoirComputer(")
102+
103+
print(io, "reservoir = ")
104+
show(io, rc.reservoir)
105+
106+
nmods = length(rc.states_modifiers)
107+
if nmods == 0
108+
print(io, ", state_modifiers = ()")
109+
else
110+
print(io, ", state_modifiers = (")
111+
for (i, m) in enumerate(rc.states_modifiers)
112+
i > 1 && print(io, ", ")
113+
show(io, m)
114+
end
115+
print(io, ")")
116+
end
117+
118+
print(io, ", readout = ")
119+
show(io, rc.readout)
120+
121+
print(io, ")")
122+
end
123+
100124
@doc raw"""
101125
resetcarry!(rng, rc::ReservoirComputer, st; init_carry=nothing)
102126
resetcarry!(rng, rc::ReservoirComputer, ps, st; init_carry=nothing)

0 commit comments

Comments
 (0)