1
+ import numpy as np
2
+ from collections import deque
3
+ import gym
4
+ from gym import spaces
5
+ import cv2
6
+
7
+ cv2 .ocl .setUseOpenCL (False )
8
+
9
+
10
+ class NoopResetEnv (gym .Wrapper ):
11
+ def __init__ (self , env , noop_max = 30 ):
12
+ """Sample initial states by taking random number of no-ops on reset.
13
+ No-op is assumed to be action 0.
14
+ """
15
+ gym .Wrapper .__init__ (self , env )
16
+ self .noop_max = noop_max
17
+ self .override_num_noops = None
18
+ self .noop_action = 0
19
+ assert env .unwrapped .get_action_meanings ()[0 ] == 'NOOP'
20
+
21
+ def reset (self , ** kwargs ):
22
+ """ Do no-op action for a number of steps in [1, noop_max]."""
23
+ self .env .reset (** kwargs )
24
+ if self .override_num_noops is not None :
25
+ noops = self .override_num_noops
26
+ else :
27
+ noops = self .unwrapped .np_random .randint (1 , self .noop_max + 1 ) # pylint: disable=E1101
28
+ assert noops > 0
29
+ obs = None
30
+ for _ in range (noops ):
31
+ obs , _ , done , _ = self .env .step (self .noop_action )
32
+ if done :
33
+ obs = self .env .reset (** kwargs )
34
+ return obs
35
+
36
+ def step (self , ac ):
37
+ return self .env .step (ac )
38
+
39
+
40
+ class FireResetEnv (gym .Wrapper ):
41
+ def __init__ (self , env ):
42
+ """Take action on reset for environments that are fixed until firing."""
43
+ gym .Wrapper .__init__ (self , env )
44
+ assert env .unwrapped .get_action_meanings ()[1 ] == 'FIRE'
45
+ assert len (env .unwrapped .get_action_meanings ()) >= 3
46
+
47
+ def reset (self , ** kwargs ):
48
+ self .env .reset (** kwargs )
49
+ obs , _ , done , _ = self .env .step (1 )
50
+ if done :
51
+ self .env .reset (** kwargs )
52
+ obs , _ , done , _ = self .env .step (2 )
53
+ if done :
54
+ self .env .reset (** kwargs )
55
+ return obs
56
+
57
+ def step (self , ac ):
58
+ return self .env .step (ac )
59
+
60
+
61
+ class EpisodicLifeEnv (gym .Wrapper ):
62
+ def __init__ (self , env ):
63
+ """Make end-of-life == end-of-episode, but only reset on true game over.
64
+ Done by DeepMind for the DQN and co. since it helps value estimation.
65
+ """
66
+ gym .Wrapper .__init__ (self , env )
67
+ self .lives = 0
68
+ self .was_real_done = True
69
+
70
+ def step (self , action ):
71
+ obs , reward , done , info = self .env .step (action )
72
+ self .was_real_done = done
73
+ # check current lives, make loss of life terminal,
74
+ # then update lives to handle bonus lives
75
+ lives = self .env .unwrapped .ale .lives ()
76
+ if lives < self .lives and lives > 0 :
77
+ # for Qbert sometimes we stay in lives == 0 condtion for a few frames
78
+ # so its important to keep lives > 0, so that we only reset once
79
+ # the environment advertises done.
80
+ done = True
81
+ self .lives = lives
82
+ return obs , reward , done , info
83
+
84
+ def reset (self , ** kwargs ):
85
+ """Reset only when lives are exhausted.
86
+ This way all states are still reachable even though lives are episodic,
87
+ and the learner need not know about any of this behind-the-scenes.
88
+ """
89
+ if self .was_real_done :
90
+ obs = self .env .reset (** kwargs )
91
+ else :
92
+ # no-op step to advance from terminal/lost life state
93
+ obs , _ , _ , _ = self .env .step (0 )
94
+ self .lives = self .env .unwrapped .ale .lives ()
95
+ return obs
96
+
97
+
98
+ class MaxAndSkipEnv (gym .Wrapper ):
99
+ def __init__ (self , env , skip = 4 ):
100
+ """Return only every `skip`-th frame"""
101
+ gym .Wrapper .__init__ (self , env )
102
+ # most recent raw observations (for max pooling across time steps)
103
+ self ._obs_buffer = np .zeros ((2 ,) + env .observation_space .shape , dtype = np .uint8 )
104
+ self ._skip = skip
105
+
106
+ def reset (self ):
107
+ return self .env .reset ()
108
+
109
+ def step (self , action ):
110
+ """Repeat action, sum reward, and max over last observations."""
111
+ total_reward = 0.0
112
+ done = None
113
+ for i in range (self ._skip ):
114
+ obs , reward , done , info = self .env .step (action )
115
+ if i == self ._skip - 2 : self ._obs_buffer [0 ] = obs
116
+ if i == self ._skip - 1 : self ._obs_buffer [1 ] = obs
117
+ total_reward += reward
118
+ if done :
119
+ break
120
+ # Note that the observation on the done=True frame
121
+ # doesn't matter
122
+ max_frame = self ._obs_buffer .max (axis = 0 )
123
+
124
+ return max_frame , total_reward , done , info
125
+
126
+ def reset (self , ** kwargs ):
127
+ return self .env .reset (** kwargs )
128
+
129
+
130
+ class ClipRewardEnv (gym .RewardWrapper ):
131
+ def __init__ (self , env ):
132
+ gym .RewardWrapper .__init__ (self , env )
133
+
134
+ def reward (self , reward ):
135
+ """Bin reward to {+1, 0, -1} by its sign."""
136
+ return np .sign (reward )
137
+
138
+
139
+ class WarpFrame (gym .ObservationWrapper ):
140
+ def __init__ (self , env ):
141
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
142
+ gym .ObservationWrapper .__init__ (self , env )
143
+ self .width = 84
144
+ self .height = 84
145
+ self .observation_space = spaces .Box (low = 0 , high = 255 ,
146
+ shape = (self .height , self .width , 1 ), dtype = np .uint8 )
147
+
148
+ def observation (self , frame ):
149
+ frame = cv2 .cvtColor (frame , cv2 .COLOR_RGB2GRAY )
150
+ frame = cv2 .resize (frame , (self .width , self .height ), interpolation = cv2 .INTER_AREA )
151
+ return frame [:, :, None ]
152
+
153
+
154
+ class FrameStack (gym .Wrapper ):
155
+ def __init__ (self , env , k ):
156
+ """Stack k last frames.
157
+ Returns lazy array, which is much more memory efficient.
158
+ See Also
159
+ --------
160
+ baselines.common.atari_wrappers.LazyFrames
161
+ """
162
+ gym .Wrapper .__init__ (self , env )
163
+ self .k = k
164
+ self .frames = deque ([], maxlen = k )
165
+ shp = env .observation_space .shape
166
+ self .observation_space = spaces .Box (low = 0 , high = 255 , shape = (shp [0 ], shp [1 ], shp [2 ] * k ), dtype = np .uint8 )
167
+
168
+ def reset (self ):
169
+ ob = self .env .reset ()
170
+ for _ in range (self .k ):
171
+ self .frames .append (ob )
172
+ return self ._get_ob ()
173
+
174
+ def step (self , action ):
175
+ ob , reward , done , info = self .env .step (action )
176
+ self .frames .append (ob )
177
+ return self ._get_ob (), reward , done , info
178
+
179
+ def _get_ob (self ):
180
+ assert len (self .frames ) == self .k
181
+ return LazyFrames (list (self .frames ))
182
+
183
+
184
+ class ScaledFloatFrame (gym .ObservationWrapper ):
185
+ def __init__ (self , env ):
186
+ gym .ObservationWrapper .__init__ (self , env )
187
+
188
+ def observation (self , observation ):
189
+ # careful! This undoes the memory optimization, use
190
+ # with smaller replay buffers only.
191
+ return np .array (observation ).astype (np .float32 ) / 255.0
192
+
193
+
194
+ class LazyFrames (object ):
195
+ def __init__ (self , frames ):
196
+ """This object ensures that common frames between the observations are only stored once.
197
+ It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
198
+ buffers.
199
+ This object should only be converted to numpy array before being passed to the model.
200
+ You'd not believe how complex the previous solution was."""
201
+ self ._frames = frames
202
+ self ._out = None
203
+
204
+ def _force (self ):
205
+ if self ._out is None :
206
+ self ._out = np .concatenate (self ._frames , axis = 2 )
207
+ self ._frames = None
208
+ return self ._out
209
+
210
+ def __array__ (self , dtype = None ):
211
+ out = self ._force ()
212
+ if dtype is not None :
213
+ out = out .astype (dtype )
214
+ return out
215
+
216
+ def __len__ (self ):
217
+ return len (self ._force ())
218
+
219
+ def __getitem__ (self , i ):
220
+ return self ._force ()[i ]
221
+
222
+
223
+ def make_atari (env_id ):
224
+ env = gym .make (env_id )
225
+ assert 'NoFrameskip' in env .spec .id
226
+ env = NoopResetEnv (env , noop_max = 30 )
227
+ env = MaxAndSkipEnv (env , skip = 4 )
228
+ return env
229
+
230
+
231
+ def wrap_deepmind (env , episode_life = True , clip_rewards = True , frame_stack = False , scale = False ):
232
+ """Configure environment for DeepMind-style Atari.
233
+ """
234
+ if episode_life :
235
+ env = EpisodicLifeEnv (env )
236
+ if 'FIRE' in env .unwrapped .get_action_meanings ():
237
+ env = FireResetEnv (env )
238
+ env = WarpFrame (env )
239
+ if scale :
240
+ env = ScaledFloatFrame (env )
241
+ if clip_rewards :
242
+ env = ClipRewardEnv (env )
243
+ if frame_stack :
244
+ env = FrameStack (env , 4 )
245
+ return env
246
+
247
+
248
+ class ImageToPyTorch (gym .ObservationWrapper ):
249
+ """
250
+ Image shape to num_channels x weight x height
251
+ """
252
+
253
+ def __init__ (self , env ):
254
+ super (ImageToPyTorch , self ).__init__ (env )
255
+ old_shape = self .observation_space .shape
256
+ self .observation_space = gym .spaces .Box (low = 0.0 , high = 1.0 , shape = (old_shape [- 1 ], old_shape [0 ], old_shape [1 ]),
257
+ dtype = np .uint8 )
258
+
259
+ def observation (self , observation ):
260
+ return np .swapaxes (observation , 2 , 0 )
261
+
262
+
263
+ def wrap_pytorch (env ):
264
+ return ImageToPyTorch (env )
0 commit comments