Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Astmodels-dev #318

Draft
wants to merge 179 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
Show all changes
179 commits
Select commit Hold shift + click to select a range
efedfaf
update Project.toml
cscherrer Nov 30, 2020
ecd5ff4
s/logpdf/logdensity/
cscherrer Dec 1, 2020
b6b9a8b
Merge branch 'master' into cs-measuretheory
cscherrer Dec 1, 2020
17b7832
merge master
cscherrer Mar 8, 2021
33fd391
dynamichmc bugfix
cscherrer Mar 9, 2021
dda1623
scratchpad
cscherrer Mar 11, 2021
1dfc821
gps
cscherrer Mar 11, 2021
2943628
predict(d, ::AbstractVector)
cscherrer Mar 11, 2021
c080f59
scratchpad
cscherrer Mar 11, 2021
c2e3327
Rename using SampleChainsDynamicHMC.jl to SampleChainsDynamicHMC-demo.jl
cscherrer Mar 12, 2021
e989ea4
WIP: add basemeasure for models (#239)
mschauer Mar 15, 2021
975692f
drop accidental ZigZagBoomerang dependency
cscherrer Mar 15, 2021
ec1e0da
move stuff to scratchpad
cscherrer Mar 16, 2021
3af9fc7
skip failing tests for now
cscherrer Mar 18, 2021
63275a0
Merge branch 'master' into dev
cscherrer Mar 18, 2021
570e75a
tests
cscherrer Mar 18, 2021
20867cc
version bump
cscherrer Mar 18, 2021
74a3bd2
deps
cscherrer Mar 20, 2021
afaa488
oops, need `noted.jl`
cscherrer Mar 20, 2021
de00a9b
drop KeywordDispatch dependency
cscherrer Mar 24, 2021
1cb2461
Merge remote-tracking branch 'origin/master' into dev
cscherrer Mar 28, 2021
89cfd4e
updatre MeaureTheory dependency version
cscherrer Mar 27, 2021
407d2df
qualify some names to avoid collisions
cscherrer Mar 28, 2021
48001da
swap name order for` simulate`
cscherrer Mar 28, 2021
3c71f3f
rand -> testvalue
cscherrer Apr 6, 2021
7bd16e8
update dependencies
cscherrer Apr 6, 2021
f57d618
update .gitignore
cscherrer Apr 6, 2021
e0fe51e
remove `For` tests (in MeasureTheory now)
cscherrer Apr 6, 2021
50654fa
update nested models test
cscherrer Apr 6, 2021
a7b5852
update tests
cscherrer Apr 6, 2021
7762b6b
oops typo
cscherrer Apr 6, 2021
77103b8
Drop old Markov chain implementation (use MeasureTheory's)
cscherrer Apr 6, 2021
87bfe14
drop example for now
cscherrer Apr 6, 2021
3b3f0d4
move simulate to primitives
cscherrer Apr 6, 2021
dc9950e
iid
cscherrer Apr 9, 2021
ce38253
update `simulate`
cscherrer Apr 9, 2021
33df3d6
rm MaskArrays
cscherrer Apr 9, 2021
f071ef9
Merge branch 'master' into dev
cscherrer Apr 9, 2021
697b0d3
rewind version
cscherrer Apr 9, 2021
3f803c8
start on ASTModels
cscherrer Apr 9, 2021
bfb1f67
move astmodels
cscherrer Apr 9, 2021
596d43e
move astmodel
cscherrer Apr 9, 2021
a41edfb
type2model
cscherrer Apr 9, 2021
6f07d9d
remove dead lines
cscherrer Apr 9, 2021
31623ab
astmodels
cscherrer Apr 11, 2021
a3ecb54
interpret
cscherrer Apr 11, 2021
cb40f8c
include interpret
cscherrer Apr 11, 2021
c7f6ece
cal2nt
cscherrer Apr 11, 2021
d4b22b0
make interpret use Vals
cscherrer Apr 11, 2021
5e44d6b
asmodel
cscherrer Apr 12, 2021
f00e083
move astmodel.jl
cscherrer Apr 12, 2021
f06d413
commit renames
cscherrer Apr 12, 2021
537d898
fix local dependencies
cscherrer Apr 18, 2021
e9b07a7
bugfix
cscherrer Apr 18, 2021
4198f69
Merge branch 'cs-scoping' into cs-astmodels
cscherrer Apr 18, 2021
5945095
Merge branch 'master' into cs-astmodels
cscherrer May 1, 2021
c2a8509
Drop redundant code
cscherrer May 1, 2021
4484112
Model => DAGModel
cscherrer May 1, 2021
3610113
Merge branch 'master' into cs-astmodels
cscherrer May 10, 2021
9a79d5a
drop some outdated files
cscherrer May 10, 2021
9932ff2
splitting up ASTModels from DAGModels
cscherrer May 10, 2021
1adb196
moving things around
cscherrer May 10, 2021
0a880b5
bugfix
cscherrer May 10, 2021
cea420f
update deps
cscherrer May 28, 2021
b76bcd3
scratchpad
cscherrer May 28, 2021
f16154f
version bump
cscherrer Jun 3, 2021
9f90cb7
representative => rootmeasure
cscherrer Jun 10, 2021
ff2aad1
update dependency versions
cscherrer Jun 15, 2021
f7ac9be
reduce dependencies
cscherrer Jun 15, 2021
13ce8fc
update deps
cscherrer May 28, 2021
2c18845
scratchpad
cscherrer May 28, 2021
1f889e5
representative => rootmeasure
cscherrer Jun 10, 2021
bc0b681
update dependency versions
cscherrer Jun 15, 2021
d300f6a
reduce dependencies
cscherrer Jun 15, 2021
fb354d2
Merge branch 'cs-keywordcalls' of github.com:cscherrer/Soss.jl into c…
cscherrer Jun 15, 2021
5a7c5d7
`as` methods for `xform`
cscherrer Jun 17, 2021
6e0e0d7
cleanup
cscherrer Jun 17, 2021
105891e
require latest MeasureTheory
cscherrer Jun 17, 2021
96a9ae2
Merge remote-tracking branch 'origin/master' into cs-measuretheory
cscherrer Jun 27, 2021
70b3d5f
Merge branch 'cs-keywordcalls' into dev
cscherrer Jul 2, 2021
4893430
dorp old distributions code
cscherrer Jul 2, 2021
98e3716
drop old iid code
cscherrer Jul 2, 2021
050d41c
drop extra space
cscherrer Jul 2, 2021
0c639ca
limit deps to three newest releases
cscherrer Jul 3, 2021
b2ecf89
update dynamichmc
cscherrer Jul 6, 2021
024fdc4
add Aqua
cscherrer Jul 11, 2021
6ec7131
bump version
cscherrer Jul 11, 2021
aee7bf7
minor updates
cscherrer Jul 14, 2021
10c959a
Merge remote-tracking branch 'origin/master' into cs-astmodels
cscherrer Jul 14, 2021
083b76a
Merge branch 'cs-astmodels' of github.com:cscherrer/Soss.jl into cs-a…
cscherrer Jul 14, 2021
52ac4fd
add test
cscherrer Jul 15, 2021
faae51b
Merge branch 'master' into cs-astmodels
cscherrer Jul 15, 2021
30d1302
Merge branch 'cs-astmodels' of github.com:cscherrer/Soss.jl into cs-a…
cscherrer Jul 15, 2021
621bb0f
updates to `interpret`
cscherrer Jul 15, 2021
0e94bcd
small bugfix
cscherrer Jul 15, 2021
2d860ee
trying out `rand`
cscherrer Jul 15, 2021
3fc0d18
Better `predict` method
cscherrer Jul 15, 2021
a5fe38d
withmeasures(::ConditionalModel)
cscherrer Jul 16, 2021
c83c670
Merge branch 'master' into dev
cscherrer Jul 16, 2021
80e773b
minor cleanup
cscherrer Jul 16, 2021
0a02614
refactoring
cscherrer Jul 16, 2021
da060f0
refactor
cscherrer Jul 16, 2021
b1b213b
make it faster
cscherrer Jul 16, 2021
74a8012
_runtime_args
cscherrer Jul 16, 2021
12956de
more refactoring
cscherrer Jul 17, 2021
f0eb9e2
Update interpret to match _interpret
cscherrer Jul 18, 2021
6b7fc8b
drop temporary `rand2`
cscherrer Jul 18, 2021
67f9452
refactoring
cscherrer Jul 18, 2021
f9e648f
drop old `Base.rand`
cscherrer Jul 18, 2021
095b428
add Accessors and BangBang
cscherrer Jul 18, 2021
a6ff69e
Merge branch 'cs-astmodels' of github.com:cscherrer/Soss.jl into cs-a…
cscherrer Jul 18, 2021
ad0c588
Merge branch 'cs-astmodels' of github.com:cscherrer/Soss.jl into cs-a…
cscherrer Jul 18, 2021
6ee03b5
reset
cscherrer Jul 18, 2021
d314c5e
Break ConditionalModel into ModelClosure and ModelPosterior
cscherrer Jul 19, 2021
14e0f15
kwargs for `rnad`
cscherrer Jul 19, 2021
1842009
remove the Val
cscherrer Jul 19, 2021
9c2ca86
bugfix
cscherrer Jul 19, 2021
47a1b17
add _retn
cscherrer Jul 19, 2021
b76dbd6
refactoring
cscherrer Jul 19, 2021
1977ffc
allow nested models
cscherrer Jul 19, 2021
7948cfc
Commit to the approach (for this branch anyway)
cscherrer Jul 19, 2021
d33f211
nested models
cscherrer Jul 20, 2021
fdf45c2
some cleanup
cscherrer Jul 20, 2021
249b1d3
add inargs and inobs
cscherrer Jul 20, 2021
4632e6f
cleanup
cscherrer Jul 22, 2021
3fa0c62
ReTest
cscherrer Jul 22, 2021
a46eef2
Merge remote-tracking branch 'origin/master' into cs-astmodels
cscherrer Aug 13, 2021
fcca308
`using Soss` in tests
cscherrer Aug 16, 2021
b7e30f4
Merge branch 'master' into cs-astmodels
cscherrer Aug 26, 2021
25ba251
update dependencies
cscherrer Oct 22, 2021
84ba2aa
updating symbolics
cscherrer Oct 22, 2021
8c9b96c
Merge branch 'master' into dev
cscherrer Oct 22, 2021
3382325
some updates to symbolics
cscherrer Oct 26, 2021
020e236
Merge branch 'dev' into cs-astmodels
cscherrer Oct 26, 2021
5fdc1d6
toposort(::DAGModel)
cscherrer Oct 26, 2021
968fa66
cleanup
cscherrer Oct 27, 2021
29d725e
moar
cscherrer Oct 27, 2021
cb77c70
hmm example
cscherrer Nov 2, 2021
54a07b1
merge
cscherrer Nov 2, 2021
f48bc8e
Merge branch 'dev' of github.com:cscherrer/Soss.jl into dev
cscherrer Nov 2, 2021
d25d00a
update MeasureBase bound
cscherrer Nov 3, 2021
1cb1db0
better dispatch for `predict`
cscherrer Nov 3, 2021
466a5a4
drop redundant method
cscherrer Nov 3, 2021
408e00f
remove whitespace
cscherrer Nov 3, 2021
28753dc
bump version
cscherrer Nov 3, 2021
4261028
Merge remote-tracking branch 'origin/master' into cs-astmodels
cscherrer Nov 5, 2021
34be839
gotos
cscherrer Nov 23, 2021
314cfe9
Merge branch 'master' into dev
cscherrer Dec 29, 2021
38fdd05
update for upcoming MeasureTheory release
cscherrer Dec 29, 2021
95754ea
small change to toposort
cscherrer Dec 29, 2021
e7bd0d3
have `iid` use powermeasure
cscherrer Dec 29, 2021
7672e5a
tests passing!
cscherrer Jan 21, 2022
e12356d
iid(n::Integer...)
cscherrer Jan 21, 2022
bb49335
Merge branch 'dev' into astmodels-dev
cscherrer Jan 21, 2022
450998c
abstract model stuff
cscherrer Jan 21, 2022
57e85ce
fix some merging issues
cscherrer Jan 22, 2022
d363e50
update dependencies
cscherrer Jan 24, 2022
83d3713
start logdensity with partialstatic
cscherrer Jan 24, 2022
2da7f83
update test to account for partialstatic
cscherrer Jan 24, 2022
dbce8d1
add `insupport`
cscherrer Jan 24, 2022
3b1f36f
clean up insupport
cscherrer Jan 24, 2022
b391ec5
speed up model building
cscherrer Jan 24, 2022
bd04858
updates
cscherrer Jan 28, 2022
ba27c93
Merge remote-tracking branch 'origin/dev' into astmodels-dev
cscherrer Jan 28, 2022
1efa66e
work on codegen
cscherrer Jan 31, 2022
bc6b4ae
workin on it
cscherrer Jan 31, 2022
7dd2643
working on astmodels
cscherrer Feb 1, 2022
ac06391
another update
cscherrer Feb 1, 2022
7f80e99
edits
cscherrer Feb 3, 2022
6323998
bugfix
cscherrer Feb 3, 2022
e7ed8b1
update `logdensity`
cscherrer Feb 4, 2022
fe45f8e
updates
cscherrer Feb 9, 2022
863f2c4
strip things down for now
cscherrer Feb 9, 2022
4dc0bba
fix-JET-infer (#321)
thautwarm Feb 15, 2022
497da02
refactor TildeArgs
cscherrer Feb 16, 2022
58d860d
Working out the AST thing
cscherrer Feb 25, 2022
f84281d
think it's working
cscherrer Feb 25, 2022
2fe89bd
update
cscherrer Feb 25, 2022
6f841e3
merge
cscherrer Mar 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working on astmodels
cscherrer committed Feb 1, 2022

Unverified

This user has not yet uploaded their public signing key.
commit 7dd2643d476083239b22cb04ae244ab02ec7e9ad
2 changes: 1 addition & 1 deletion src/core/models/abstractmodel.jl
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ obstype(::Type{<:AbstractModel}) = NamedTuple{(), Tuple{}}

(m::AbstractModel)(;argvals...)= m((;argvals...))

(m::AbstractModel)(args...) = m(NamedTuple{Tuple(m.args)}(args...))
(m::AbstractModel{A})(args...) where {A} = m(A(args))

body(::AbstractModel{A,B}) where {A,B} = from_type(B)
body(::Type{AM}) where {A,B,AM<:AbstractModel{A,B}} = from_type(B)
13 changes: 12 additions & 1 deletion src/core/utils.jl
Original file line number Diff line number Diff line change
@@ -359,6 +359,9 @@ function locally_bound(ex, optic)
setdiff(globals(isolated), globals(in_context))
end

"""
Given a JuliaVariables "solved" expression, convert back to a standard expression
"""
function unsolve(ex)
ex = unwrap_scoped(ex)
@match ex begin
@@ -368,6 +371,14 @@ function unsolve(ex)
end
end


"""
Return the set of local variable names from a *solved* expression (using JuliaVariables)
"""
function locals(ex)

Tuple(@match ex begin
v::JuliaVariables.Var => ifelse(v.is_global, Set{Symbol}(), Set((v.name,)))
Expr(head, args...) => union(map(locals, args)...)
x => Set{Symbol}()
end)
end
27 changes: 14 additions & 13 deletions src/primitives/interpret.jl
Original file line number Diff line number Diff line change
@@ -2,26 +2,25 @@ export interpret

function interpret(m::ASTModel{A,B,M}, tilde, ctx0) where {A,B,M}
theModule = getmodule(m)
mk_function(theModule, _interpret(m.body, tilde, ctx0))
mk_function(theModule, _interpret(theModule, m.body, tilde, ctx0))
end

# abstract type Maybe{T} end


function _interpret(ast::Expr, _tilde, _args, _obs)
function _interpret(M, ast::Expr, _tilde, _args, _obs)
function go(ex)
@match ex begin
:($x ~ $d) => begin
x = x.name
qx = QuoteNode(x)
xname = to_type(x)
measure = to_type(d)
inargs = static(x getntkeys(_args))
inobs = static(x getntkeys(_obs))
varnames = locals(d)
varvals = Expr(:tuple, varnames...)
quote
# _x_oldval = ifelse($(Expr(:isdefined, $qx)), Just($x), None())
_x_oldval = nothing
_targs = TildeArgs(_ctx, _cfg, _x_oldval, _vars, $inargs, $inobs)
($x, _ctx, _retn) = $_tilde($xname, $measure, _targs)
($x, _ctx, _retn) = let targs = Soss.TildeArgs($xname, $measure, NamedTuple{$varnames}($varvals), $inargs, $inobs)
$_tilde($qx, $d, _cfg, _ctx, targs)
end
end
end

@@ -31,10 +30,10 @@ function _interpret(ast::Expr, _tilde, _args, _obs)
end
end


body = go(@q let
$ast
end)

$(solve_scope(ast)).args[2]
end) |> unsolve

body
end
@@ -49,7 +48,7 @@ end
tilde = T.instance

body = _m.body |> loadvals(_args, _obs)
body = _interpret(body, tilde, _args, _obs)
body = _interpret(M, body, tilde, _args, _obs)

q = MacroTools.flatten(@q let M
@inline function(_cfg, _ctx)
@@ -61,4 +60,6 @@ end
_retn
end
end)

@under_global M q
end
16 changes: 9 additions & 7 deletions src/tildeargs.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
struct TildeArgs{Ctx,Cfg,Xold,Vars,inArgs,inObs}
context::Ctx # A value (often a NamedTuple) that typically evolves throughout inference
config::Cfg # A NamedTuple holding configuration parameters, e.g. RNG
x_oldval::Xold # The previous value for the LHS. Because it may not be defined, this must be a `Maybe`
vars::Vars # A named tuple of local variables used in the measure
inargs::inArgs # A StaticBool indicating whether the current LHS is in the arguments
inobs::inObs # A StaticBool indicating whether the current LHS is in the observations
struct TildeArgs{XName,M,Vars,inArgs,inObs}
x_name::XName # The name of LHS variable, represented at the type level.
measure::M # A type-level representation of the RHS expression
# x_oldval::Xold # The previous value for the LHS. Because it may not be defined, this must be a `Maybe`
# ctx::Ctx # A value (often a NamedTuple) that typically evolves throughout inference
# cfg::Cfg # A NamedTuple holding configuration parameters, e.g. RNG
vars::Vars # A named tuple of local variables used in the measure
inargs::inArgs # A StaticBool indicating whether the current LHS is in the arguments
inobs::inObs # A StaticBool indicating whether the current LHS is in the observations
end