Skip to content
This repository was archived by the owner on May 3, 2022. It is now read-only.

Commit c62e925

Browse files
authored
Add custom objects support + bug fix (DLR-RM#336)
* Add support for custom objects * Add python 3.8 to the CI * Bump version * PyType fixes * [ci skip] Fix typo * Add note about slow-down + fix typos * Minor edits to the doc * Bug fix for DQN * Update test * Add test for custom objects
1 parent f13de5b commit c62e925

27 files changed

+118
-60
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
runs-on: ubuntu-latest
1717
strategy:
1818
matrix:
19-
python-version: [3.6, 3.7] # 3.8 not supported yet by pytype
19+
python-version: [3.6, 3.7, 3.8]
2020

2121
steps:
2222
- uses: actions/checkout@v2

Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
1414
libglib2.0-0 && \
1515
rm -rf /var/lib/apt/lists/*
1616

17-
# Install anaconda abd dependencies
17+
# Install Anaconda and dependencies
1818
RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \
1919
chmod +x ~/miniconda.sh && \
2020
~/miniconda.sh -b -p /opt/conda && \

docs/guide/migration.rst

+7
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ You can also take a look at the `rl-zoo3 <https://github.com/DLR-RM/rl-baselines
3333
to the `rl-zoo <https://github.com/araffin/rl-baselines-zoo>`_ of SB2 to have a concrete example of successful migration.
3434

3535

36+
.. note::
37+
38+
If you experience massive slow-down switching to PyTorch, you may need to play with the number of threads used,
39+
using ``torch.set_num_threads(1)`` or ``OMP_NUM_THREADS=1``, see `issue #122 <https://github.com/DLR-RM/stable-baselines3/issues/122>`_
40+
and `issue #90 <https://github.com/DLR-RM/stable-baselines3/issues/90>`_.
41+
42+
3643
Breaking Changes
3744
================
3845

docs/guide/rl_tips.rst

+7-7
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ Discrete Actions
119119
Discrete Actions - Single Process
120120
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
121121

122-
DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
123-
We notably provide QR-DQN in our :ref:`contrib repo <sb3_contrib>`.
124-
DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
122+
``DQN`` with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
123+
We notably provide ``QR-DQN`` in our :ref:`contrib repo <sb3_contrib>`.
124+
``DQN`` is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).
125125

126126
Discrete Actions - Multiprocessed
127127
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
128128

129-
You should give a try to PPO or A2C.
129+
You should give a try to ``PPO`` or ``A2C``.
130130

131131

132132
Continuous Actions
@@ -142,7 +142,7 @@ Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-base
142142
Continuous Actions - Multiprocessed
143143
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
144144

145-
Take a look at PPO, TRPO or A2C. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
145+
Take a look at ``PPO`` or ``A2C``. Again, don't forget to take the hyperparameters from the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_
146146
for continuous actions problems (cf *Bullet* envs).
147147

148148
.. note::
@@ -155,12 +155,12 @@ Goal Environment
155155
-----------------
156156

157157
If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
158-
HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space.
158+
HER + (SAC/TD3/DDPG/DQN/QR-DQN/TQC) depending on the action space.
159159

160160

161161
.. note::
162162

163-
The number of workers is an important hyperparameters for experiments with HER
163+
The ``batch_size`` is an important hyperparameter for experiments with :ref:`HER <her>`
164164

165165

166166

docs/misc/changelog.rst

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,26 @@
33
Changelog
44
==========
55

6-
Release 1.0rc0 (2021-02-28)
6+
Release 1.0rc1 (WIP)
77
-------------------------------
88

99
Breaking Changes:
1010
^^^^^^^^^^^^^^^^^
1111
- Removed ``stable_baselines3.common.cmd_util`` (already deprecated), please use ``env_util`` instead
1212

13+
New Features:
14+
^^^^^^^^^^^^^
15+
- Added support for ``custom_objects`` when loading models
16+
17+
Bug Fixes:
18+
^^^^^^^^^^
19+
- Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space
20+
1321
Documentation:
1422
^^^^^^^^^^^^^^
1523
- Fixed examples
1624
- Added new project using SB3: rl_reach (@PierreExeter)
25+
- Added note about slow-down when switching to PyTorch
1726
- Add a note on continual learning and resetting environment
1827

1928

docs/modules/a2c.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,12 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
5353
import gym
5454
5555
from stable_baselines3 import A2C
56-
from stable_baselines3.a2c import MlpPolicy
5756
from stable_baselines3.common.env_util import make_vec_env
5857
5958
# Parallel environments
60-
env = make_vec_env('CartPole-v1', n_envs=4)
59+
env = make_vec_env("CartPole-v1", n_envs=4)
6160
62-
model = A2C(MlpPolicy, env, verbose=1)
61+
model = A2C("MlpPolicy", env, verbose=1)
6362
model.learn(total_timesteps=25000)
6463
model.save("a2c_cartpole")
6564

docs/modules/ddpg.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ Example
6363
from stable_baselines3 import DDPG
6464
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
6565
66-
env = gym.make('Pendulum-v0')
66+
env = gym.make("Pendulum-v0")
6767
6868
# The noise objects for DDPG
6969
n_actions = env.action_space.shape[-1]
7070
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
7171
72-
model = DDPG('MlpPolicy', env, action_noise=action_noise, verbose=1)
72+
model = DDPG("MlpPolicy", env, action_noise=action_noise, verbose=1)
7373
model.learn(total_timesteps=10000, log_interval=10)
7474
model.save("ddpg_pendulum")
7575
env = model.get_env()

docs/modules/dqn.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,10 @@ Example
5656
import numpy as np
5757
5858
from stable_baselines3 import DQN
59-
from stable_baselines3.dqn import MlpPolicy
6059
61-
env = gym.make('CartPole-v0')
60+
env = gym.make("CartPole-v0")
6261
63-
model = DQN(MlpPolicy, env, verbose=1)
62+
model = DQN("MlpPolicy", env, verbose=1)
6463
model.learn(total_timesteps=10000, log_interval=4)
6564
model.save("dqn_pendulum")
6665

docs/modules/her.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Notes
5252
Can I use?
5353
----------
5454

55-
Please refer to the used model (DQN, SAC, TD3 or DDPG) for that section.
55+
Please refer to the used model (DQN, QR-DQN, SAC, TQC, TD3, or DDPG) for that section.
5656

5757
Example
5858
-------

docs/modules/ppo.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@ Train a PPO agent on ``Pendulum-v0`` using 4 environments.
5454
import gym
5555
5656
from stable_baselines3 import PPO
57-
from stable_baselines3.ppo import MlpPolicy
5857
from stable_baselines3.common.env_util import make_vec_env
5958
6059
# Parallel environments
61-
env = make_vec_env('CartPole-v1', n_envs=4)
60+
env = make_vec_env("CartPole-v1", n_envs=4)
6261
63-
model = PPO(MlpPolicy, env, verbose=1)
62+
model = PPO("MlpPolicy", env, verbose=1)
6463
model.learn(total_timesteps=25000)
6564
model.save("ppo_cartpole")
6665

docs/modules/sac.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ Example
6868
import numpy as np
6969
7070
from stable_baselines3 import SAC
71-
from stable_baselines3.sac import MlpPolicy
7271
73-
env = gym.make('Pendulum-v0')
72+
env = gym.make("Pendulum-v0")
7473
75-
model = SAC(MlpPolicy, env, verbose=1)
74+
model = SAC("MlpPolicy", env, verbose=1)
7675
model.learn(total_timesteps=10000, log_interval=4)
7776
model.save("sac_pendulum")
7877

docs/modules/td3.rst

+2-3
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,15 @@ Example
6161
import numpy as np
6262
6363
from stable_baselines3 import TD3
64-
from stable_baselines3.td3.policies import MlpPolicy
6564
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
6665
67-
env = gym.make('Pendulum-v0')
66+
env = gym.make("Pendulum-v0")
6867
6968
# The noise objects for TD3
7069
n_actions = env.action_space.shape[-1]
7170
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))
7271
73-
model = TD3(MlpPolicy, env, action_noise=action_noise, verbose=1)
72+
model = TD3("MlpPolicy", env, action_noise=action_noise, verbose=1)
7473
model.learn(total_timesteps=10000, log_interval=10)
7574
model.save("td3_pendulum")
7675
env = model.get_env()

docs/spelling_wordlist.txt

+5
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,8 @@ cuda
119119
Polyak
120120
gSDE
121121
rollouts
122+
Pyro
123+
softmax
124+
stdout
125+
Contrib
126+
Quantile

stable_baselines3/common/base_class.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ def load(
586586
path: Union[str, pathlib.Path, io.BufferedIOBase],
587587
env: Optional[GymEnv] = None,
588588
device: Union[th.device, str] = "auto",
589+
custom_objects: Optional[Dict[str, Any]] = None,
589590
**kwargs,
590591
) -> "BaseAlgorithm":
591592
"""
@@ -596,9 +597,15 @@ def load(
596597
:param env: the new environment to run the loaded model on
597598
(can be None if you only need prediction from a trained model) has priority over any saved environment
598599
:param device: Device on which the code should run.
600+
:param custom_objects: Dictionary of objects to replace
601+
upon loading. If a variable is present in this dictionary as a
602+
key, it will not be deserialized and the corresponding item
603+
will be used instead. Similar to custom_objects in
604+
``keras.models.load_model``. Useful when you have an object in
605+
file that can not be deserialized.
599606
:param kwargs: extra arguments to change the model when loading
600607
"""
601-
data, params, pytorch_variables = load_from_zip_file(path, device=device)
608+
data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects)
602609

603610
# Remove stored device information and replace with ours
604611
if "policy_kwargs" in data:
@@ -625,7 +632,7 @@ def load(
625632
env = data["env"]
626633

627634
# noinspection PyArgumentList
628-
model = cls(
635+
model = cls( # pytype: disable=not-instantiable,wrong-keyword-args
629636
policy=data["policy_class"],
630637
env=env,
631638
device=device,

stable_baselines3/common/distributions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def atanh(x: th.Tensor) -> th.Tensor:
623623
"""
624624
Inverse of Tanh
625625
626-
Taken from pyro: https://github.com/pyro-ppl/pyro
626+
Taken from Pyro: https://github.com/pyro-ppl/pyro
627627
0.5 * torch.log((1 + x ) / (1 - x))
628628
"""
629629
return 0.5 * (x.log1p() - (-x).log1p())

stable_baselines3/common/evaluation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def evaluate_policy(
4141
called after each step. Gets locals() and globals() passed as parameters.
4242
:param reward_threshold: Minimum expected reward per episode,
4343
this will raise an error if the performance is not met
44-
:param return_episode_rewards: If True, a list of rewards and episde lengths
44+
:param return_episode_rewards: If True, a list of rewards and episode lengths
4545
per episode will be returned instead of the mean.
4646
:param warn: If True (default), warns user about lack of a Monitor wrapper in the
4747
evaluation environment.

stable_baselines3/common/off_policy_algorithm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _setup_model(self) -> None:
174174
self.device,
175175
optimize_memory_usage=self.optimize_memory_usage,
176176
)
177-
self.policy = self.policy_class(
177+
self.policy = self.policy_class( # pytype:disable=not-instantiable
178178
self.observation_space,
179179
self.action_space,
180180
self.lr_schedule,

stable_baselines3/common/on_policy_algorithm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _setup_model(self) -> None:
114114
gae_lambda=self.gae_lambda,
115115
n_envs=self.n_envs,
116116
)
117-
self.policy = self.policy_class(
117+
self.policy = self.policy_class( # pytype:disable=not-instantiable
118118
self.observation_space,
119119
self.action_space,
120120
self.lr_schedule,

stable_baselines3/common/policies.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,10 @@
1919
StateDependentNoiseDistribution,
2020
make_proba_distribution,
2121
)
22-
from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, preprocess_obs
22+
from stable_baselines3.common.preprocessing import get_action_dim, maybe_transpose, preprocess_obs
2323
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp
2424
from stable_baselines3.common.type_aliases import Schedule
2525
from stable_baselines3.common.utils import get_device, is_vectorized_observation
26-
from stable_baselines3.common.vec_env import VecTransposeImage
2726
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
2827

2928

@@ -266,17 +265,7 @@ def predict(
266265

267266
# Handle the different cases for images
268267
# as PyTorch use channel first format
269-
if is_image_space(self.observation_space):
270-
if not (
271-
observation.shape == self.observation_space.shape or observation.shape[1:] == self.observation_space.shape
272-
):
273-
# Try to re-order the channels
274-
transpose_obs = VecTransposeImage.transpose_image(observation)
275-
if (
276-
transpose_obs.shape == self.observation_space.shape
277-
or transpose_obs.shape[1:] == self.observation_space.shape
278-
):
279-
observation = transpose_obs
268+
observation = maybe_transpose(observation, self.observation_space)
280269

281270
vectorized_env = is_vectorized_observation(observation, self.observation_space)
282271

stable_baselines3/common/preprocessing.py

+20
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ def is_image_space(observation_space: spaces.Space, channels_last: bool = True,
6161
return False
6262

6363

64+
def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> np.ndarray:
65+
"""
66+
Handle the different cases for images as PyTorch use channel first format.
67+
68+
:param observation:
69+
:param observation_space:
70+
:return: channel first observation if observation is an image
71+
"""
72+
# Avoid circular import
73+
from stable_baselines3.common.vec_env import VecTransposeImage
74+
75+
if is_image_space(observation_space):
76+
if not (observation.shape == observation_space.shape or observation.shape[1:] == observation_space.shape):
77+
# Try to re-order the channels
78+
transpose_obs = VecTransposeImage.transpose_image(observation)
79+
if transpose_obs.shape == observation_space.shape or transpose_obs.shape[1:] == observation_space.shape:
80+
observation = transpose_obs
81+
return observation
82+
83+
6484
def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor:
6585
"""
6686
Preprocess observation to be to a neural network.

stable_baselines3/common/save_util.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
137137
upon loading. If a variable is present in this dictionary as a
138138
key, it will not be deserialized and the corresponding item
139139
will be used instead. Similar to custom_objects in
140-
`keras.models.load_model`. Useful when you have an object in
140+
``keras.models.load_model``. Useful when you have an object in
141141
file that can not be deserialized.
142142
:return: Loaded class parameters.
143143
"""
@@ -162,7 +162,7 @@ def json_to_data(json_string: str, custom_objects: Optional[Dict[str, Any]] = No
162162
try:
163163
base64_object = base64.b64decode(serialization.encode())
164164
deserialized_object = cloudpickle.loads(base64_object)
165-
except RuntimeError:
165+
except (RuntimeError, TypeError):
166166
warnings.warn(
167167
f"Could not deserialize object {data_key}. "
168168
+ "Consider using `custom_objects` argument to replace "
@@ -359,6 +359,7 @@ def load_from_pkl(path: Union[str, pathlib.Path, io.BufferedIOBase], verbose: in
359359
def load_from_zip_file(
360360
load_path: Union[str, pathlib.Path, io.BufferedIOBase],
361361
load_data: bool = True,
362+
custom_objects: Optional[Dict[str, Any]] = None,
362363
device: Union[th.device, str] = "auto",
363364
verbose: int = 0,
364365
) -> (Tuple[Optional[Dict[str, Any]], Optional[TensorDict], Optional[TensorDict]]):
@@ -368,6 +369,12 @@ def load_from_zip_file(
368369
:param load_path: Where to load the model from
369370
:param load_data: Whether we should load and return data
370371
(class parameters). Mainly used by 'load_parameters' to only load model parameters (weights)
372+
:param custom_objects: Dictionary of objects to replace
373+
upon loading. If a variable is present in this dictionary as a
374+
key, it will not be deserialized and the corresponding item
375+
will be used instead. Similar to custom_objects in
376+
``keras.models.load_model``. Useful when you have an object in
377+
file that can not be deserialized.
371378
:param device: Device on which the code should run.
372379
:return: Class parameters, model state_dicts (aka "params", dict of state_dict)
373380
and dict of pytorch variables
@@ -392,7 +399,7 @@ def load_from_zip_file(
392399
# Load class parameters that are stored
393400
# with either JSON or pickle (not PyTorch variables).
394401
json_data = archive.read("data").decode()
395-
data = json_to_data(json_data)
402+
data = json_to_data(json_data, custom_objects=custom_objects)
396403

397404
# Check for all .pth files and load them using th.load.
398405
# "pytorch_variables.pth" stores PyTorch variables, and any other .pth

0 commit comments

Comments
 (0)