Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit cd77d25

Browse files
authored
Api change (#19)
* unify interfaces * update screen after interact * add minor comment * add version check * update docker * revert to juali v1.2 due to https://github.com/JuliaLang/julia/pull/32408\#issuecomment-522168938 * update README
1 parent 82a21f3 commit cd77d25

14 files changed

+146
-93
lines changed

.travis.yml

+1-3
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ language: julia
33
os:
44
- linux
55
julia:
6-
- 1.0
7-
- 1.1
8-
- nightly
6+
- 1.2
97
notifications:
108
email: false
119

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM julia:1.1
1+
FROM julia:1.2
22

33
# install dependencies
44
RUN set -eux; \

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This package serves as a one-stop place for different kinds of reinforcement lea
77
Install:
88

99
```julia
10-
(v1.1) pkg> add https://github.com/JuliaReinforcementLearning/ReinforcementLearningEnvironments.jl
10+
pkg> add ReinforcementLearningEnvironments
1111
```
1212

1313
## API
@@ -64,11 +64,11 @@ Take the `AtariEnv` for example:
6464

6565
1. Install this package by:
6666
```julia
67-
(v1.1) pkg> add ReinforcementLearningEnvironments
67+
pkg> add ReinforcementLearningEnvironments
6868
```
6969
2. Install corresponding dependent package by:
7070
```julia
71-
(v1.1) pkg> add ArcadeLearningEnvironment
71+
pkg> add ArcadeLearningEnvironment
7272
```
7373
3. Using the above two packages:
7474
```julia

src/ReinforcementLearningEnvironments.jl

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
module ReinforcementLearningEnvironments
22

3+
export RLEnvs
4+
const RLEnvs = ReinforcementLearningEnvironments
5+
36
using Reexport, Requires
47

58
include("abstractenv.jl")

src/abstractenv.jl

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render
1+
export AbstractEnv, observe, reset!, interact!, action_space, observation_space, render, Observation, get_reward, get_terminal, get_state, get_legal_actions
22

33
abstract type AbstractEnv end
44

@@ -7,4 +7,26 @@ function reset! end
77
function interact! end
88
function action_space end
99
function observation_space end
10-
function render end
10+
function render end
11+
12+
struct Observation{R, T, S, M<:NamedTuple}
13+
reward::R
14+
terminal::T
15+
state::S
16+
meta::M
17+
end
18+
19+
Observation(;reward, terminal, state, kw...) = Observation(reward, terminal, state, merge(NamedTuple(), kw))
20+
21+
get_reward(obs::Observation) = obs.reward
22+
get_terminal(obs::Observation) = obs.terminal
23+
get_state(obs::Observation) = obs.state
24+
get_legal_actions(obs::Observation) = obs.meta.legal_actions
25+
26+
# !!! >= julia v1.3
27+
if VERSION >= v"1.3.0-rc1.0"
28+
(env::AbstractEnv)(a) = interact!(env, a)
29+
end
30+
31+
action_space(env::AbstractEnv) = env.action_space
32+
observation_space(env::AbstractEnv) = env.observation_space

src/environments/atari.jl

+12-12
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,17 @@ using ArcadeLearningEnvironment, GR
22

33
export AtariEnv
44

5-
struct AtariEnv{To,F} <: AbstractEnv
5+
mutable struct AtariEnv{To,F} <: AbstractEnv
66
ale::Ptr{Nothing}
77
screen::Array{UInt8, 1}
88
getscreen!::F
9-
actions::Array{Int32, 1}
9+
actions::Array{Int64, 1}
1010
action_space::DiscreteSpace{Int}
1111
observation_space::To
1212
noopmax::Int
13+
reward::Float32
1314
end
1415

15-
action_space(env::AtariEnv) = env.action_space
16-
observation_space(env::AtariEnv) = env.observation_space
17-
1816
"""
1917
AtariEnv(name; colorspace = "Grayscale", frame_skip = 4, noopmax = 20,
2018
color_averaging = true, repeat_action_probability = 0.)
@@ -51,24 +49,26 @@ function AtariEnv(name;
5149
end
5250
actions = actionset == :minimal ? getMinimalActionSet(ale) : getLegalActionSet(ale)
5351
action_space = DiscreteSpace(length(actions))
54-
AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax)
52+
AtariEnv(ale, screen, getscreen!, actions, action_space, observation_space, noopmax, 0.0f0)
5553
end
5654

5755
function interact!(env::AtariEnv, a)
58-
r = act(env.ale, env.actions[a])
56+
env.reward = act(env.ale, env.actions[a])
5957
env.getscreen!(env.ale, env.screen)
60-
(observation=env.screen, reward=r, isdone=game_over(env.ale))
58+
nothing
6159
end
6260

63-
function observe(env::AtariEnv)
64-
env.getscreen!(env.ale, env.screen)
65-
(observation=env.screen, isdone=game_over(env.ale))
66-
end
61+
observe(env::AtariEnv) = Observation(
62+
reward = env.reward,
63+
terminal = game_over(env.ale),
64+
state = env.screen
65+
)
6766

6867
function reset!(env::AtariEnv)
6968
reset_game(env.ale)
7069
for _ in 1:rand(0:env.noopmax) act(env.ale, Int32(0)) end
7170
env.getscreen!(env.ale, env.screen)
71+
env.reward = 0.0f0 # dummy
7272
nothing
7373
end
7474

src/environments/classic_control/cart_pole.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ function CartPoleEnv(; T = Float64, gravity = T(9.8), masscart = T(1.),
4242
cp
4343
end
4444

45-
action_space(env::CartPoleEnv) = env.action_space
46-
observation_space(env::CartPoleEnv) = env.observation_space
47-
4845
function reset!(env::CartPoleEnv{T}) where T <: Number
4946
env.state[:] = T(.1) * rand(env.rng, T, 4) .- T(.05)
5047
env.t = 0
@@ -53,7 +50,11 @@ function reset!(env::CartPoleEnv{T}) where T <: Number
5350
nothing
5451
end
5552

56-
observe(env::CartPoleEnv) = (observation=env.state, isdone=env.done)
53+
observe(env::CartPoleEnv) = Observation(
54+
reward = env.done ? 0.0 : 1.0,
55+
terminal = env.done,
56+
state = env.state
57+
)
5758

5859
function interact!(env::CartPoleEnv{T}, a) where T <: Number
5960
env.action = a
@@ -76,7 +77,7 @@ function interact!(env::CartPoleEnv{T}, a) where T <: Number
7677
env.done = abs(env.state[1]) > env.params.xthreshold ||
7778
abs(env.state[3]) > env.params.thetathreshold ||
7879
env.t >= env.params.max_steps
79-
(observation=env.state, reward=1., isdone=env.done)
80+
nothing
8081
end
8182

8283
function plotendofepisode(x, y, d)

src/environments/classic_control/mdp.jl

+42-37
Original file line numberDiff line numberDiff line change
@@ -9,46 +9,52 @@ export MDPEnv, POMDPEnv, SimpleMDPEnv, absorbing_deterministic_tree_MDP, stochas
99
##### POMDPEnv
1010
#####
1111

12-
mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG}
12+
mutable struct POMDPEnv{T,Ts,Ta, R<:AbstractRNG} <: AbstractEnv
1313
model::T
1414
state::Ts
1515
actions::Ta
1616
action_space::DiscreteSpace
1717
observation_space::DiscreteSpace
18+
observation::Int
19+
reward::Float64
1820
rng::R
1921
end
2022

21-
POMDPEnv(model; rng=Random.GLOBAL_RNG) = POMDPEnv(
22-
model,
23-
initialstate(model, rng),
24-
actions(model),
25-
DiscreteSpace(n_actions(model)),
26-
DiscreteSpace(n_states(model)),
27-
rng)
23+
function POMDPEnv(model; rng=Random.GLOBAL_RNG)
24+
state = initialstate(model, rng)
25+
as = DiscreteSpace(n_actions(model))
26+
os = DiscreteSpace(n_states(model))
27+
actions_of_model = actions(model)
28+
s, o, r = generate_sor(model, state, actions_of_model[rand(as)], rng)
29+
obs = observationindex(model, o)
30+
POMDPEnv(model, state, actions_of_model, as, os, obs, 0., rng)
31+
end
2832

2933
function interact!(env::POMDPEnv, action)
3034
s, o, r = generate_sor(env.model, env.state, env.actions[action], env.rng)
3135
env.state = s
32-
(observation = observationindex(env.model, o),
33-
reward = r,
34-
isdone = isterminal(env.model, s))
36+
env.reward = r
37+
env.observation = observationindex(env.model, o)
38+
nothing
3539
end
3640

37-
function observe(env::POMDPEnv)
38-
(observation = observationindex(env.model, generate_o(env.model, env.state, env.rng)),
39-
isdone = isterminal(env.model, env.state))
40-
end
41+
observe(env::POMDPEnv) = Observation(
42+
reward = env.reward,
43+
terminal = isterminal(env.model, env.state),
44+
state = env.observation
45+
)
4146

4247
#####
4348
##### MDPEnv
4449
#####
4550

46-
mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG}
51+
mutable struct MDPEnv{T, Ts, Ta, R<:AbstractRNG} <: AbstractEnv
4752
model::T
4853
state::Ts
4954
actions::Ta
5055
action_space::DiscreteSpace
5156
observation_space::DiscreteSpace
57+
reward::Float64
5258
rng::R
5359
end
5460

@@ -58,10 +64,9 @@ MDPEnv(model; rng=Random.GLOBAL_RNG) = MDPEnv(
5864
actions(model),
5965
DiscreteSpace(n_actions(model)),
6066
DiscreteSpace(n_states(model)),
61-
rng)
62-
63-
action_space(env::Union{MDPEnv, POMDPEnv}) = env.action_space
64-
observation_space(env::Union{MDPEnv, POMDPEnv}) = env.observation_space
67+
0.,
68+
rng
69+
)
6570

6671
observationindex(env, o) = Int(o) + 1
6772

@@ -74,15 +79,15 @@ function interact!(env::MDPEnv, action)
7479
s = rand(env.rng, transition(env.model, env.state, env.actions[action]))
7580
r = POMDPs.reward(env.model, env.state, env.actions[action])
7681
env.state = s
77-
(observation = stateindex(env.model, s),
78-
reward = r,
79-
isdone = isterminal(env.model, s))
82+
env.reward = r
83+
nothing
8084
end
8185

82-
function observe(env::MDPEnv)
83-
(observation = stateindex(env.model, env.state),
84-
isdone = isterminal(env.model, env.state))
85-
end
86+
observe(env::MDPEnv) = Observation(
87+
reward = env.reward,
88+
terminal = isterminal(env.model, env.state),
89+
state = stateindex(env.model, env.state)
90+
)
8691

8792
#####
8893
##### SimpleMDPEnv
@@ -107,14 +112,15 @@ probabilities) `reward` of type `R` (see [`DeterministicStateActionReward`](@ref
107112
[`NormalStateActionReward`](@ref)), array of initial states
108113
`initialstates`, and `ns` - array of 0/1 indicating if a state is terminal.
109114
"""
110-
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG}
115+
mutable struct SimpleMDPEnv{T,R,S<:AbstractRNG} <: AbstractEnv
111116
observation_space::DiscreteSpace
112117
action_space::DiscreteSpace
113118
state::Int
114119
trans_probs::Array{T, 2}
115120
reward::R
116121
initialstates::Array{Int, 1}
117122
isterminal::Array{Int, 1}
123+
score::Float64
118124
rng::S
119125
end
120126

@@ -125,12 +131,9 @@ function SimpleMDPEnv(ospace, aspace, state, trans_probs::Array{T, 2},
125131
reward = DeterministicStateActionReward(reward)
126132
end
127133
SimpleMDPEnv{T,typeof(reward),S}(ospace, aspace, state, trans_probs,
128-
reward, initialstates, isterminal, rng)
134+
reward, initialstates, isterminal, 0., rng)
129135
end
130136

131-
observation_space(env::SimpleMDPEnv) = env.observation_space
132-
action_space(env::SimpleMDPEnv) = env.action_space
133-
134137
# reward types
135138
"""
136139
struct DeterministicNextStateReward
@@ -208,13 +211,15 @@ run!(mdp::SimpleMDPEnv, policy::Array{Int, 1}) = run!(mdp, policy[mdp.state])
208211
function interact!(env::SimpleMDPEnv, action)
209212
oldstate = env.state
210213
run!(env, action)
211-
r = reward(env.rng, env.reward, oldstate, action, env.state)
212-
(observation = env.state, reward = r, isdone = env.isterminal[env.state] == 1)
214+
env.score = reward(env.rng, env.reward, oldstate, action, env.state)
215+
nothing
213216
end
214217

215-
function observe(env::SimpleMDPEnv)
216-
(observation = env.state, isdone = env.isterminal[env.state] == 1)
217-
end
218+
observe(env::SimpleMDPEnv) = Observation(
219+
reward = env.score,
220+
terminal = env.isterminal[env.state] == 1,
221+
state = env.state
222+
)
218223

219224
function reset!(env::SimpleMDPEnv)
220225
env.state = rand(env.rng, env.initialstates)

src/environments/classic_control/mountain_car.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,14 @@ function MountainCarEnv(; T = Float64, continuous = false,
5050
reset!(env)
5151
env
5252
end
53+
5354
ContinuousMountainCarEnv(; kwargs...) = MountainCarEnv(; continuous = true, kwargs...)
5455

55-
action_space(env::MountainCarEnv) = env.action_space
56-
observation_space(env::MountainCarEnv) = env.observation_space
57-
observe(env::MountainCarEnv) = (observation=env.state, isdone=env.done)
56+
observe(env::MountainCarEnv) = Observation(
57+
reward = env.done ? 0. : -1.,
58+
terminal = env.done,
59+
state = env.state
60+
)
5861

5962
function reset!(env::MountainCarEnv{A, T}) where {A, T}
6063
env.state[1] = .2 * rand(env.rng, T) - .6
@@ -78,7 +81,7 @@ function _interact!(env::MountainCarEnv, force)
7881
env.t >= env.params.max_steps
7982
env.state[1] = x
8083
env.state[2] = v
81-
(observation=env.state, reward=-1., isdone=env.done)
84+
nothing
8285
end
8386

8487
# adapted from https://github.com/JuliaML/Reinforce.jl/blob/master/src/envs/mountain_car.jl

0 commit comments

Comments
 (0)