Skip to content

Commit 02d2519

Browse files
svekarsvmoens
andauthored
Bump torchrl and tensordict to 0.7.2 (#3298)
* Bump torchrl and torchdict to 0.7.2 * Add devices * fix semi_structured_sparse.py with default device * Disable semi sparse tutorial --------- Co-authored-by: Vincent Moens <[email protected]>
1 parent 6053b2a commit 02d2519

File tree

5 files changed

+7
-4
lines changed

5 files changed

+7
-4
lines changed

.ci/docker/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ tensorboard
2828
jinja2==3.1.3
2929
pytorch-lightning
3030
torchx
31-
torchrl==0.6.0
32-
tensordict==0.6.0
31+
torchrl==0.7.2
32+
tensordict==0.7.2
3333
ax-platform>=0.4.0
3434
nbformat>=5.9.2
3535
datasets

.jenkins/validate_tutorials_built.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"intermediate_source/flask_rest_api_tutorial",
5151
"intermediate_source/text_to_speech_with_torchaudio",
5252
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
53+
"advanced_source/semi_structured_sparse" # reenable after 3303 is fixed.
5354
]
5455

5556
def tutorial_source_dirs() -> List[Path]:

advanced_source/coding_ddpg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def ceil_div(x, y):
10401040

10411041
###############################################################################
10421042
# let's use the TD(lambda) estimator!
1043-
loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda)
1043+
loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda, device=device)
10441044

10451045
###############################################################################
10461046
# .. note::

advanced_source/semi_structured_sparse.py

+2
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@
210210
SparseSemiStructuredTensor._FORCE_CUTLASS = True
211211
torch.manual_seed(100)
212212

213+
# Set default device to "cuda:0"
214+
torch.set_default_device(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
213215

214216
######################################################################
215217
# We’ll also need to define some helper functions that are specific to the

intermediate_source/reinforcement_ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@
551551
#
552552

553553
advantage_module = GAE(
554-
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
554+
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
555555
)
556556

557557
loss_module = ClipPPOLoss(

0 commit comments

Comments
 (0)