Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow type inference #54879

Open
mhauru opened this issue Jun 21, 2024 · 0 comments
Open

Slow type inference #54879

mhauru opened this issue Jun 21, 2024 · 0 comments

Comments

@mhauru
Copy link

mhauru commented Jun 21, 2024

I ran into a case where this function takes ~25s to compile:

function infers_slowly(x, chosen_indices, indices)
    return [
        [x.values[i] for i in indices if i == ci] for
        ci in chosen_indices
    ]
end

The type of x is quite complicated, coming from a bunch Turing.jl/Optimization.jl stuff, and this is crucial for inference to take a long time. I tried to boil it down to something more minimal, but only partially succeeded. Here's where I got to:

import ADTypes
import SciMLBase
using Turing: @model, arraydist, LogNormal, MvNormal, DefaultContext
using Turing.Optimisation: ModeResult, OptimLogDensity, OptimizationContext
using LinearAlgebra: Diagonal
import OptimizationBase
import Optimization
using OptimizationOptimJL: OptimizationOptimJL

struct OptimizationThingie{F,RC,O}
    optimization_function::F
    reinit_cache::RC
    optimization_algorithm::O
end

function OptimizationThingie(prob, alg)
    reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
    f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, 0)
    return OptimizationThingie(f, reinit_cache, alg)
end

function make_a_thing_with_a_complicated_type(N)

    @model function demo_model(N)
        a ~ arraydist(LogNormal.(fill(0, N), 1))
        covariance_matrix = Matrix(Diagonal(a))
        x ~ MvNormal(covariance_matrix)
        return nothing
    end

    m = demo_model(N) | (x=randn(N),)
    initial_params = fill(1.0, 2 * N)
    solver = OptimizationOptimJL.LBFGS()
    adtype = ADTypes.AutoForwardDiff()
    log_density = OptimLogDensity(m, OptimizationContext(DefaultContext()))
    f = Optimization.OptimizationFunction(log_density, adtype)
    prob = Optimization.OptimizationProblem(f, initial_params)
    thingie = OptimizationThingie(prob, solver)
    return thingie
end

struct X{V,O}
    values::V
    a_complicated_thing_that_is_never_touched::O
end

function infers_slowly(x, chosen_indices, indices)
    return [
        [x.values[i] for i in indices if i == ci] for
        ci in chosen_indices
    ]
end

N = 6  # This needs to be around ~6 or more for slowness.
complicated_thingie = make_a_thing_with_a_complicated_type(N)
indices = collect(1:N)
values = randn(N)
x = X(values, complicated_thingie)

@time infers_slowly(x, (1,), indices)
@time infers_slowly(x, (1,), indices)

The output for that is e.g.

 25.468606 seconds (58.04 k allocations: 3.971 MiB, 100.00% compilation time)
  0.000013 seconds (4 allocations: 272 bytes)

Julia version:

julia> versioninfo()
Julia Version 1.10.4
Commit 48d4fd48430 (2024-06-04 10:41 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)

Package versions:

  [47edcb42] + ADTypes v1.4.0
  [7f7a1694] + Optimization v3.26.1
  [bca83a33] + OptimizationBase v1.2.0
  [36348300] + OptimizationOptimJL v0.3.2
  [0bca4576] + SciMLBase v2.41.3
  [fce5fe82] + Turing v0.33.0

To be clear, I'm not looking to change my code to make it faster. I know how to do that. I'm reporting this because it seems like a failure of the type inference machinery, which I hope could be fixed for general benefit of mankind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant