Quantcast
Channel: First steps - JuliaLang
Viewing all articles
Browse latest Browse all 2795

Gradient and update of custom struct with Flux

$
0
0

Hi everyone,

I have created struct that I need to optimise with SGD. My struct is

using SparseArrays
struct tdmat
    d
    dv
    dh
    s
end

function tdmat(d::AbstractVector, dv::AbstractVector, dh::AbstractVector, s::Int64, m::Int64)
    if length(dv) == m-1
        dvzeros = dv
    else
        dvzeros = zeros(m) 
        for j = 1 : Int64(m/s)
            dvzeros[s*(j-1)+1:s*j-1] = dv[(s-1)*(j-1)+1:(s-1)*j]
        end
        dvzeros = dvzeros[1:end-1]
    end

    tdmat(d, dvzeros, dh, s)
end

function constructTdmat(mat::tdmat)
    m = length(mat.d)

    rows = vcat(1:m, 1:m-1, 1:m-mat.s)
    cols = vcat(1:m, 2:m, mat.s+1:m)
    vals = vcat(mat.d, mat.dv[1:end], mat.dh[1:end])

    return sparse(rows, cols, vals)
end

It represents a sparse matrix with a special structure. I also had to define some operations

function tdmat_mul(mat::tdmat, x::AbstractVector)
    m = length(mat.d)

    mx = mat.d.*x
    mx += vcat(mat.dv.*x[2:end], [0])
    mx += vcat(mat.dh.*x[mat.s+1:end], zeros(mat.s))
    return mx
end

import Base: *

function *(mat::tdmat, x::AbstractVector)
    @assert size(x,1)==length(mat.d)
    tdmat_mul(mat,x)
end

function *(a::Real, mat::tdmat)
    tdmat(a*mat.d, a*mat.dv, a*mat.dh, mat.s)
end

import Base: -
function -(mat1::tdmat, mat2::tdmat)
    if length(mat1.d) != length(mat2.d)
        error("Matrices have different size")
    end
    if mat1.s == mat2.s #mají stejné s -> výstup je tdmat
        return tdmat(mat1.d-mat2.d, mat1.dv-mat2.dv, mat1.dh - mat2.dh, mat1.s)
    else #nemají stejné s -> výstup je říká matice
        return constructTdmat(mat1) - constructTdmat(mat2)
    end
end

I am using ChainRules for definition of a derivative of multiplication of my speacial matrix and vector.

using ChainRules
function ChainRules.rrule(::typeof(*),A::tdmat,x::AbstractVector)
  function tdmat_multiply_pb(Δ)
    ΔA = Δ*x'
    Δx = Matrix(constructTdmat(A))'*Δ
    return (NO_FIELDS, ΔA, Δx)
  end
  
  return A*x, tdmat_multiply_pb
end

Loss function for testing is

T2 = tdmat(1:6, 2*(1:4), 3*(1:3), 3, 6)
u = 1:6
Topt = tdmat(1:6, (1:4), (1:3), 3, 6)
v = T2*u

loss(Topt)=sum(v.^2-Topt*u)

I also use

using Flux
Flux.@functor tdmat
function Flux.trainable(mat::tdmat)
    ps = (mat.d, mat.dv, mat.dh)
end

but I am not sure if it is necessary.
Then when I run

opts = Descent(0.1)
pars = Flux.params([Topt.d, Topt.dv, Topt.dh])
gs = Flux.gradient(()->loss(Topt),pars)
Flux.Optimise.update!(opts, pars, gs)

I get an error

ERROR: Only reference types can be differentiated with `Params`.
Stacktrace:
 [1] error(::String) at .\error.jl:33
 [2] getindex at C:\Users\anton\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:142 [inlined]
 [3] update!(::Descent, ::Zygote.Params, ::Zygote.Grads) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:28
 [4] top-level scope at REPL[330]:1

but the gradients were created and seem ok. What does it mean? I guess it has something in common with the s parameter of tdmat, but how should I fix it?

Moreover, I am totally unable to figure out how to approach gradients of a structure. Why gs[Topt] does not work as usual?

Also, I noticed that when I run it with struct yx that does not give this error, the values are updated only if I use Flux.params([yx.x, yx.y]), not Flux.params(yx). Is there an easy way to update it without having to write all parts of the struct down?
The code for the struct yx is

struct xy
    x
    y
end

function XY()
    x = [1.0]
    y = [25.0]
    xy(x, y)
end

f(p::xy) = sum(sin.(p.x) .+ p.y.^2)

yx = XY()

opts = Descent(0.1)
pars2 = Flux.params([yx.x, yx.y])
gs2 = Flux.gradient(() -> f(yx), pars2)
Flux.Optimise.update!(opts, pars2, gs2)

Apart from that I noticed that when I use different loss function f(p::xy) = sum(sin.(p.x) .+ p.y), where p.y isn’t squared, gradient returns

:(Main.yx) => (x = [0.540302], y = 1-element FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}} = 1.0)

which cannot be updated with error

ERROR: ArgumentError: Cannot setindex! to 0.1 for an AbstractFill with value 1.0.
Stacktrace:
 [1] setindex! at C:\Users\anton\.julia\packages\FillArrays\NjFh2\src\FillArrays.jl:47 [inlined]
 [2] copyto!(::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at .\multidimensional.jl:962
 [3] copyto! at .\broadcast.jl:905 [inlined]
 [4] copyto! at .\broadcast.jl:864 [inlined]
 [5] materialize! at .\broadcast.jl:826 [inlined]
 [6] apply!(::Descent, ::Array{Float64,1}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\optimisers.jl:39
 [7] update!(::Descent, ::Array{Float64,1}, ::FillArrays.Fill{Float64,1,Tuple{Base.OneTo{Int64}}}) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:23
 [8] update!(::Descent, ::Zygote.Params, ::Zygote.Grads) at C:\Users\anton\.julia\packages\Flux\NpkMm\src\optimise\train.jl:29
 [9] top-level scope at REPL[350]:1

How should the function look like then?

It is a bit lenghty, but I will appriciate help with any of these questions.

3 posts - 2 participants

Read full topic


Viewing all articles
Browse latest Browse all 2795

Trending Articles