I’m having a hard time figuring out why this works:
using Zygote
W, b = rand(2, 3), rand(2)
predict(x) = W*x .+ b
g = gradient(() -> sum(predict([1,2,3])), Params([W, b]))
g[W], g[b]
but this doesn’t:
using Zygote
a = 2
x = 2
f(x) = x^a
gp = gradient(() -> f(x), Params(a))
gp[a]
I get the error:
ERROR: Only reference types can be differentiated with `Params`.
Can anyone help?
7 posts - 2 participants