-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathRigidBodySensorTests.cs
142 lines (121 loc) · 5.22 KB
/
RigidBodySensorTests.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
string errorMessage;
bool isOk = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOk, errorMessage);
}
public static void CompareObservation(ISensor sensor, float[,,] expected)
{
string errorMessage;
bool isOk = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOk, errorMessage);
}
}
public class RigidBodySensorTests
{
[Test]
public void TestNullRootBody()
{
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
var sensor = sensorComponent.CreateSensors()[0];
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
[Test]
public void TestSingleRigidbody()
{
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceLinearVelocity = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceRotations = true
};
var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
// The root body is ignored since it always generates identity values
// and there are no other bodies to generate observations.
var expected = new float[0];
Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
public void TestBodiesWithJoint()
{
var rootObj = new GameObject();
var rootRb = rootObj.AddComponent<Rigidbody>();
rootRb.velocity = new Vector3(1f, 0f, 0f);
var middleGamObj = new GameObject();
var middleRb = middleGamObj.AddComponent<Rigidbody>();
middleRb.velocity = new Vector3(0f, 1f, 0f);
middleGamObj.transform.SetParent(rootObj.transform);
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
var joint = middleGamObj.AddComponent<ConfigurableJoint>();
joint.connectedBody = rootRb;
var leafGameObj = new GameObject();
var leafRb = leafGameObj.AddComponent<Rigidbody>();
leafRb.velocity = new Vector3(0f, 0f, 1f);
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>();
joint2.connectedBody = middleRb;
var virtualRoot = new GameObject();
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
sensorComponent.VirtualRoot = virtualRoot;
var sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
// Note that the VirtualRoot is ignored from the observations
var expected = new[]
{
// Model space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Middle pos
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
// Local space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Attached pos
4.2f, 0f, 0f, // Leaf pos
1f, 0f, 0f, // Root vel (relative to virtual root)
-1f, 1f, 0f, // Attached vel
0f, -1f, 1f // Leaf vel
};
Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
SensorTestHelper.CompareObservation(sensor, expected);
// Update the settings to only process joint observations
sensorComponent.Settings = new PhysicsSensorSettings
{
UseJointPositionsAndAngles = true,
UseJointForces = true,
};
sensor = sensorComponent.CreateSensors()[0];
sensor.Update();
expected = new[]
{
0f, 0f, 0f, // joint1.force
0f, 0f, 0f, // joint1.torque
0f, 0f, 0f, // joint2.force
0f, 0f, 0f, // joint2.torque
};
SensorTestHelper.CompareObservation(sensor, expected);
Assert.AreEqual(expected.Length, sensor.GetObservationSpec().Shape[0]);
}
}
}