-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathGridSensorComponent.cs
293 lines (267 loc) · 10.9 KB
/
GridSensorComponent.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
using System.Collections.Generic;
using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A SensorComponent that creates a <see cref="GridSensorBase"/>.
/// </summary>
[AddComponentMenu("ML Agents/Grid Sensor", (int)MenuGroup.Sensors)]
public class GridSensorComponent : SensorComponent
{
// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<ISensor> m_Sensors;
internal BoxOverlapChecker m_BoxOverlapChecker;
[HideInInspector, SerializeField]
protected internal string m_SensorName = "GridSensor";
/// <summary>
/// Name of the generated <see cref="GridSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
internal Vector3 m_CellScale = new Vector3(1f, 0.01f, 1f);
/// <summary>
/// The scale of each grid cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3 CellScale
{
get { return m_CellScale; }
set { m_CellScale = value; }
}
[HideInInspector, SerializeField]
internal Vector3Int m_GridSize = new Vector3Int(16, 1, 16);
/// <summary>
/// The number of grid on each side.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public Vector3Int GridSize
{
get { return m_GridSize; }
set
{
if (value.y != 1)
{
m_GridSize = new Vector3Int(value.x, 1, value.z);
}
else
{
m_GridSize = value;
}
}
}
[HideInInspector, SerializeField]
internal bool m_RotateWithAgent = true;
/// <summary>
/// Rotate the grid based on the direction the agent is facing.
/// </summary>
public bool RotateWithAgent
{
get { return m_RotateWithAgent; }
set { m_RotateWithAgent = value; }
}
[HideInInspector, SerializeField]
internal string[] m_DetectableTags;
/// <summary>
/// List of tags that are detected.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public string[] DetectableTags
{
get { return m_DetectableTags; }
set { m_DetectableTags = value; }
}
[HideInInspector, SerializeField]
internal LayerMask m_ColliderMask;
/// <summary>
/// The layer mask.
/// </summary>
public LayerMask ColliderMask
{
get { return m_ColliderMask; }
set { m_ColliderMask = value; }
}
[HideInInspector, SerializeField]
internal int m_MaxColliderBufferSize = 500;
/// <summary>
/// The absolute max size of the Collider buffer used in the non-allocating Physics calls. In other words
/// the Collider buffer will never grow beyond this number even if there are more Colliders in the Grid Cell.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int MaxColliderBufferSize
{
get { return m_MaxColliderBufferSize; }
set { m_MaxColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal int m_InitialColliderBufferSize = 4;
/// <summary>
/// The Estimated Max Number of Colliders to expect per cell. This number is used to
/// pre-allocate an array of Colliders in order to take advantage of the OverlapBoxNonAlloc
/// Physics API. If the number of colliders found is >= InitialColliderBufferSize the array
/// will be resized to double its current size. The hard coded absolute size is 500.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int InitialColliderBufferSize
{
get { return m_InitialColliderBufferSize; }
set { m_InitialColliderBufferSize = value; }
}
[HideInInspector, SerializeField]
internal Color[] m_DebugColors;
/// <summary>
/// Array of Colors used for the grid gizmos.
/// </summary>
public Color[] DebugColors
{
get { return m_DebugColors; }
set { m_DebugColors = value; }
}
[HideInInspector, SerializeField]
internal float m_GizmoYOffset = 0f;
/// <summary>
/// The height of the gizmos grid.
/// </summary>
public float GizmoYOffset
{
get { return m_GizmoYOffset; }
set { m_GizmoYOffset = value; }
}
[HideInInspector, SerializeField]
internal bool m_ShowGizmos = false;
/// <summary>
/// Whether to show gizmos or not.
/// </summary>
public bool ShowGizmos
{
get { return m_ShowGizmos; }
set { m_ShowGizmos = value; }
}
[HideInInspector, SerializeField]
internal SensorCompressionType m_CompressionType = SensorCompressionType.PNG;
/// <summary>
/// The compression type to use for the sensor.
/// </summary>
public SensorCompressionType CompressionType
{
get { return m_CompressionType; }
set { m_CompressionType = value; UpdateSensor(); }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of frames of observations that will be stacked before being fed to the neural network.")]
internal int m_ObservationStacks = 1;
/// <summary>
/// Whether to stack previous observations. Using 1 means no previous observations.
/// Note that changing this after the sensor is created has no effect.
/// </summary>
public int ObservationStacks
{
get { return m_ObservationStacks; }
set { m_ObservationStacks = value; }
}
/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
List<ISensor> m_Sensors = new List<ISensor>();
m_BoxOverlapChecker = new BoxOverlapChecker(
m_CellScale,
m_GridSize,
m_RotateWithAgent,
m_ColliderMask,
gameObject,
m_DetectableTags,
m_InitialColliderBufferSize,
m_MaxColliderBufferSize
);
// debug data is positive int value and will trigger data validation exception if SensorCompressionType is not None.
m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
var gridSensors = GetGridSensors();
if (gridSensors == null || gridSensors.Length < 1)
{
throw new UnityAgentsException("GridSensorComponent received no sensors. Specify at least one observation type (OneHot/Counting) to use grid sensors." +
"If you're overriding GridSensorComponent.GetGridSensors(), return at least one grid sensor.");
}
foreach (var sensor in gridSensors)
{
if (ObservationStacks != 1)
{
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks));
}
else
{
m_Sensors.Add(sensor);
}
m_BoxOverlapChecker.RegisterSensor(sensor);
}
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker;
return m_Sensors.ToArray();
}
/// <summary>
/// Get an array of GridSensors to be added in this component.
/// Override this method and return custom GridSensor implementations.
/// </summary>
/// <returns>Array of grid sensors to be added to the component.</returns>
protected virtual GridSensorBase[] GetGridSensors()
{
List<GridSensorBase> sensorList = new List<GridSensorBase>();
var sensor = new OneHotGridSensor(m_SensorName + "-OneHot", m_CellScale, m_GridSize, m_DetectableTags, m_CompressionType);
sensorList.Add(sensor);
return sensorList.ToArray();
}
/// <summary>
/// Update fields that are safe to change on the Sensor at runtime.
/// </summary>
internal void UpdateSensor()
{
if (m_Sensors != null)
{
m_BoxOverlapChecker.RotateWithAgent = m_RotateWithAgent;
m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
((GridSensorBase)sensor).CompressionType = m_CompressionType;
}
}
}
void OnDrawGizmos()
{
if (m_ShowGizmos)
{
if (m_BoxOverlapChecker == null || m_DebugSensor == null)
{
return;
}
m_DebugSensor.ResetPerceptionBuffer();
m_BoxOverlapChecker.UpdateGizmo();
var cellColors = m_DebugSensor.PerceptionBuffer;
var rotation = m_BoxOverlapChecker.GetGridRotation();
var scale = new Vector3(m_CellScale.x, 1, m_CellScale.z);
var gizmoYOffset = new Vector3(0, m_GizmoYOffset, 0);
var oldGizmoMatrix = Gizmos.matrix;
for (var i = 0; i < m_DebugSensor.PerceptionBuffer.Length; i++)
{
var cellPosition = m_BoxOverlapChecker.GetCellGlobalPosition(i);
var cubeTransform = Matrix4x4.TRS(cellPosition + gizmoYOffset, rotation, scale);
Gizmos.matrix = oldGizmoMatrix * cubeTransform;
var colorIndex = cellColors[i] - 1;
var debugRayColor = Color.white;
if (colorIndex > -1 && m_DebugColors.Length > colorIndex)
{
debugRayColor = m_DebugColors[(int)colorIndex];
}
Gizmos.color = new Color(debugRayColor.r, debugRayColor.g, debugRayColor.b, .5f);
Gizmos.DrawCube(Vector3.zero, Vector3.one);
}
Gizmos.matrix = oldGizmoMatrix;
}
}
}
}