Quantcast
Channel: First steps - JuliaLang
Viewing all articles
Browse latest Browse all 2795

Performance issues?

$
0
0

As a learning project I reimplemented a simple Bayesian NN in julia from scratch (hopefully without mistakes!)
Benchmarking vs. my Pytorch code on the same CPU I’m seeing a runtime of +30% (using BenchmarkTools)…

While I’ve tried to follow the performance-tips as best as I could, this difference makes me think I’ve missed something major.

Hopefully someone can see at a glance something stupid I’ve done because after profiling, benchmarking , and tutorials I’ve hit a wall.

using Flux 
using Flux: glorot_uniform
using Zygote
using Flux.Optimise: update!, ADAM
using Plots 
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,3}) where T = batched_mul(x, y)
Base.:*(x::AbstractArray{T,3},y::AbstractArray{T,2}) where T = batched_mul_loop(x, y)

"""
Multiplies a batch of matrices with a 
"""

function batched_mul_loop(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    B = [view(A, :, :, k) * x for k in axes(A, 3)]
    # B = [Ai * x for Ai in eachslice(A, dims=3)]
    return cat(B..., dims=3)
end

function reparamSample(μ::AbstractArray, logσ::AbstractArray, m::Integer)
    r = randn(Float32, size(μ)..., m)
    return @. μ + r * exp(logσ)  
end
"""
Linear layer for Bayes neural network 
"""
mutable struct VariationalLinear{F <: Function,S <: AbstractArray,T <: AbstractArray,U <: AbstractArray,V <: AbstractArray,N <: Integer}
    in::N
    out::N
    W::S
    logσW::S
    b::T
    logσb::T
    σ::F
    W_sample::U
    b_sample::V
    η::N
end

function VariationalLinear(in::Integer, out::Integer, η::Integer, σ=identity)
    W = glorot_uniform(out, in)
    logσW = -6.0f0 * ones(Float32, out, in)
    b = zeros(Float32, out)
    logσb = -6.0f0 * ones(Float32, out)
    return VariationalLinear(in, out, W, logσW, b, logσb, σ, 
                            reparamSample(W, logσW, η), 
                            reparamSample(b, logσb, η),η)
end

function update_samples(a::VariationalLinear)
    a.W_sample = reparamSample(a.W, a.logσW, a.η)
    a.b_sample = reparamSample(a.b, a.logσb, a.η)
end

function (a::VariationalLinear)(x::AbstractArray{T}) where T
    update_samples(a)
    out = a.σ.(a.W_sample * x .+ reshape(a.b_sample, a.out, 1, a.η))
    return out 
end

"""
BayesNN 
"""
struct BayesNN{T <: Array{VariationalLinear,1},P <: Zygote.Params, N<:Integer}
	layers::T
    θ::P
    η::N
end
function BayesNN(in::Integer, out::Integer, η::Integer, num_layers::Integer, num_hidden::Integer, σ=relu)
    # putting layers into array
    layers = VariationalLinear{<:Function}[VariationalLinear(in, num_hidden, η, σ),]
	append!(layers, [VariationalLinear(num_hidden, num_hidden, η, σ) for i in 1:(num_layers - 1)])
	append!(layers, [VariationalLinear(num_hidden, out, η),])
	# collecting into parameter array 
	P = [layers[1].W, layers[1].b, layers[1].logσW, layers[1].logσb]
	(L -> append!(P, [L.W, L.b, L.logσW, L.logσb])).(layers[2:end])
	return BayesNN(layers, Flux.params(P), η)
end
(b::BayesNN)(x) = foldl((x, b) -> b(x), b.layers, init=x)

function log_likelihood(x, y, noise, model)
	return sum(-0.5f0 / noise^2 .* (model(x) .- y).^2) / model.η
end 

function log_normal(x, μ, logσ)
    -0.5f0 * (x .- μ).^2 ./ exp.(logσ).^2  .- logσ .- log(sqrt(2.0f0 * π))
end

function mean_log_prob(m)
    sum(sum(m, dims=3))
end

function kl_divergence(model)
    logq = 0.0f0
    logp = 0.0f0
    for layer in model.layers
        logq += mean_log_prob(log_normal(layer.W_sample, layer.W, layer.logσW))/model.η
        logp += mean_log_prob(log_normal(layer.W_sample, 0.0f0, 1.0f0))/model.η
    end
    return logq - logp
end

function main()
    @info "Start"

    model = BayesNN(1, 1, 20, 3, 100)

    # making dummy-data 
    bs = 100;
    epochs = 10000;
    period = 2.0f0;
    noise = 1.0f0; 
    x = reshape(collect(LinRange(0.0, 10.0, bs)), (1, bs));
    y = x .* 2 .* sin.(x .* (2 * π / period)) 
    yt = y .+ randn(size(x)...) * noise

    # converting to float32
    x = convert(Array{Float32}, x)
    y = convert(Array{Float32}, y)
    yt = convert(Array{Float32}, yt)

    θ = model.θ

    opt = Flux.Optimise.ADAM(1e-3)

    for i in 1:epochs
        gs = gradient(θ) do 
            return -log_likelihood(x, y, noise, model) + kl_divergence(model)
        end
        Flux.Optimise.update!(opt, θ, gs)
        if i % 1000 == 0 # change 1 to higher number to compute and print less frequently
            @info "Epoch $i| log-likelihood $(-log_likelihood(x, y, noise, model)) | kl-div $(kl_divergence(model))"
        end
    end

    @info "Done!"
    theme(:dark) 
    plot(x',yt', seriestype=:scatter, alpha=0.4)
    plot!(x',y')
    plot!(x',model(x)[1,:,:], legend = false , alpha=0.8 )
end

5 posts - 5 participants

Read full topic


Viewing all articles
Browse latest Browse all 2795

Trending Articles