Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save a model / fit and load it? Issue with JLD2 for "reconstructing" #2309

Open
DominiqueMakowski opened this issue Aug 25, 2024 · 14 comments
Labels

Comments

@DominiqueMakowski
Copy link
Contributor

(this issue is somewhat related to #2308)

I'm trying to save models on the disk and, in a new session, loading and using them.

Here's an MWE, starting with making and saving a model:

using Turing
using JLD2

@model function mymodel(y)
    μ ~ Normal(0, 2)
    σ ~ truncated(Normal(0, 3), 0.0, Inf)
    for i in 1:length(y)
        y[i] ~ Normal(μ, σ)
    end
end
fit = mymodel([1, 2, 3, 4, 5])

jldsave("model.jld2";  model=mymodel, fit=fit)

Now, in a new session, if I do the following it errors:

using Turing
using JLD2

loaded = jldopen("model.jld2", "r+")
loaded["model"]
┌ Warning: type Main.#mymodel does not exist in workspace; reconstructing
└ @ JLD2 C:\Users\domma\.julia\packages\JLD2\twZ5D\src\data\reconstructing_datatypes.jl:492
loaded["model"]([1, 2, 3, 4, 5])
ERROR: MethodError: objects of type JLD2.ReconstructedSingleton{Symbol("#mymodel")} are not callable
Stacktrace:
 [1] top-level scope
   @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\1_models_make.jl:154

How to correctly save/load Turing models?

@yebai
Copy link
Member

yebai commented Aug 25, 2024

@devmotion IIRC, we can't serialise Turing models due to a DynamicPPL limitation. Is that still the case, and if so, is that fixable?

@devmotion
Copy link
Member

Models can be serialized, we even have a test for it: https://github.com/TuringLang/DynamicPPL.jl/blob/138bd40acdfc47d7b00e25a2adaf9fec986f9646/test/serialization.jl The serialization issues in DynamicPPL should have been fixed by TuringLang/DynamicPPL.jl#134. I haven't checked the MWE above but I wonder if it's rather a JLD2 than a Turing/DynamicPPL issue.

@DominiqueMakowski
Copy link
Contributor Author

Thanks for looking into this. Is there another more robust alternative to saving & loading models other than JLD2? I picked JLD2 initially for saving the chains (note that it works for that) following this thread

@DominiqueMakowski
Copy link
Contributor Author

kind bump in case someone has any good suggestions on how to save & load models

@devmotion
Copy link
Member

Serialization should work (and is tested) - can you try if Serialization.serialize and Serialization.deserialize works for you?

@DominiqueMakowski
Copy link
Contributor Author

Unfortunately it doesn't seem like it works:

using Turing
using Serialization

@model function mymodel(y)
    μ ~ Normal(0, 2)
    σ ~ truncated(Normal(0, 3), 0.0, Inf)
    for i in 1:length(y)
        y[i] ~ Normal(μ, σ)
    end
end
fit = mymodel([1, 2, 3, 4, 5])

Serialization.serialize("model.turing", Dict("model" => mymodel, "fit" => fit))

Restart session:

using Turing
using Serialization

loaded = Serialization.deserialize("model.turing")
loaded["model"]([1, 2, 3, 4, 5])
ERROR: UndefVarError: `#mymodel` not defined
Stacktrace:
  [1] deserialize_datatype(s::Serializer{IOStream}, full::Bool)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1399
  [2] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:867
  [3] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
  [4] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:874
  [5] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
  [6] deserialize_dict(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1529
  [7] deserialize(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1536
  [8] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
  [9] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
 [10] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:920
 [11] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
 [12] deserialize(s::IOStream)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:801
 [13] open(f::typeof(deserialize), args::String; kwargs::@Kwargs{})
    @ Base .\io.jl:396
 [14] open
    @ .\io.jl:393 [inlined]
 [15] deserialize(filename::String)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:811
 [16] top-level scope
    @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\activate.jl:29

@devmotion
Copy link
Member

mymodel is a regular Julia function, so it suffers from the same limitations regarding (de)serializations as any other Julia function, whereas fit is an object of type DynamicPPL.Model and behaves differently.

You can (de)serialize mymodel e.g. in the following way:

using Turing
using Serialization

@model function mymodel(y)
    μ ~ Normal(0, 2)
    σ ~ truncated(Normal(0, 3), 0.0, Inf)
    for i in 1:length(y)
        y[i] ~ Normal(μ, σ)
    end
end

Serialization.serialize("model.turing", methods(mymodel))
using Turing
using Serialization

function mymodel end # this is required
Serialization.deserialize("model.turing")

mymodel([1, 2, 3, 4])

@DominiqueMakowski
Copy link
Contributor Author

DominiqueMakowski commented Sep 4, 2024

Unfortunately, that didn't do the trick either and a new error crept in that I couldn't make sense when googling what AccessorsImpl was:

julia> loaded = Serialization.deserialize("model.turing")
ERROR: UndefVarError: `AccessorsImpl` not defined
Stacktrace:
  [1] deserialize_module(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:997
  [2] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:896
  [3] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
  [4] deserialize_datatype(s::Serializer{IOStream}, full::Bool)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1398
  [5] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:867
  [6] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
  [7] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:874
  [8] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
  [9] deserialize_expr(s::Serializer{IOStream}, len::Int64)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1291
 [10] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:894
 [11] deserialize_fillarray!(A::Vector{Any}, s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1281
 [12] deserialize_array(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1273
 [13] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:865
 [14] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
 [15] deserialize(s::Serializer{IOStream}, ::Type{Core.CodeInfo})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1133
 [16] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:960
 [17] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
 [18] deserialize(s::Serializer{IOStream}, ::Type{Method})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1044
 [19] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:960
 [20] deserialize_fillarray!(A::Vector{Method}, s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1281
 [21] deserialize_array(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1273
 [22] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:865
 [23] deserialize(s::Serializer{IOStream}, t::DataType)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1501
 [24] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
 [25] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
 [26] deserialize_dict(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1529
 [27] deserialize(s::Serializer{IOStream}, T::Type{Dict{String, Any}})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:1536
 [28] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:883
 [29] deserialize(s::Serializer{IOStream})
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814
 [30] handle_deserialize(s::Serializer{IOStream}, b::Int32)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:920
 [31] deserialize
    @ C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:814 [inlined]
 [32] deserialize(s::IOStream)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:801
 [33] open(f::typeof(deserialize), args::String; kwargs::@Kwargs{})
    @ Base .\io.jl:396
 [34] open
    @ .\io.jl:393 [inlined]
 [35] deserialize(filename::String)
    @ Serialization C:\Users\domma\.julia\juliaup\julia-1.10.2+0.x64.w64.mingw32\share\julia\stdlib\v1.10\Serialization\src\Serialization.jl:811
 [36] top-level scope
    @ c:\Users\domma\Dropbox\RECHERCHE\Studies\DoggoNogo\study1\analysis\activate.jl:31

I really appreciate your help though.

The main reason I'm saving the model itself is to be able to refit it later on new data (which surely must be a common use case! in R it is common to save and share and download and re-use big fitted models)

As doing it this way - assuming it is even possible - is clunky and unwieldy, the alternatives I see are:

  • Having an update() method (discussed in update() method for updating a fitted model with new data #2308)
  • Is it possible to extract the model object/method from the fitted object? In other words, as far as I understand, a Turing model is often defined as a function (which is hard to serialize), which gets turned into a dynamicPPL object through the @model macro. Can we recover/reconstruct that object from the fitted version?

Thanks again @devmotion

@devmotion
Copy link
Member

devmotion commented Sep 4, 2024

The example above works fine for me, I don't get this error. Did you try a more complicated example?

My setup:

(jl_L5QqA0) pkg> st
Status `/private/var/folders/n6/98_7bm0j0hb57zv3l3tj8sxh0000gn/T/jl_L5QqA0/Project.toml`
⌃ [fce5fe82] Turing v0.33.3
Info Packages marked with ⌃ have new versions available and may be upgradable.

julia> versioninfo()
Julia Version 1.10.5
Commit 6f3fdf7b362 (2024-08-27 14:19 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M2 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 6 virtual cores)
Environment:
  JULIA_PKG_USE_CLI_GIT = true
  JULIA_PKG_PRESERVE_TIERED_INSTALLED = true

@DominiqueMakowski
Copy link
Contributor Author

It does work, the culprit was that I was reloading it using a different Turing version 🤦

Great, cheers, I'll close this!

(but I would still suggest that making that process more convenient would be a nice feature ☺️)

@DominiqueMakowski
Copy link
Contributor Author

One last shot, I know I'm asking for a lot, but the above solution is not very convenient for programmatic usage.
In my case, I define and save a lot of models, and then I load them and use them via a loop. The code is made to work on an arbitrary number of models with arbitrary names:

hence I would like to be able to call the model directly from the dict loaded["model"]([1, 2, 3, 4, 5]) without having to use the original function name, i.e., rather than from the re-defined function mymodel([1, 2, 3, 4]) (because that would require me to re-write bespoke code for all the models)

@yebai
Copy link
Member

yebai commented Sep 4, 2024

@penelopeysm can you help add the trick to docs / FAQ?

@yebai yebai reopened this Sep 4, 2024
@sunxd3
Copy link
Collaborator

sunxd3 commented Sep 4, 2024

@yebai yebai added the doc label Sep 4, 2024
@DominiqueMakowski
Copy link
Contributor Author

the above solution is not very convenient for programmatic usage

The problem boils down to being able to fit a model on data without having to (re)define the original function name, to allow for workflows such as:

(pseudocode)

# 1. Define models
@model m1(y, x)
    ...
end
fit1 = m1(data)

@model m2(y, x)
    ...
end
fit2 = m2(data)

# 2. Save models
save(fit1, "fit1")
save(fit2, "fit2")

In a new script

for m in ["fit1", "fit2"]
    fit = load(m)
    fit(newdata) 
    predict(fit, ...)
end

Do you think it might be solved by implementing an update() method (#2308)?


Context: an example of use case where this flexibility is IMO a critical feature is when models are fit / sampled from on external machines (high-performance clusters): the output is ideally saved and then downloaded by researchers who can then manipulate these models independently (for reporting, postprocessing, predictions, analysis, etc. etc.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants
@devmotion @yebai @sunxd3 @DominiqueMakowski and others