Quantcast
Channel: First steps - JuliaLang
Viewing all articles
Browse latest Browse all 2795

PermutedDimsArray slower than permutedims?

$
0
0

I’m struggling to understand why one function for multiplying a batch of matrices with a single matrix is significantly slower than another. The only difference between the functions is that PermutedDimsArray is used in place of permutedims in the slower function.

I would have thought permutedims would be slower given that it creates a copy of the array while PermutedDimsArray creates a new view. This makes me think I’ve misunderstood something fundamental with how Julia works.

function batched_mul_m(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    a1, a2, bs = size(A)
    b1, b2 = size(x)
    C = reshape(permutedims(A, (1, 3, 2)), (bs * a1, a2)) * x
    return permutedims(reshape(C, (a1, bs, b2)), (1, 3, 2))
end

function batched_mul_tk2(A::AbstractArray{T,3}, x::AbstractArray{T,2}) where T 
    a1, a2, bs = size(A)
    b1, b2 = size(x)
    C = reshape(PermutedDimsArray(A, (1, 3, 2)), (bs * a1, a2)) * x
    return PermutedDimsArray(reshape(C, (a1, bs, b2)), (1, 3, 2))
end

M = randn(10,20,200)
x = randn(20,50)

@btime batched_mul_m(M,x);
# 169.902 μs (10 allocations: 1.83 MiB)

@btime batched_mul_tk2(M,x);
# 1.408 ms (21 allocations: 782.25 KiB)

4 posts - 3 participants

Read full topic


Viewing all articles
Browse latest Browse all 2795

Trending Articles