Skip to content

Commit 2045932

Browse files
Restructure python folders, add basic gym environment and MAPPO (#70)
* Functioning single-agent SB3 PPO implementation. * Add env * Render update and working pygame example. * Restructure * Add basic pygame rgb_array example * Render bug fixes * Rendering progress * Add PPO support for multiple worlds * Cleanup * minor * WIP: benchmark ppo speed * minor * Cleanup * Reformatting and filter out expert obs using boolean mask. * Bug fix: Specify squeeze dim to avoid reducing to a one-dim tensor when using a single world. * Delete data_10 directory * Updated gym env. * Add support for multi-agent PPO * Minor env changes * remove AbsoluteSelfObsTensor and remove whitespaces. * rmv white spaces * rmv * Update render func
1 parent 665c4d8 commit 2045932

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1695
-4041
lines changed

Diff for: .gitignore

+223-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,229 @@
55
/CMakeSettings.json
66
/scripts/__pycache__
77
/train_src/madrona_escape_room_learn/__pycache__
8-
/nocturne_data*
8+
/waymo_data*
99
.vscode/launch.json
1010
.vscode/settings.json
1111
.vscode/tasks.json
12+
13+
/cloudpickle
14+
/cloudpickle-3.0.0.dist-info
15+
/bin
16+
/zipp*
17+
.python-version
18+
19+
# Data
20+
/formatted_json_v2_no_tl_train
21+
/data_10
22+
/data_100
23+
/data_1000
24+
25+
# Logging
26+
/wandb
27+
events.out.tfevents.*
28+
29+
### C++ ###
30+
# Prerequisites
31+
*.d
32+
33+
# Compiled Object files
34+
*.slo
35+
*.lo
36+
*.o
37+
*.obj
38+
39+
# Precompiled Headers
40+
*.gch
41+
*.pch
42+
43+
# Compiled Dynamic libraries
44+
*.so
45+
*.dylib
46+
*.dll
47+
48+
# Fortran module files
49+
*.mod
50+
*.smod
51+
52+
# Compiled Static libraries
53+
*.lai
54+
*.la
55+
*.a
56+
*.lib
57+
58+
# Executables
59+
*.exe
60+
*.out
61+
*.app
62+
63+
### Python ###
64+
# Byte-compiled / optimized / DLL files
65+
__pycache__/
66+
*.py[cod]
67+
*$py.class
68+
69+
# C extensions
70+
71+
# Distribution / packaging
72+
.Python
73+
build/
74+
develop-eggs/
75+
dist/
76+
downloads/
77+
eggs/
78+
.eggs/
79+
lib/
80+
lib64/
81+
parts/
82+
sdist/
83+
var/
84+
wheels/
85+
share/python-wheels/
86+
*.egg-info/
87+
.installed.cfg
88+
*.egg
89+
MANIFEST
90+
91+
# PyInstaller
92+
# Usually these files are written by a python script from a template
93+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
94+
*.manifest
95+
*.spec
96+
97+
# Installer logs
98+
pip-log.txt
99+
pip-delete-this-directory.txt
100+
101+
# Unit test / coverage reports
102+
htmlcov/
103+
.tox/
104+
.nox/
105+
.coverage
106+
.coverage.*
107+
.cache
108+
nosetests.xml
109+
coverage.xml
110+
*.cover
111+
*.py,cover
112+
.hypothesis/
113+
.pytest_cache/
114+
cover/
115+
116+
# Translations
117+
*.mo
118+
*.pot
119+
120+
# Django stuff:
121+
*.log
122+
local_settings.py
123+
db.sqlite3
124+
db.sqlite3-journal
125+
126+
# Flask stuff:
127+
instance/
128+
.webassets-cache
129+
130+
# Scrapy stuff:
131+
.scrapy
132+
133+
# Sphinx documentation
134+
docs/_build/
135+
136+
# PyBuilder
137+
.pybuilder/
138+
target/
139+
140+
# Jupyter Notebook
141+
.ipynb_checkpoints
142+
143+
# IPython
144+
profile_default/
145+
ipython_config.py
146+
147+
# pyenv
148+
# For a library or package, you might want to ignore these files since the code is
149+
# intended to run in multiple environments; otherwise, check them in:
150+
# .python-version
151+
152+
# pipenv
153+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
154+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
155+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
156+
# install all needed dependencies.
157+
#Pipfile.lock
158+
159+
# poetry
160+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
161+
# This is especially recommended for binary packages to ensure reproducibility, and is more
162+
# commonly ignored for libraries.
163+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
164+
#poetry.lock
165+
166+
# pdm
167+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
168+
#pdm.lock
169+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
170+
# in version control.
171+
# https://pdm.fming.dev/#use-with-ide
172+
.pdm.toml
173+
174+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
175+
__pypackages__/
176+
177+
# Celery stuff
178+
celerybeat-schedule
179+
celerybeat.pid
180+
181+
# SageMath parsed files
182+
*.sage.py
183+
184+
# Environments
185+
.env
186+
.venv
187+
venv/
188+
ENV/
189+
env.bak/
190+
venv.bak/
191+
192+
# Spyder project settings
193+
.spyderproject
194+
.spyproject
195+
196+
# Rope project settings
197+
.ropeproject
198+
199+
# mkdocs documentation
200+
/site
201+
202+
# mypy
203+
.mypy_cache/
204+
.dmypy.json
205+
dmypy.json
206+
207+
# Pyre type checker
208+
.pyre/
209+
210+
# pytype static type analyzer
211+
.pytype/
212+
213+
# Cython debug symbols
214+
cython_debug/
215+
216+
# PyCharm
217+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
218+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
219+
# and can be added to the global gitignore or merged into this file. For a more nuclear
220+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
221+
#.idea/
222+
223+
### Python Patch ###
224+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
225+
poetry.toml
226+
227+
# ruff
228+
.ruff_cache/
229+
230+
# LSP config files
231+
pyrightconfig.json
232+
233+
# End of https://www.toptal.com/developers/gitignore/api/python,c++

Diff for: .pre-commit-config.yaml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v4.0.1 # Use the latest version
4+
hooks:
5+
- id: trailing-whitespace
6+
- id: end-of-file-fixer
7+
- id: check-yaml
8+
- id: check-added-large-files
9+
10+
- repo: https://github.com/pycqa/flake8
11+
rev: 3.9.2 # Use the latest version
12+
hooks:
13+
- id: flake8
14+
- repo: https://github.com/psf/black
15+
rev: 22.3.0 # Use the latest version
16+
hooks:
17+
- id: black
18+
args: [--line-length, "79"]

Diff for: README.md

+1-7
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,6 @@ Citation
169169
If you use Madrona in a research project, please cite our SIGGRAPH paper.
170170
171171
```
172-
@article{shacklett23madrona,
173-
title = {An Extensible, Data-Oriented Architecture for High-Performance, Many-World Simulation},
174-
author = {Brennan Shacklett and Luc Guy Rosenzweig and Zhiqiang Xie and Bidipta Sarkar and Andrew Szot and Erik Wijmans and Vladlen Koltun and Dhruv Batra and Kayvon Fatahalian},
175-
journal = {ACM Trans. Graph.},
176-
volume = {42},
177-
number = {4},
178-
year = {2023}
172+
@article{...,
179173
}
180174
```

Diff for: algorithms/__init__.py

Whitespace-only changes.

Diff for: algorithms/ppo/cleanrl/.gitkeep

Whitespace-only changes.

Diff for: algorithms/ppo/cleanrl/__init__.py

Whitespace-only changes.

Diff for: algorithms/ppo/sb3/__init__.py

Whitespace-only changes.

Diff for: algorithms/ppo/sb3/callbacks.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
import os
3+
import wandb
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
from stable_baselines3.common.callbacks import BaseCallback
8+
from stable_baselines3.common.policies import ActorCriticPolicy
9+
10+
11+
class MultiAgentCallback(BaseCallback):
12+
"""SB3 callback for gpudrive."""
13+
def __init__(
14+
self,
15+
wandb_run=None,
16+
**kwargs,
17+
) -> None:
18+
super().__init__(**kwargs)
19+
self.wandb_run = wandb_run
20+
21+
def _on_training_start(self) -> None:
22+
"""
23+
This method is called before the first rollout starts.
24+
"""
25+
pass
26+
27+
def _on_rollout_start(self) -> None:
28+
"""
29+
A rollout is the collection of environment interaction
30+
using the current policy.
31+
This event is triggered before collecting new samples.
32+
"""
33+
pass
34+
35+
def _on_step(self) -> bool:
36+
"""
37+
This method will be called by the model after each call to `env.step()`.
38+
"""
39+
pass
40+
41+
def _on_rollout_end(self) -> None:
42+
"""
43+
This event is triggered before updating the policy.
44+
"""
45+
46+
rewards = self.locals["rollout_buffer"].rewards.cpu().detach().numpy().flatten()
47+
48+
self.logger.record("rollout/global_step", self.num_timesteps)
49+
self.logger.record("rollout/avg_reward", np.mean(rewards))
50+
self.logger.record("rollout/std_reward", np.std(rewards))
51+
52+
53+
54+
55+

0 commit comments

Comments
 (0)