diff --git a/.github/workflows/tests_suite.yml b/.github/workflows/tests_suite.yml index 0fdefc7fec..a2e65af3c5 100644 --- a/.github/workflows/tests_suite.yml +++ b/.github/workflows/tests_suite.yml @@ -80,6 +80,7 @@ jobs: - simulation/human_model_generation - perception/facial_expression_recognition - control/single_demo_grasp + - planning/end_to_end_planning # - perception/object_tracking_3d include: - os: ubuntu-20.04 @@ -328,6 +329,7 @@ jobs: - control/mobile_manipulation - simulation/human_model_generation - control/single_demo_grasp + - planning/end_to_end_planning # - perception/object_tracking_3d runs-on: ubuntu-20.04 steps: diff --git a/.github/workflows/tests_suite_develop.yml b/.github/workflows/tests_suite_develop.yml index 4f6ddaaaa3..8494c7550d 100644 --- a/.github/workflows/tests_suite_develop.yml +++ b/.github/workflows/tests_suite_develop.yml @@ -79,6 +79,7 @@ jobs: - simulation/human_model_generation - perception/facial_expression_recognition - control/single_demo_grasp + - planning/end_to_end_planning # - perception/object_tracking_3d include: - os: ubuntu-20.04 @@ -330,6 +331,7 @@ jobs: - control/mobile_manipulation - simulation/human_model_generation - control/single_demo_grasp + - planning/end_to_end_planning # - perception/object_tracking_3d runs-on: ubuntu-20.04 steps: diff --git a/CHANGELOG.md b/CHANGELOG.md index ae0dc87ed7..6b2cea8f27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ Released on XX, XXth, 2022. - New Features: - - None. + - Added end-to-end planning tool ([#223](https://github.com/opendr-eu/opendr/pull/223)). - Enhancements: - Added support for modular pip packages allowing tools to be installed separately ([#201](https://github.com/opendr-eu/opendr/pull/201)). - Simplified the installation process for pip by including the appropriate post-installation scripts ([#201](https://github.com/opendr-eu/opendr/pull/201)). diff --git a/docs/reference/end-to-end-planning.md b/docs/reference/end-to-end-planning.md new file mode 100644 index 0000000000..79ff34f7a0 --- /dev/null +++ b/docs/reference/end-to-end-planning.md @@ -0,0 +1,185 @@ +# end_to_end_planning module + +The *end_to_end_planning* module contains the *EndToEndPlanningRLLearner* class, which inherits from the abstract +class *LearnerRL*. + +### Class EndToEndPlanningRLLearner +Bases: `engine.learners.LearnerRL` + +The *EndToEndPlanningRLLearner* is an agent that can be used to train quadrotor robots equipped with a depth sensor to +follow a provided trajectory while avoiding obstacles. + +The [EndToEndPlanningRLLearner](/src/opendr/planning/end_to_end_planning/e2e_planning_learner.py) class has the +following public methods: + +#### `EndToEndPlanningRLLearner` constructor + +Constructor parameters: + +- **env**: *gym.Env*\ + Reinforcment learning environment to train or evaluate the agent on. +- **lr**: *float, default=3e-4*\ + Specifies the initial learning rate to be used during training. +- **n_steps**: *int, default=1024*\ + Specifies the number of steps to run for environment per update. +- **iters**: *int, default=5e4*\ + Specifies the number of steps the training should run for. +- **batch_size**: *int, default=64*\ + Specifies the batch size during training. +- **checkpoint_after_iter**: *int, default=500*\ + Specifies per how many training steps a checkpoint should be saved. +- **temp_path**: *str, default=''*\ + Specifies a path where the algorithm stores log files and saves checkpoints. +- **device**: *{'cpu', 'cuda'}, default='cuda'*\ + Specifies the device to be used. + +#### `EndToEndPlanningRLLearner.fit` +```python +EndToEndPlanningRLLearner.fit(self, env, logging_path, silent, verbose) +``` + +Train the agent on the environment. + +Parameters: + +- **env**: *gym.Env, default=None*\ + If specified use this env to train. +- **logging_path**: *str, default=''*\ + Path for logging and checkpointing. +- **silent**: *bool, default=False*\ + Disable verbosity. +- **verbose**: *bool, default=True*\ + Enable verbosity. + + +#### `EndToEndPlanningRLLearner.eval` +```python +EndToEndPlanningRLLearner.eval(self, env) +``` +Evaluate the agent on the specified environment. + +Parameters: + +- **env**: *gym.Env, default=None*\ + Environment to evaluate on. + + +#### `EndToEndPlanningRLLearner.save` +```python +EndToEndPlanningRLLearner.save(self, path) +``` +Saves the model in the path provided. + +Parameters: + +- **path**: *str*\ + Path to save the model, including the filename. + + +#### `EndToEndPlanningRLLearner.load` +```python +EndToEndPlanningRLLearner.load(self, path) +``` +Loads a model from the path provided. + +Parameters: + +- **path**: *str*\ + Path of the model to be loaded. + + +#### `EndToEndPlanningRLLearner.infer` +```python +EndToEndPlanningRLLearner.infer(self, batch, deterministic) +``` +Performs inference on a single observation or a list of observations. + +Parameters: + +- **batch**: *dict or list of dict, default=None*\ + Single observation or list of observations. +- **deterministic**: *bool, default=True*\ + Use deterministic actions from the policy + +### Simulation environment setup + +The environment includes an Ardupilot controlled quadrotor in Webots simulation. +For the installation of Ardupilot instructions are available [here](https://github.com/ArduPilot/ardupilot). + +The required files to complete Ardupilot setup can be downloaded by running [`download_ardupilot_files.py`](src/opendr/planning/end_to_end_planning/download_ardupilot_files.py) script. +The downloaded files (zipped as `ardupilot.zip`) should be replaced under the installation of Ardupilot. +In order to run Ardupilot in Webots 2021a, controller codes should be replaced. (For older versions of Webots, these files can be skipped.) +The world file for the environment is provided under `/ardupilot/libraries/SITL/examples/webots/worlds/` for training and testing. + +Install `mavros` package for ROS communication with Ardupilot. +Instructions are available [here](https://github.com/mavlink/mavros/blob/master/mavros/README.md#installation). +Source installation is recomended. + +### Running the environment + +The following steps should be executed to have a ROS communication between Gym environment and simulation. +- Start the Webots and open the provided world file. +The simulation time should stop at first time step and wait for Ardupilot software to run. +- Run following script from Ardupilot directory: `./libraries/SITL/examples/Webots/dronePlus.sh` which starts software in the loop execution of the Ardupilot software. +- Run `roscore`. +- Run `roslaunch mavros apm.launch` which creates ROS communication for Ardupilot. +- Run following ROS nodes in `src/opendr/planning/end_to_end_planning/src`: + - `children_robot` which activates required sensors on quadrotor and creates ROS communication for them. + - `take_off` which takes off the quadrotor. + - `range_image` which converts the depth image into array format to be input for the learner. + +After these steps the [AgiEnv](src/opendr/planning/end_to_end_planning/envs/agi_env.py) gym environment can send action comments to the simulated drone and receive depth image and pose information from simulation. + +### Examples + +Training in Webots environment: + +```python +from opendr.planning.end_to_end_planning import EndToEndPlanningRLLearner, AgiEnv + +env = AgiEnv() +learner = EndToEndPlanningRLLearner(env, n_steps=1024) +learner.fit(logging_path='./end_to_end_planning_tmp') +``` + + +Running a pretrained model: + +```python +from opendr.planning.end_to_end_planning import EndToEndPlanningRLLearner, AgiEnv + +env = AgiEnv() +learner = EndToEndPlanningRLLearner(env) +learner.load('{$OPENDR_HOME}/src/opendr/planning/end_to_end_planning/pretrained_model/saved_model.zip') +obs = env.reset() +sum_of_rew = 0 +number_of_timesteps = 20 +for i in range(number_of_timesteps): + action, _states = learner.infer(obs, deterministic=True) + obs, rewards, dones, info = env.step(action) + sum_of_rew += rewards + if dones: + obs = env.reset() +print("Reward collected is:", sum_of_rew) +``` + +### Performance Evaluation + +TABLE 1: Speed (FPS) and energy consumption for inference on various platforms. + +| | TX2 | Xavier | RTX 2080 Ti | +| --------------- | ----- | ------ | ----------- | +| FPS Evaluation | 153.5 | 201.6 | 973.6 | +| Energy (Joules) | 0.12 | 0.051 | \- | + +TABLE 2: Platform compatibility evaluation. + +| Platform | Test results | +| -------------------------------------------- | ------------ | +| x86 - Ubuntu 20.04 (bare installation - CPU) | Pass | +| x86 - Ubuntu 20.04 (bare installation - GPU) | Pass | +| x86 - Ubuntu 20.04 (pip installation) | Pass | +| x86 - Ubuntu 20.04 (CPU docker) | Pass | +| x86 - Ubuntu 20.04 (GPU docker) | Pass | +| NVIDIA Jetson TX2 | Pass | +| NVIDIA Jetson Xavier AGX | Pass | \ No newline at end of file diff --git a/src/opendr/planning/end_to_end_planning/README.md b/src/opendr/planning/end_to_end_planning/README.md new file mode 100644 index 0000000000..a8e285992d --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/README.md @@ -0,0 +1,33 @@ +# End-to-end Planning + +This folder contains the OpenDR Learner class for end-to-end planning tasks. +This method uses reinforcement learning to train an agent that is able to generate local motion plans for a quadrotor UAV equipped with a depth camera. + +### Simulation environment setup + +The environment includes an Ardupilot controlled quadrotor in Webots simulation. +For the installation of Ardupilot instructions are available [here](https://github.com/ArduPilot/ardupilot). + +The required files to complete Ardupilot setup can be downloaded by running [`download_ardupilot_files.py`](src/opendr/planning/end_to_end_planning/download_ardupilot_files.py) script. +The downloaded files (zipped as `ardupilot.zip`) should be replaced under the installation of Ardupilot. +In order to run Ardupilot in Webots 2021a, controller codes should be replaced. (For older versions of Webots, these files can be skipped.) +The world file for the environment is provided under `/ardupilot/libraries/SITL/examples/webots/worlds/` for training and testing. + +Install `mavros` package for ROS communication with Ardupilot. +Instructions are available [here](https://github.com/mavlink/mavros/blob/master/mavros/README.md#installation). +Source installation is recomended. + +### Running the environment + +The following steps should be executed to have a ROS communication between Gym environment and simulation. +- Start the Webots and open the provided world file. +The simulation time should stop at first time step and wait for Ardupilot software to run. +- Run following script from Ardupilot directory: `./libraries/SITL/examples/Webots/dronePlus.sh` which starts software in the loop execution of the Ardupilot software. +- Run `roscore`. +- Run `roslaunch mavros apm.launch` which creates ROS communication for Ardupilot. +- Run following ROS nodes in `src/opendr/planning/end_to_end_planning/src`: + - `children_robot` which activates required sensors on quadrotor and creates ROS communication for them. + - `take_off` which takes off the quadrotor. + - `range_image` which converts the depth image into array format to be input for the learner. + +After these steps the [AgiEnv](src/opendr/planning/end_to_end_planning/envs/agi_env.py) gym environment can send action comments to the simulated drone and receive depth image and pose information from simulation. diff --git a/src/opendr/planning/end_to_end_planning/__init__.py b/src/opendr/planning/end_to_end_planning/__init__.py new file mode 100644 index 0000000000..3f5a5c45e9 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/__init__.py @@ -0,0 +1,4 @@ +from opendr.planning.end_to_end_planning.e2e_planning_learner import EndToEndPlanningRLLearner +from opendr.planning.end_to_end_planning.envs.agi_env import AgiEnv + +__all__ = ['EndToEndPlanningRLLearner', 'AgiEnv'] diff --git a/src/opendr/planning/end_to_end_planning/dependencies.ini b/src/opendr/planning/end_to_end_planning/dependencies.ini new file mode 100644 index 0000000000..81de02b377 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/dependencies.ini @@ -0,0 +1,16 @@ +[compilation] +linux=libeigen3-dev +python=vcstool + rosdep + rospkg + catkin_pkg + catkin_tools + roslibpy + empy + gym==0.20.0 + stable-baselines3==1.1.0 +[runtime] +# 'python' key expects a value using the Python requirements file format +# https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format +python=stable-baselines3 +linux=ros-noetic-webots-ros diff --git a/src/opendr/planning/end_to_end_planning/download_ardupilot_files.py b/src/opendr/planning/end_to_end_planning/download_ardupilot_files.py new file mode 100644 index 0000000000..6a14266981 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/download_ardupilot_files.py @@ -0,0 +1,20 @@ +# Copyright 2020-2022 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from urllib.request import urlretrieve +from opendr.engine.constants import OPENDR_SERVER_URL + +url = OPENDR_SERVER_URL + "planning/end_to_end_planning/ardupilot.zip" +file_destination = "./ardupilot.zip" +urlretrieve(url=url, filename=file_destination) diff --git a/src/opendr/planning/end_to_end_planning/e2e_planning_learner.py b/src/opendr/planning/end_to_end_planning/e2e_planning_learner.py new file mode 100644 index 0000000000..a0086a5ff2 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/e2e_planning_learner.py @@ -0,0 +1,176 @@ +# Copyright 2020-2022 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import rospy +import gym +import os +from pathlib import Path +from urllib.request import urlretrieve + +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv +from stable_baselines3.common.results_plotter import load_results, ts2xy + +from opendr.engine.learners import LearnerRL +from opendr.engine.constants import OPENDR_SERVER_URL + +__all__ = ["rospy", ] + + +class EndToEndPlanningRLLearner(LearnerRL): + def __init__(self, env, lr=3e-4, n_steps=1024, iters=int(5e4), batch_size=64, checkpoint_after_iter=500, + temp_path='', device='cuda'): + """ + Specifies a proximal policy optimization (PPO) agent that can be trained for end to end planning for obstacle avoidance. + Internally uses Stable-Baselines (https://github.com/hill-a/stable-baselines). + """ + super(EndToEndPlanningRLLearner, self).__init__(lr=lr, iters=iters, batch_size=batch_size, optimizer='adam', + network_head='', temp_path=temp_path, + checkpoint_after_iter=checkpoint_after_iter, + device=device, threshold=0.0, scale=1.0) + self.env = env + if isinstance(self.env, DummyVecEnv): + self.env = self.env.envs[0] + self.env = DummyVecEnv([lambda: self.env]) + self.agent = PPO("MultiInputPolicy", self.env, learning_rate=self.lr, n_steps=n_steps, + batch_size=self.batch_size, verbose=1) + + def download(self, path=None, + url=OPENDR_SERVER_URL + "planning/end_to_end_planning"): + if path is None: + path = "./end_to_end_planning_tmp/" + filename = "ardupilot.zip" + file_destination = Path(path) / filename + if not file_destination.exists(): + file_destination.parent.mkdir(parents=True, exist_ok=True) + url = os.path.join(url, filename) + urlretrieve(url=url, filename=file_destination) + return file_destination + + def fit(self, env=None, logging_path='', silent=False, verbose=True): + """ + Train the agent on the environment. + + :param env: gym.Env, optional, if specified use this env to train + :param logging_path: str, path for logging and checkpointing + :param silent: bool, disable verbosity + :param verbose: bool, enable verbosity + :return: + """ + if env is not None: + if isinstance(env, gym.Env): + self.env = env + else: + print('env should be gym.Env') + return + self.last_checkpoint_time_step = 0 + self.mean_reward = -10 + self.logdir = logging_path + if isinstance(self.env, DummyVecEnv): + self.env = self.env.envs[0] + if isinstance(self.env, Monitor): + self.env = self.env.env + self.env = Monitor(self.env, filename=self.logdir) + self.env = DummyVecEnv([lambda: self.env]) + self.agent.set_env(self.env) + self.agent.learn(total_timesteps=self.iters, callback=self.callback) + return {"last_20_episodes_mean_reward": self.mean_reward} + + def eval(self, env): + """ + Evaluate the agent on the specified environment. + + :param env: gym.Env, env to evaluate on + :return: sum of rewards through the episode + """ + if isinstance(env, DummyVecEnv): + env = env.envs[0] + if isinstance(env, Monitor): + env = env.env + # env = Monitor(env, filename=self.logdir) + env = DummyVecEnv([lambda: env]) + self.agent.set_env(env) + obs = env.reset() + sum_of_rewards = 0 + for i in range(50): + action, _states = self.agent.predict(obs, deterministic=True) + obs, rewards, dones, info = env.step(action) + sum_of_rewards += rewards + if dones: + break + return {"rewards_collected": sum_of_rewards} + + def save(self, path): + """ + Saves the model in the path provided. + + :param path: Path to save directory + :type path: str + :return: Whether save succeeded or not + :rtype: bool + """ + self.agent.save(path) + + def load(self, path): + """ + Loads a model from the path provided. + + :param path: Path to saved model + :type path: str + :return: Whether load succeeded or not + :rtype: bool + """ + self.agent = PPO.load(path) + self.agent.set_env(self.env) + + def infer(self, batch, deterministic: bool = True): + """ + Loads a model from the path provided. + + :param batch: single or list of observations + :type batch: dict ot list of dict + :param deterministic: use deterministic actions from the policy + :type deterministic: bool + :return: the selected action + :rtype: int or list + """ + if isinstance(batch, dict): + return self.agent.predict(batch, deterministic=deterministic) + elif isinstance(batch, list) or isinstance(batch, np.ndarray): + return [self.agent.predict(obs, deterministic=deterministic) for obs in batch] + else: + raise ValueError() + + def reset(self): + raise NotImplementedError() + + def optimize(self, target_device): + raise NotImplementedError() + + def callback(self, _locals, _globals): + x, y = ts2xy(load_results(self.logdir), 'timesteps') + + if len(y) > 20: + self.mean_reward = np.mean(y[-20:]) + else: + return True + + if x[-1] - self.last_checkpoint_time_step > self.checkpoint_after_iter: + self.last_checkpoint_time_step = x[-1] + check_point_path = Path(self.logdir, + 'checkpoint_save' + str(x[-1]) + 'with_mean_rew' + str(self.mean_reward)) + self.save(str(check_point_path)) + + return True diff --git a/src/opendr/planning/end_to_end_planning/envs/__init__.py b/src/opendr/planning/end_to_end_planning/envs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/opendr/planning/end_to_end_planning/envs/agi_env.py b/src/opendr/planning/end_to_end_planning/envs/agi_env.py new file mode 100644 index 0000000000..69e476e904 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/envs/agi_env.py @@ -0,0 +1,313 @@ +# Copyright 2020-2022 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gym +from gym import spaces +import numpy as np +import rospy +from geometry_msgs.msg import PoseStamped +from std_msgs.msg import Float32MultiArray +from std_msgs.msg import String +from nav_msgs.msg import Path +from webots_ros.msg import BoolStamped + + +def euler_from_quaternion(x, y, z, w): + """ + Convert a quaternion into euler angles (roll, pitch, yaw) + roll is rotation around x in radians (counterclockwise) + pitch is rotation around y in radians (counterclockwise) + yaw is rotation around z in radians (counterclockwise) + """ + t3 = +2.0 * (w * z + x * y) + t4 = +1.0 - 2.0 * (y * y + z * z) + yaw_z = np.atan2(t3, t4) + + return yaw_z / np.pi * 180 # in radians + + +class AgiEnv(gym.Env): + metadata = {'render.modes': ['human']} + + def __init__(self): + super(AgiEnv, self).__init__() + + # Gym elements + self.action_space = gym.spaces.Discrete(7) + self.observation_space = spaces.Dict( + {'depth_cam': spaces.Box(low=0, high=255, shape=(64, 64, 1), dtype=np.uint8), + 'moving_target': spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64)}) + + self.action_dictionary = {0: (np.cos(22.5 / 180 * np.pi), np.sin(22.5 / 180 * np.pi), 1), + 1: (np.cos(22.5 / 180 * np.pi), np.sin(22.5 / 180 * np.pi), 0), + 2: (1, 0, 1), + 3: (1, 0, 0), + 4: (1, 0, -1), + 5: (np.cos(22.5 / 180 * np.pi), -np.sin(22.5 / 180 * np.pi), 0), + 6: (np.cos(22.5 / 180 * np.pi), -np.sin(22.5 / 180 * np.pi), -1), + 7: (0, 0, 2), + 8: (0, 0, -2)} + self.step_length = 1 # meter + + self.current_position = PoseStamped().pose.position + self.current_yaw = 0 + self.range_image = np.ones((64, 64), dtype=np.float32) + self.collision_flag = False + self.model_name = "" + + # ROS connection + rospy.init_node('agi_gym_environment') + self.r = rospy.Rate(10) + self.ros_pub_pose = rospy.Publisher('mavros/setpoint_position/local', PoseStamped, queue_size=10) + self.ros_pub_target = rospy.Publisher('target_position', PoseStamped, queue_size=10) + self.ros_pub_trajectory = rospy.Publisher('uav_trajectory', Path, queue_size=10) + self.ros_pub_global_trajectory = rospy.Publisher('uav_global_trajectory', Path, queue_size=10) + self.global_traj = Path() + self.uav_trajectory = Path() + rospy.Subscriber("/mavros/local_position/pose", PoseStamped, self.pose_callback) + rospy.Subscriber("/range_image_raw", Float32MultiArray, self.range_image_callback) + rospy.Subscriber("/model_name", String, self.model_name_callback) + self.r.sleep() + rospy.Subscriber("/" + self.model_name + "/touch_sensor/value", BoolStamped, self.collision_callback) + + self.target_y = -22 + self.target_y_list = [-22, -16, -10, -4, 2, 7, 12] # evaluation map:[-22, -16, -10, -4, 2, 8, 14, 20, 26, 32] + self.target_z = 2.5 + self.start_x = -10 + self.forward_direction = True + self.parkour_length = 30 + self.episode_counter = 0 + + self.set_target() + self.r.sleep() + vo = self.difference_between_points(self.target_position, self.current_position) + self.vector_observation = np.array([vo[0] * np.cos(self.current_yaw * 22.5 / 180 * np.pi) - vo[1] * np.sin( + self.current_yaw * 22.5 / 180 * np.pi), + vo[0] * np.sin(self.current_yaw * 22.5 / 180 * np.pi) + vo[1] * np.cos( + self.current_yaw * 22.5 / 180 * np.pi), + vo[2]]) + self.observation = {'depth_cam': np.copy(self.range_image), 'moving_target': np.copy(self.vector_observation)} + self.r.sleep() + + self.image_count = 0 + + def step(self, discrete_action): + if self.current_position == PoseStamped().pose.position: + rospy.loginfo("Gym environment is not reading mavros position") + return self.observation_space.sample(), np.random.random(1), False, {} + action = self.action_dictionary[discrete_action] + action = (action[0] * self.step_length, action[1] * self.step_length, action[2]) + prev_x = self.current_position.x + if self.forward_direction: + self.go_position( + self.current_position.x + action[0] * np.cos(self.current_yaw * 22.5 / 180 * np.pi) - action[ + 1] * np.sin(self.current_yaw * 22.5 / 180 * np.pi), + self.current_position.y + action[0] * np.sin(self.current_yaw * 22.5 / 180 * np.pi) + action[ + 1] * np.cos(self.current_yaw * 22.5 / 180 * np.pi), self.target_z, yaw=self.current_yaw + action[2], + check_collision=True) + else: + self.go_position( + self.current_position.x - action[0] * np.cos(self.current_yaw * 22.5 / 180 * np.pi) + action[ + 1] * np.sin(self.current_yaw * 22.5 / 180 * np.pi), + self.current_position.y - action[0] * np.sin(self.current_yaw * 22.5 / 180 * np.pi) - action[ + 1] * np.cos(self.current_yaw * 22.5 / 180 * np.pi), self.target_z, yaw=self.current_yaw + action[2], + check_collision=True) + self.update_trajectory() + + dx = np.abs(self.current_position.x - prev_x) + dy = np.abs(self.current_position.y - self.target_position.y) + dyaw = np.abs(self.current_yaw) + reward = 2 * dx - 0.4 * dy - 0.3 * dyaw + + # set new observation + if self.forward_direction: + self.set_target() + vo = self.difference_between_points(self.target_position, self.current_position) + self.vector_observation = np.array([vo[0] * np.cos(self.current_yaw * 22.5 / 180 * np.pi) + vo[1] * np.sin( + self.current_yaw * 22.5 / 180 * np.pi), + -vo[0] * np.sin(self.current_yaw * 22.5 / 180 * np.pi) + vo[1] * np.cos( + self.current_yaw * 22.5 / 180 * np.pi), + vo[2]]) + self.observation = {'depth_cam': np.copy(self.range_image), + 'moving_target': np.copy(self.vector_observation)} + finish_passed = (self.current_position.x > self.parkour_length + self.start_x) + else: + self.set_target() + vo = self.difference_between_points(self.current_position, self.target_position) + self.vector_observation = np.array([vo[0] * np.cos(self.current_yaw * 22.5 / 180 * np.pi) + vo[1] * np.sin( + self.current_yaw * 22.5 / 180 * np.pi), + -vo[0] * np.sin(self.current_yaw * 22.5 / 180 * np.pi) + vo[1] * np.cos( + self.current_yaw * 22.5 / 180 * np.pi), + vo[2]]) + self.observation = {'depth_cam': np.copy(self.range_image), + 'moving_target': np.copy(self.vector_observation)} + finish_passed = (self.current_position.x < self.start_x - self.parkour_length) + + # check done + if finish_passed: + reward = 20 + done = True + elif abs(self.current_position.y - self.target_y) > 5: + reward = -10 + done = True + elif self.collision_flag: + reward = -20 + done = True + else: + done = False + + info = {"current_position": self.current_position, "finish_passed": finish_passed} + return self.observation, reward, done, info + + def reset(self): + if self.current_position == PoseStamped().pose.position: + rospy.loginfo("Gym environment is not reading mavros position") + return self.observation_space.sample() + self.target_y = np.random.choice(self.target_y_list) + self.go_position(self.current_position.x, self.current_position.y, 8) + self.go_position(self.start_x, self.current_position.y, 8) + self.go_position(self.start_x, self.target_y, self.target_z) + self.uav_trajectory.header.frame_id = "map" + self.update_trajectory() + self.publish_global_trajectory() + + self.collision_flag = False + self.set_target() + if self.forward_direction: + self.vector_observation = self.difference_between_points(self.target_position, self.current_position) + else: + self.vector_observation = self.difference_between_points(self.current_position, self.target_position) + self.observation = {'depth_cam': np.copy(self.range_image), 'moving_target': np.copy(self.vector_observation)} + return self.observation + + def set_target(self): + self.target_position = PoseStamped().pose.position + if self.forward_direction: + self.target_position.x = self.current_position.x + 5 + else: + self.target_position.x = self.current_position.x - 5 + self.target_position.y = self.target_y + self.target_position.z = self.target_z + self.publish_target() + + def render(self, mode='human', close=False): + pass + + def pose_callback(self, data): + self.current_position = data.pose.position + + def range_image_callback(self, data): + self.range_image = ((np.clip(np.array(data.data).reshape((64, 64, 1)), 0, 15) / 15.)*255).astype(np.uint8) + + def model_name_callback(self, data): + if data.data[:5] == "robot": + self.model_name = data.data + + def collision_callback(self, data): + if data.data: + self.collision_flag = True + # print("colliiiddeeee") + + def go_position(self, x, y, z, yaw=0, check_collision=False): + if yaw > 4: + yaw = 4 + if yaw < -4: + yaw = -4 + goal = PoseStamped() + + goal.header.seq = 1 + goal.header.stamp = rospy.Time.now() + # goal.header.frame_id = "map" + + goal.pose.position.x = x + goal.pose.position.y = y + goal.pose.position.z = z + + goal.pose.orientation.x = 0.0 + goal.pose.orientation.y = 0.0 + quat_z_yaw_dict = {-4: -0.7071068, -3: -0.5555702, -2: -0.3826834, -1: -0.1950903, 0: 0.0, 1: 0.1950903, + 2: 0.3826834, 3: 0.5555702, 4: 0.7071068} + quat_w_yaw_dict = {-4: 0.7071068, -3: 0.8314696, -2: 0.9238795, -1: 0.9807853, 0: 1.0, 1: 0.9807853, + 2: 0.9238795, 3: 0.8314696, 4: 0.7071068} + if self.forward_direction: + goal.pose.orientation.z = quat_z_yaw_dict[yaw] + goal.pose.orientation.w = quat_w_yaw_dict[yaw] + else: + goal.pose.orientation.z = -quat_w_yaw_dict[yaw] + goal.pose.orientation.w = quat_z_yaw_dict[yaw] + self.current_yaw = yaw + self.ros_pub_pose.publish(goal) + self.r.sleep() + while self.distance_between_points(goal.pose.position, self.current_position) > 0.1: + if check_collision and self.collision_flag: + return + self.ros_pub_pose.publish(goal) + self.r.sleep() + + def publish_target(self): + goal = PoseStamped() + + goal.header.seq = 1 + goal.header.stamp = rospy.Time.now() + goal.header.frame_id = "map" + + goal.pose.position = self.target_position + + goal.pose.orientation.x = 0.0 + goal.pose.orientation.y = 0.0 + goal.pose.orientation.z = 0.0 + goal.pose.orientation.w = 1.0 + self.ros_pub_target.publish(goal) + + def update_trajectory(self): + new_point = PoseStamped() + new_point.header.seq = 1 + new_point.header.stamp = rospy.Time.now() + new_point.header.frame_id = "map" + new_point.pose.position.x = self.current_position.x + new_point.pose.position.y = self.current_position.y + new_point.pose.position.z = self.current_position.z + self.uav_trajectory.poses.append(new_point) + self.ros_pub_trajectory.publish(self.uav_trajectory) + + def publish_global_trajectory(self): + self.global_traj.header.frame_id = "map" + new_point = PoseStamped() + new_point.header.seq = 1 + new_point.header.stamp = rospy.Time.now() + new_point.header.frame_id = "map" + new_point.pose.position.x = self.start_x + new_point.pose.position.y = self.target_y + new_point.pose.position.z = self.target_z + self.global_traj.poses.append(new_point) + new_point = PoseStamped() + new_point.header.seq = 1 + new_point.header.stamp = rospy.Time.now() + new_point.header.frame_id = "map" + if self.forward_direction: + new_point.pose.position.x = self.start_x + self.parkour_length + else: + new_point.pose.position.x = self.start_x - self.parkour_length + new_point.pose.position.y = self.target_y + new_point.pose.position.z = self.target_z + self.global_traj.poses.append(new_point) + self.ros_pub_global_trajectory.publish(self.global_traj) + + def distance_between_points(self, p1, p2): + x = p1.x - p2.x + y = p1.y - p2.y + z = p1.z - p2.z + return np.sqrt(x * x + y * y + z * z) + + def difference_between_points(self, p1, p2): + return np.array([p1.x - p2.x, p1.y - p2.y, p1.z - p2.z]) diff --git a/src/opendr/planning/end_to_end_planning/pretrained_model/__init__.py b/src/opendr/planning/end_to_end_planning/pretrained_model/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/opendr/planning/end_to_end_planning/pretrained_model/saved_model.zip b/src/opendr/planning/end_to_end_planning/pretrained_model/saved_model.zip new file mode 100644 index 0000000000..53c2522feb Binary files /dev/null and b/src/opendr/planning/end_to_end_planning/pretrained_model/saved_model.zip differ diff --git a/src/opendr/planning/end_to_end_planning/src/children_robot.cpp b/src/opendr/planning/end_to_end_planning/src/children_robot.cpp new file mode 100644 index 0000000000..482227afba --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/src/children_robot.cpp @@ -0,0 +1,245 @@ +// Copyright 2020-2022 OpenDR European Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "ros/ros.h" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define TIME_STEP 32; +using namespace std; + +static int controllerCount; +static std::vector controllerList; +static std::vector imageRangeFinder; +static double touchSensorValues[3] = {0, 0, 0}; +static bool callbackCalled = false; + +ros::ServiceClient time_step_client; +webots_ros::set_int time_step_srv; +// catch names of the controllers availables on ROS network +void controllerNameCallback(const std_msgs::String::ConstPtr &name) { + controllerCount++; + controllerList.push_back(name->data); + ROS_INFO("Controller #%d: %s.", controllerCount, controllerList.back().c_str()); +} + +// get range image from the range-finder +void rangeFinderCallback(const sensor_msgs::Image::ConstPtr &image) { + int size = image->width * image->height; + imageRangeFinder.resize(size); + + const float *depth_data = reinterpret_cast(&image->data[0]); + for (int i = 0; i < size; ++i) + imageRangeFinder[i] = depth_data[i]; +} +// touch sensor +void touchSensorCallback(const webots_ros::Float64Stamped::ConstPtr &value) { + ROS_INFO("Touch sensor sent value %f (time: %d:%d).", value->data, value->header.stamp.sec, value->header.stamp.nsec); + callbackCalled = true; +} + +void touchSensorBumperCallback(const webots_ros::BoolStamped::ConstPtr &value) { + ROS_INFO("Touch sensor sent value %d (time: %d:%d).", value->data, value->header.stamp.sec, value->header.stamp.nsec); + callbackCalled = true; +} + +void touchSensor3DCallback(const geometry_msgs::WrenchStamped::ConstPtr &values) { + touchSensorValues[0] = values->wrench.force.x; + touchSensorValues[1] = values->wrench.force.y; + touchSensorValues[2] = values->wrench.force.z; + + ROS_INFO("Touch sensor values are x = %f, y = %f and z = %f (time: %d:%d).", touchSensorValues[0], touchSensorValues[1], + touchSensorValues[2], values->header.stamp.sec, values->header.stamp.nsec); + callbackCalled = true; +} + +void quit(int sig) { + time_step_srv.request.value = 0; + time_step_client.call(time_step_srv); + ROS_INFO("User stopped the 'catch_the_bird' node."); + ros::shutdown(); + exit(0); +} + +int main(int argc, char **argv) { + std::string controllerName; + std::vector deviceList; + std::string rangeFinderName; + std::string touchSensorName; + + int width, height; + float i, step; + + // create a node named 'range' on ROS network + ros::init(argc, argv, "range", ros::init_options::AnonymousName); + ros::NodeHandle n; + + signal(SIGINT, quit); + + // subscribe to the topic model_name to get the list of availables controllers + ros::Subscriber nameSub = n.subscribe("model_name", 100, controllerNameCallback); + while (controllerCount == 0 || controllerCount < nameSub.getNumPublishers()) { + ros::spinOnce(); + ros::spinOnce(); + ros::spinOnce(); + } + ros::spinOnce(); + + // if there is more than one controller available, let the user choose + if (controllerCount == 1) + controllerName = controllerList[0]; + else { + int wantedController = 0; + std::cout << "Choose the # of the controller you want to use:\n"; + std::cin >> wantedController; + if (1 <= wantedController && wantedController <= controllerCount) + controllerName = controllerList[wantedController - 1]; + else { + ROS_ERROR("Invalid number for controller choice."); + return 1; + } + } + // leave topic once it's not necessary anymore + nameSub.shutdown(); + // call device_list service to get the list of the devices available on the controller and print it the device_list_srv object + // contains 2 members request and response. Their fields are described in the corresponding .srv file + ros::ServiceClient deviceListClient = + n.serviceClient(controllerName + "/robot/get_device_list"); + webots_ros::robot_get_device_list deviceListSrv; + + if (deviceListClient.call(deviceListSrv)) + deviceList = deviceListSrv.response.list; + else + ROS_ERROR("Failed to call service device_list."); + rangeFinderName = deviceList[1]; + touchSensorName = deviceList[0]; + ros::ServiceClient rangeFinderGetInfoClient = + n.serviceClient(controllerName + '/' + rangeFinderName + "/get_info"); + webots_ros::range_finder_get_info rangeFinderGetInfoSrv; + if (rangeFinderGetInfoClient.call(rangeFinderGetInfoSrv)) { + width = rangeFinderGetInfoSrv.response.width; + height = rangeFinderGetInfoSrv.response.height; + ROS_INFO("Range-finder size is %d x %d.", width, height); + } else + ROS_ERROR("Failed to call service range_finder_get_info."); + + // enable the range-finder + ros::ServiceClient enableRangeFinderClient = + n.serviceClient(controllerName + '/' + rangeFinderName + "/enable"); + webots_ros::set_int enableRangeFinderSrv; + ros::Subscriber subRangeFinderRangeFinder; + + enableRangeFinderSrv.request.value = 2 * TIME_STEP; + if (enableRangeFinderClient.call(enableRangeFinderSrv) && enableRangeFinderSrv.response.success) { + ROS_INFO("Range-finder enabled with sampling period %d.", enableRangeFinderSrv.request.value); + subRangeFinderRangeFinder = n.subscribe(controllerName + '/' + rangeFinderName + "/range_image", 1, rangeFinderCallback); + + // wait for the topics to be initialized + while (subRangeFinderRangeFinder.getNumPublishers() == 0) { + } + } else { + ROS_ERROR("Failed to call service enable for %s.", rangeFinderName.c_str()); + } + // enable time_step + time_step_client = n.serviceClient(controllerName + "/robot/time_step"); + time_step_srv.request.value = TIME_STEP; + + /////////////////////////////// + // TOUCH SENSOR // + /////////////////////////////// + + ros::ServiceClient set_touch_sensor_client; + webots_ros::set_int touch_sensor_srv; + ros::Subscriber sub_touch_sensor_32; + set_touch_sensor_client = n.serviceClient(controllerName + "/touch_sensor/enable"); + + ros::ServiceClient sampling_period_touch_sensor_client; + webots_ros::get_int sampling_period_touch_sensor_srv; + sampling_period_touch_sensor_client = + n.serviceClient(controllerName + "/touch_sensor/get_sampling_period"); + + ros::ServiceClient touch_sensor_get_type_client; + webots_ros::get_int touch_sensor_get_type_srv; + touch_sensor_get_type_client = n.serviceClient(controllerName + "/touch_sensor/get_type"); + + touch_sensor_get_type_client.call(touch_sensor_get_type_srv); + ROS_INFO("Touch_sensor is of type %d.", touch_sensor_get_type_srv.response.value); + + touch_sensor_get_type_client.shutdown(); + time_step_client.call(time_step_srv); + + touch_sensor_srv.request.value = 32; + if (set_touch_sensor_client.call(touch_sensor_srv) && touch_sensor_srv.response.success) { + ROS_INFO("Touch_sensor enabled."); + if (touch_sensor_get_type_srv.response.value == 0) + sub_touch_sensor_32 = n.subscribe(controllerName + "/touch_sensor/value", 1, touchSensorBumperCallback); + else if (touch_sensor_get_type_srv.response.value == 1) + sub_touch_sensor_32 = n.subscribe(controllerName + "/touch_sensor/value", 1, touchSensorCallback); + else + sub_touch_sensor_32 = n.subscribe(controllerName + "/touch_sensor/values", 1, touchSensor3DCallback); + callbackCalled = false; + while (sub_touch_sensor_32.getNumPublishers() == 0 && !callbackCalled) { + ros::spinOnce(); + time_step_client.call(time_step_srv); + } + } else { + if (!touch_sensor_srv.response.success) + ROS_ERROR("Sampling period is not valid."); + ROS_ERROR("Failed to enable touch_sensor."); + return 1; + } + + sub_touch_sensor_32.shutdown(); + time_step_client.call(time_step_srv); + + sampling_period_touch_sensor_client.call(sampling_period_touch_sensor_srv); + ROS_INFO("Touch_sensor is enabled with a sampling period of %d.", sampling_period_touch_sensor_srv.response.value); + + time_step_client.call(time_step_srv); + + sampling_period_touch_sensor_client.call(sampling_period_touch_sensor_srv); + sampling_period_touch_sensor_client.shutdown(); + time_step_client.call(time_step_srv); + + // main loop + while (ros::ok()) { + if (!time_step_client.call(time_step_srv) || !time_step_srv.response.success) { + ROS_ERROR("Failed to call next step with time_step service."); + exit(1); + } + ros::spinOnce(); + while (imageRangeFinder.size() < (width * height)) + ros::spinOnce(); + } + time_step_srv.request.value = 0; + time_step_client.call(time_step_srv); + n.shutdown(); +} diff --git a/src/opendr/planning/end_to_end_planning/src/range_image.py b/src/opendr/planning/end_to_end_planning/src/range_image.py new file mode 100644 index 0000000000..8266c6c338 --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/src/range_image.py @@ -0,0 +1,43 @@ +# Copyright 2020-2022 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import rospy +from sensor_msgs.msg import Image +from std_msgs.msg import Float32MultiArray +from std_msgs.msg import String +from cv_bridge import CvBridge + + +class range_image_node(): + + def __init__(self): + rospy.init_node('listener', anonymous=True) + self.r = rospy.Rate(10) + self.raw_image_pub = rospy.Publisher('range_image_raw', Float32MultiArray, queue_size=10) + rospy.Subscriber("/model_name", String, self.model_name_callback) + self.r.sleep() + rospy.Subscriber("/" + self.model_name + "/range_finder/range_image", Image, self.range_callback) + rospy.spin() + + def range_callback(self, data): + bridge = CvBridge() + cv_image = bridge.imgmsg_to_cv2(data) + arr = Float32MultiArray() + arr.data = list(cv_image.reshape(4096)) + self.raw_image_pub.publish(arr) + + def model_name_callback(self, data): + self.model_name = data.data + + +node_class = range_image_node() diff --git a/src/opendr/planning/end_to_end_planning/src/take_off.cpp b/src/opendr/planning/end_to_end_planning/src/take_off.cpp new file mode 100644 index 0000000000..627ac3da7b --- /dev/null +++ b/src/opendr/planning/end_to_end_planning/src/take_off.cpp @@ -0,0 +1,94 @@ +// Copyright 2020-2022 OpenDR European Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include +#include +#include +#include "std_msgs/Float64.h" +#include "std_msgs/Header.h" +#include "std_msgs/String.h" + +mavros_msgs::State current_state; +void state_cb(const mavros_msgs::State::ConstPtr &msg) { + current_state = *msg; +} +float current_pose_x; +float current_pose_y; +float current_pose_z; + +void current_pose_callback(const geometry_msgs::PoseStamped::ConstPtr &msg) { + current_pose_x = msg->pose.position.x; + current_pose_y = msg->pose.position.y; + current_pose_z = msg->pose.position.z; +} + +int main(int argc, char **argv) { + ros::init(argc, argv, "way_point"); + ros::NodeHandle nh; + + ros::Subscriber state_sub = nh.subscribe("mavros/state", 10, state_cb); + + ros::Publisher local_pos_pub = nh.advertise("mavros/setpoint_position/local", 100); + ros::ServiceClient arming_client = nh.serviceClient("mavros/cmd/arming"); + ros::ServiceClient set_mode_client = nh.serviceClient("mavros/set_mode"); + // takeoff + ros::ServiceClient takeoff_client = nh.serviceClient("mavros/cmd/takeoff"); + // drone pos sub + ros::Subscriber drone_pos_sub = + nh.subscribe("mavros/local_position/pose", 10, current_pose_callback); + + // the setpoint publishing rate MUST be faster than 2Hz + ros::Rate rate(40.0); + + // wait for FCU connection + while (ros::ok() && !current_state.connected) { + ros::spinOnce(); + rate.sleep(); + } + geometry_msgs::PoseStamped pose; + pose.pose.position.x = 0; + pose.pose.position.y = 0; + pose.pose.position.z = 5; + local_pos_pub.publish(pose); + + // send a few setpoints before starting + for (int i = 100; ros::ok() && i > 0; --i) { + local_pos_pub.publish(pose); + ros::spinOnce(); + rate.sleep(); + } + + mavros_msgs::SetMode guided_set_mode; + + guided_set_mode.request.custom_mode = "GUIDED"; + + mavros_msgs::CommandBool arm_cmd; + arm_cmd.request.value = true; + + mavros_msgs::CommandTOL takeoff_cmd; + + takeoff_cmd.request.min_pitch = 0; + takeoff_cmd.request.yaw = 0; + takeoff_cmd.request.latitude = 0; + takeoff_cmd.request.longitude = 0; + takeoff_cmd.request.altitude = 5; + + set_mode_client.call(guided_set_mode); + arming_client.call(arm_cmd); + takeoff_client.call(takeoff_cmd); + + return 0; +} diff --git a/tests/sources/tools/planning/__init__.py b/tests/sources/tools/planning/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/tools/planning/end_to_end_planning/__init__.py b/tests/sources/tools/planning/end_to_end_planning/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/tools/planning/end_to_end_planning/test_end_to_end_planning.py b/tests/sources/tools/planning/end_to_end_planning/test_end_to_end_planning.py new file mode 100644 index 0000000000..67533e5c61 --- /dev/null +++ b/tests/sources/tools/planning/end_to_end_planning/test_end_to_end_planning.py @@ -0,0 +1,89 @@ +# Copyright 1996-2020 OpenDR European Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import unittest +from pathlib import Path + +from opendr.planning.end_to_end_planning import EndToEndPlanningRLLearner, AgiEnv +import opendr +import torch +import os + +device = os.getenv('TEST_DEVICE') if os.getenv('TEST_DEVICE') else 'cpu' + +TEST_ITERS = 3 +TEMP_SAVE_DIR = Path(__file__).parent / "end_to_end_planning_tmp/" +TEMP_SAVE_DIR.mkdir(parents=True, exist_ok=True) + + +def get_first_weight(learner): + return list(learner.stable_bl_agent.get_parameters()['policy'].values())[0].clone() + + +def isequal_dict_of_ndarray(first, second): + """Return whether two dicts of arrays are exactly equal""" + if first.keys() != second.keys(): + return False + return all(np.array_equal(first[key], second[key]) for key in first) + + +class EndToEndPlanningTest(unittest.TestCase): + learner = None + + @classmethod + def setUpClass(cls): + cls.env = AgiEnv() + cls.learner = EndToEndPlanningRLLearner(cls.env, device=device) + + @classmethod + def tearDownClass(cls): + del cls.learner + + def test_infer(self): + obs = self.env.observation_space.sample() + action = self.learner.infer(obs)[0] + self.assertTrue((action >= 0), "Actions below 0") + self.assertTrue((action < self.env.action_space.n), "Actions above discrete action space dimensions") + + def test_eval(self): + episode_reward = self.learner.eval(self.env)["rewards_collected"] + self.assertTrue((episode_reward > -100), "Episode reward cannot be lower than -100") + self.assertTrue((episode_reward < 100), "Episode reward cannot pass 100") + + def test_fit(self): + self.learner.__init__(self.env, n_steps=12, iters=15) + initial_weights = self.learner.agent.get_parameters() + self.learner.fit(logging_path=str(TEMP_SAVE_DIR)) + trained_weights = self.learner.agent.get_parameters() + self.assertFalse(isequal_dict_of_ndarray(initial_weights, trained_weights), + "Fit method did not change model weights") + + def test_save_load(self): + self.learner.__init__(self.env) + initial_weights = list(self.learner.agent.get_parameters()['policy'].values())[0].clone() + self.learner.save(str(TEMP_SAVE_DIR) + "/init_weights") + self.learner.load( + path=Path(opendr.__file__).parent / "planning/end_to_end_planning/pretrained_model/saved_model") + self.assertFalse( + torch.equal(initial_weights, list(self.learner.agent.get_parameters()['policy'].values())[0].clone()), + "Load method did not change model weights") + self.learner.load(str(TEMP_SAVE_DIR) + "/init_weights") + self.assertTrue( + torch.equal(initial_weights, list(self.learner.agent.get_parameters()['policy'].values())[0].clone()), + "Load method did not load the same model weights") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_clang_format.py b/tests/test_clang_format.py index d77e78176b..226361edb8 100755 --- a/tests/test_clang_format.py +++ b/tests/test_clang_format.py @@ -59,6 +59,7 @@ def test_sources_are_clang_format_compliant(self): ] skippedPaths = [ 'src/opendr/perception/panoptic_segmentation/efficient_ps/algorithm/EfficientPS', + 'src/opendr/planning/end_to_end_planning/ardupilot', ] skippedFiles = [ 'src/opendr/perception/object_detection_2d/retinaface/algorithm/cython/gpu_nms.cpp',