@@ -74,15 +74,15 @@ def test_networkbody_lstm():
74
74
def test_networkbody_visual ():
75
75
torch .manual_seed (1 )
76
76
vec_obs_size = 4
77
- obs_size = (84 , 84 , 3 )
77
+ obs_size = (3 , 84 , 84 )
78
78
network_settings = NetworkSettings ()
79
79
obs_shapes = [(vec_obs_size ,), obs_size ]
80
80
81
81
networkbody = NetworkBody (
82
82
create_observation_specs_with_shapes (obs_shapes ), network_settings
83
83
)
84
84
optimizer = torch .optim .Adam (networkbody .parameters (), lr = 3e-3 )
85
- sample_obs = 0.1 * torch .ones ((1 , 84 , 84 , 3 ), dtype = torch .float32 )
85
+ sample_obs = 0.1 * torch .ones ((1 , 3 , 84 , 84 ), dtype = torch .float32 )
86
86
sample_vec_obs = torch .ones ((1 , vec_obs_size ), dtype = torch .float32 )
87
87
obs = [sample_vec_obs ] + [sample_obs ]
88
88
loss = 1
@@ -200,7 +200,7 @@ def test_multinetworkbody_visual(with_actions):
200
200
act_size = 2
201
201
n_agents = 3
202
202
obs_size = 4
203
- vis_obs_size = (84 , 84 , 3 )
203
+ vis_obs_size = (3 , 84 , 84 )
204
204
network_settings = NetworkSettings ()
205
205
obs_shapes = [(obs_size ,), vis_obs_size ]
206
206
action_spec = ActionSpec (act_size , tuple (act_size for _ in range (act_size )))
@@ -209,7 +209,7 @@ def test_multinetworkbody_visual(with_actions):
209
209
)
210
210
optimizer = torch .optim .Adam (networkbody .parameters (), lr = 3e-3 )
211
211
sample_obs = [
212
- [0.1 * torch .ones ((1 , obs_size ))] + [0.1 * torch .ones ((1 , 84 , 84 , 3 ))]
212
+ [0.1 * torch .ones ((1 , obs_size ))] + [0.1 * torch .ones ((1 , 3 , 84 , 84 ))]
213
213
for _ in range (n_agents )
214
214
]
215
215
# simulate baseline in POCA
@@ -273,7 +273,7 @@ def test_valuenetwork():
273
273
@pytest .mark .parametrize ("lstm" , [True , False ])
274
274
def test_actor_critic (lstm , shared ):
275
275
obs_size = 4
276
- vis_obs_size = (84 , 84 , 3 )
276
+ vis_obs_size = (3 , 84 , 84 )
277
277
network_settings = NetworkSettings (
278
278
memory = NetworkSettings .MemorySettings () if lstm else None , normalize = True
279
279
)
@@ -291,14 +291,14 @@ def test_actor_critic(lstm, shared):
291
291
critic = ValueNetwork (stream_names , obs_spec , network_settings )
292
292
if lstm :
293
293
sample_vis_obs = torch .ones (
294
- (network_settings .memory .sequence_length , 84 , 84 , 3 ), dtype = torch .float32
294
+ (network_settings .memory .sequence_length , 3 , 84 , 84 ), dtype = torch .float32
295
295
)
296
296
sample_obs = torch .ones ((network_settings .memory .sequence_length , obs_size ))
297
297
memories = torch .ones (
298
298
(1 , network_settings .memory .sequence_length , actor .memory_size )
299
299
)
300
300
else :
301
- sample_vis_obs = 0.1 * torch .ones ((1 , 84 , 84 , 3 ), dtype = torch .float32 )
301
+ sample_vis_obs = 0.1 * torch .ones ((1 , 3 , 84 , 84 ), dtype = torch .float32 )
302
302
sample_obs = torch .ones ((1 , obs_size ))
303
303
memories = torch .tensor ([])
304
304
# memories isn't always set to None, the network should be able to
0 commit comments