You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I ran into a case where this function takes ~25s to compile:
function infers_slowly(x, chosen_indices, indices)
return [
[x.values[i] for i in indices if i == ci] for
ci in chosen_indices
]
end
The type of x is quite complicated, coming from a bunch Turing.jl/Optimization.jl stuff, and this is crucial for inference to take a long time. I tried to boil it down to something more minimal, but only partially succeeded. Here's where I got to:
import ADTypes
import SciMLBase
using Turing: @model, arraydist, LogNormal, MvNormal, DefaultContext
using Turing.Optimisation: ModeResult, OptimLogDensity, OptimizationContext
using LinearAlgebra: Diagonal
import OptimizationBase
import Optimization
using OptimizationOptimJL: OptimizationOptimJL
struct OptimizationThingie{F,RC,O}
optimization_function::F
reinit_cache::RC
optimization_algorithm::O
end
function OptimizationThingie(prob, alg)
reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, 0)
return OptimizationThingie(f, reinit_cache, alg)
end
function make_a_thing_with_a_complicated_type(N)
@model function demo_model(N)
a ~ arraydist(LogNormal.(fill(0, N), 1))
covariance_matrix = Matrix(Diagonal(a))
x ~ MvNormal(covariance_matrix)
return nothing
end
m = demo_model(N) | (x=randn(N),)
initial_params = fill(1.0, 2 * N)
solver = OptimizationOptimJL.LBFGS()
adtype = ADTypes.AutoForwardDiff()
log_density = OptimLogDensity(m, OptimizationContext(DefaultContext()))
f = Optimization.OptimizationFunction(log_density, adtype)
prob = Optimization.OptimizationProblem(f, initial_params)
thingie = OptimizationThingie(prob, solver)
return thingie
end
struct X{V,O}
values::V
a_complicated_thing_that_is_never_touched::O
end
function infers_slowly(x, chosen_indices, indices)
return [
[x.values[i] for i in indices if i == ci] for
ci in chosen_indices
]
end
N = 6 # This needs to be around ~6 or more for slowness.
complicated_thingie = make_a_thing_with_a_complicated_type(N)
indices = collect(1:N)
values = randn(N)
x = X(values, complicated_thingie)
@time infers_slowly(x, (1,), indices)
@time infers_slowly(x, (1,), indices)
To be clear, I'm not looking to change my code to make it faster. I know how to do that. I'm reporting this because it seems like a failure of the type inference machinery, which I hope could be fixed for general benefit of mankind.
The text was updated successfully, but these errors were encountered:
I ran into a case where this function takes ~25s to compile:
The type of
x
is quite complicated, coming from a bunch Turing.jl/Optimization.jl stuff, and this is crucial for inference to take a long time. I tried to boil it down to something more minimal, but only partially succeeded. Here's where I got to:The output for that is e.g.
Julia version:
Package versions:
To be clear, I'm not looking to change my code to make it faster. I know how to do that. I'm reporting this because it seems like a failure of the type inference machinery, which I hope could be fixed for general benefit of mankind.
The text was updated successfully, but these errors were encountered: