-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrepresentation_minex.jl
197 lines (165 loc) · 7.2 KB
/
representation_minex.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
using Revise
using BetaZero
using Plots; default(fontfamily="Computer Modern", framestyle=:box)
using ParticleFilters
using POMDPs
using POMDPTools
using LinearAlgebra
using ParticleBeliefs
using StatsBase
using MinEx
pomdp = MinExPOMDP()
up = ParticleHistoryBeliefUpdater(BootstrapFilter(pomdp, pomdp.n_particles))
function BetaZero.accuracy(pomdp::MinExPOMDP, b0, s0, states, actions, returns)
massive = MinEx.calc_massive(pomdp, s0)
truth = (massive > pomdp.extraction_cost) ? :mine : :abandon
is_correct = (actions[end] == truth)
return is_correct
end
simple_minex_belief_reward(pomdp::POMDP, b, a, bp) = mean(reward(pomdp, s, a) for s in MinEx.particles(b.particles))
BetaZero.optimal_return(pomdp::MinExPOMDP, s) = max(0, extraction_reward(pomdp, s))
function compute_optimal_return_minex(pomdp::MinExPOMDP; kwargs...)
ds0 = initialstate(pomdp)
if ds0 isa Vector
# discrete (cached) particle set (Note, avoid using this)
@warn "Computing using discrete, generated particle set"
ore_matrix = load_states(joinpath(@__DIR__, "..", "submodules", "MinEx", "src", "generated_states.h5"))
ds0 = [MinExState(ore_matrix[:,:,i]) for i in axes(ore_matrix,3)]
end
return compute_optimal_return_minex(pomdp, ds0; kwargs...)
end
function compute_optimal_return_minex(pomdp, ds0::MinEx.MinExStateDistribution; n=10_000, kwargs...)
particles = rand(ds0, n)
return compute_optimal_return_minex(pomdp, particles; kwargs...)
end
function compute_optimal_return_minex(pomdp, ds0; include_drills=false)
if include_drills
# compute optimal returns if fully drilled all locations, then made the oracle mine/abandon decision
discounted_full_drill_cost = sum(discount(pomdp)^(t-1)*-pomdp.drill_cost for t in 1:length(actions(pomdp))-2)
else
# make oracle decision without drilling
discounted_full_drill_cost = 0
end
optimal_r = [discounted_full_drill_cost + BetaZero.optimal_return(pomdp, s) for s in ds0]
return mean_and_stderr(optimal_r)
end
function compute_lowerbound_return_minex(pomdp)
# compute returns if fully drilled all locations, then made the decision to abandon
return sum(discount(pomdp)^(t-1)*-pomdp.drill_cost for t in 1:length(actions(pomdp))-2)
end
POMDPs.convert_s(::Type{A}, b::ParticleHistoryBelief{MinExState}, m::BetaZero.BeliefMDP) where A<:AbstractArray = Flux.unsqueeze(Float32.(BetaZero.input_representation(b)); dims=4)
# POMDPs.convert_s(::Type{ParticleHistoryBelief{MinExState}}, b::A, m::BetaZero.BeliefMDP) where A<:AbstractArray = ParticleHistoryBelief(particles=ParticleCollection(rand(LDNormalStateDist(b[1], b[2]), 500)))
zeroifnan(x) = isnan(x) ? 0 : x
data_skewness(D) = [zeroifnan(skewness(D[x,y,1:end-1])) for x in axes(D,1), y in axes(D,2)]
data_kurtosis(D) = [zeroifnan(kurtosis(D[x,y,1:end-1])) for x in axes(D,1), y in axes(D,2)]
function BetaZero.input_representation(b::ParticleHistoryBelief{MinExState};
use_higher_orders::Bool=false, include_obs::Bool=false)
states::Vector{MinExState} = particles(b)
grid_dims::Tuple = size(states[1].ore)
n_particles::Int = length(states)
# n_channels::Int = 2 + (use_higher_orders ? 2 : 0) + include_obs
stacked_states = Array{Float32}(undef, grid_dims..., n_particles)
for i in 1:n_particles
stacked_states[:,:,i] = states[i].ore
end
μ, σ = mean_and_std(stacked_states, 3)
return cat(μ, σ; dims=3)
end
## Plotting
function plot_belief(b::ParticleHistoryBelief{MinExState}, s=nothing; a=nothing, cmap=cgrad(["#44342a", "#e1e697"]), linecmap=cgrad([:white, :white]), scale=1, applyclims=false)
# cmap = cgrad(:turbid, rev=true)
# linecmap = cgrad([:white, :black])
function show_decision()
if !isnothing(a) && a isa Symbol
annotate!(16, 30, ("$a", 16, :white, :center))
end
end
b̃ = BetaZero.input_representation(b)
μ = b̃[:,:,1] .^ (1/scale)
σ = b̃[:,:,2] .^ (1/scale)
plt_mean = heatmap(μ, ratio=1, c=cmap, title="\$\\mu(b)\$")
if applyclims
plot!(clims=(0,1))
end
current_ylims = ylims()
xlims!(current_ylims...)
drill_style = (label=false, c=:black, mc=:darkred, msc=:white, marker=:square, ms=4)
if !isnothing(s) && !isempty(s.drill_locations)
xloc = map(last, s.drill_locations) # Note y-first, x-last
yloc = map(first, s.drill_locations)
n = length(xloc)
if n > 1
for i in 1:n-1
x = xloc[i:i+1]
y = yloc[i:i+1]
if x[end] != -1 && y[end] != -1 # final decision has been made
c = n == 2 ? get(linecmap, 0) : get(linecmap, (i-1)/(n-2))
plot!(x, y, arrow=:closed, lw=2, color=c, label=false)
end
end
end
scatter!(xloc, yloc; drill_style...)
end
ylims!(current_ylims...)
show_decision()
plt_std = heatmap(σ, ratio=1, c=cmap, title="\$\\sigma(b)\$")
if applyclims
plot!(clims=(0.0, 0.2))
end
current_ylims = ylims()
xlims!(current_ylims...)
if !isnothing(s) && !isempty(s.drill_locations)
xloc = map(last, s.drill_locations) # Note y-first, x-last
yloc = map(first, s.drill_locations)
scatter!(xloc, yloc; drill_style...)
end
ylims!(current_ylims...)
show_decision()
return plot(plt_mean, plt_std, layout=2, margin=4Plots.mm, size=(1000, 400))
end
function plot_state(s::MinExState; cmap=cgrad(["#44342a", "#e1e697"]))
heatmap(s.ore, ratio=1, c=cmap, title="state", clims=(0,1), margin=4Plots.mm, size=(500, 400))
return xlims!(ylims()...)
end
function plot_trajectory(beliefs::Vector, states::Union{Vector,Nothing}=nothing; filename::Function=i->"belief$i.png", betterfig::Bool=false)
for i in eachindex(beliefs)
@info "Plotting belief $i/$(length(beliefs))"
if isnothing(states)
plot_belief(beliefs[i])
else
plot_belief(beliefs[i], states[i])
end
if betterfig
bettersavefig(filename(i))
else
savefig(filename(i))
end
end
end
function plot_volume(volume, true_volume=nothing; bins=[-200:10:200;])
μ, σ = mean_and_std(volume)
h = fit(Histogram, volume, bins)
h = normalize(h, mode=:probability)
rd = x->round(x, digits=2)
plot(h, title="belief volumes (μ=$(rd(μ)), σ=$(rd(σ)))", label="economic volume", c=:cadetblue)
h_height = maximum(h.weights)
ylims!(0, h_height*1.05)
if !isnothing(true_volume)
vline!([true_volume], c=:black, ls=:dash, lw=2, label="true volume")
end
vline!([μ], c=:crimson, lw=2, label="mean volume")
vline!([μ - σ], c=:crimson, lw=2, alpha=0.5, ls=:dot, label="standard deviation")
vline!([μ + σ], c=:crimson, lw=2, alpha=0.5, ls=:dot, label=false)
plot!(size=(600,350), margin=10Plots.mm, top_margin=5Plots.mm, xlabel="economic volume", ylabel="probability")
end
function plot_volumes(volumes::Vector, true_volume=nothing; filename::Function=i->"volume$i.png", betterfig::Bool=false)
for i in eachindex(volumes)
@info "Plotting volume $i/$(length(volumes))"
plot_volume(volumes[i], true_volume)
if betterfig
bettersavefig(filename(i))
else
savefig(filename(i))
end
end
end