@@ -104,27 +104,13 @@ end
104104 # By linearity: 0.5*(gtmp1+gtmp2)*W.dW == 0.5*(gtmp2*W.dW) + 0.5*(gtmp1*W.dW).
105105 # Avoid forming (gtmp1 + gtmp2), which would allocate a temporary SparseMatrixCSC.
106106 # Use 5-arg mul! to accumulate directly into the cached vector (allocation-free).
107- _eh_accum_stage2 ! (nrtmp, gtmp2, W. dW)
107+ mul ! (nrtmp, gtmp2, W. dW, convert ( eltype (nrtmp), 0.5 ), convert ( eltype (nrtmp), 0.5 ) )
108108 end
109109
110110 dto2 = dt / 2
111111 @. . u = uprev + dto2 * (ftmp1 + ftmp2) + nrtmp
112112end
113113
114- @inline function _eh_accum_stage2! (y, g2, dW)
115- mul! (y, g2, dW, convert (eltype (y), 0.5 ), convert (eltype (y), 0.5 ))
116- return nothing
117- end
118-
119- @inline function _eh_accum_stage2! (
120- y:: StridedVector{T} ,
121- g2:: StridedMatrix{T} ,
122- dW:: StridedVector{T}
123- ) where {T <: LinearAlgebra.BlasFloat }
124- LinearAlgebra. BLAS. gemv! (' N' , T (0.5 ), g2, dW, T (0.5 ), y)
125- return nothing
126- end
127-
128114@muladd function perform_step! (integrator, cache:: RandomEMConstantCache )
129115 @unpack t, dt, uprev, u, W, p, f = integrator
130116 u = uprev .+ dt .* integrator. f (uprev, p, t, W. curW)
0 commit comments