-
Notifications
You must be signed in to change notification settings - Fork 4.3k
/
Copy pathModelRunner.cs
235 lines (205 loc) · 8.26 KB
/
ModelRunner.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
using System.Collections.Generic;
using Unity.Barracuda;
using UnityEngine.Profiling;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Inference
{
internal struct AgentInfoSensorsPair
{
public AgentInfo agentInfo;
public List<ISensor> sensors;
}
internal class ModelRunner
{
List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>();
Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>();
List<int> m_OrderedAgentsRequestingDecisions = new List<int>();
ITensorAllocator m_TensorAllocator;
TensorGenerator m_TensorGenerator;
TensorApplier m_TensorApplier;
NNModel m_Model;
string m_ModelName;
InferenceDevice m_InferenceDevice;
IWorker m_Engine;
bool m_Verbose = false;
string[] m_OutputNames;
IReadOnlyList<TensorProxy> m_InferenceInputs;
List<TensorProxy> m_InferenceOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();
SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator();
bool m_ObservationsInitialized;
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for
/// the agents
/// </summary>
/// <param name="model"> The Barracuda model to load </param>
/// <param name="actionSpec"> Description of the actions for the Agent.</param>
/// <param name="inferenceDevice"> Inference execution device. CPU is the fastest
/// option for most of ML Agents models. </param>
/// <param name="seed"> The seed that will be used to initialize the RandomNormal
/// and Multinomial objects used when running inference.</param>
/// <exception cref="UnityAgentsException">Throws an error when the model is null
/// </exception>
public ModelRunner(
NNModel model,
ActionSpec actionSpec,
InferenceDevice inferenceDevice,
int seed = 0)
{
Model barracudaModel;
m_Model = model;
m_ModelName = model.name;
m_InferenceDevice = inferenceDevice;
m_TensorAllocator = new TensorCachingAllocator();
if (model != null)
{
#if BARRACUDA_VERBOSE
m_Verbose = true;
#endif
D.logEnabled = m_Verbose;
barracudaModel = ModelLoader.Load(model);
WorkerFactory.Type executionDevice;
switch (inferenceDevice)
{
case InferenceDevice.CPU:
executionDevice = WorkerFactory.Type.CSharp;
break;
case InferenceDevice.GPU:
executionDevice = WorkerFactory.Type.ComputePrecompiled;
break;
case InferenceDevice.Burst:
executionDevice = WorkerFactory.Type.CSharpBurst;
break;
case InferenceDevice.Default: // fallthrough
default:
executionDevice = WorkerFactory.Type.CSharpBurst;
break;
}
m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose);
}
else
{
barracudaModel = null;
m_Engine = null;
}
m_InferenceInputs = barracudaModel.GetInputTensors();
m_OutputNames = barracudaModel.GetOutputNames();
m_TensorGenerator = new TensorGenerator(
seed, m_TensorAllocator, m_Memories, barracudaModel);
m_TensorApplier = new TensorApplier(
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();
m_InferenceOutputs = new List<TensorProxy>();
}
public InferenceDevice InferenceDevice
{
get { return m_InferenceDevice; }
}
public NNModel Model
{
get { return m_Model; }
}
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
{
m_InputsByName.Clear();
for (var i = 0; i < infInputs.Count; i++)
{
var inp = infInputs[i];
m_InputsByName[inp.name] = inp.data;
}
}
public void Dispose()
{
if (m_Engine != null)
m_Engine.Dispose();
m_TensorAllocator?.Reset(false);
}
void FetchBarracudaOutputs(string[] names)
{
m_InferenceOutputs.Clear();
foreach (var n in names)
{
var output = m_Engine.PeekOutput(n);
m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
}
}
public void PutObservations(AgentInfo info, List<ISensor> sensors)
{
#if DEBUG
m_SensorShapeValidator.ValidateSensors(sensors);
#endif
m_Infos.Add(new AgentInfoSensorsPair
{
agentInfo = info,
sensors = sensors
});
// We add the episodeId to this list to maintain the order in which the decisions were requested
m_OrderedAgentsRequestingDecisions.Add(info.episodeId);
if (!m_LastActionsReceived.ContainsKey(info.episodeId))
{
m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty;
}
if (info.done)
{
// If the agent is done, we remove the key from the last action dictionary since no action
// should be taken.
m_LastActionsReceived.Remove(info.episodeId);
}
}
public void DecideBatch()
{
var currentBatchSize = m_Infos.Count;
if (currentBatchSize == 0)
{
return;
}
if (!m_ObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
var firstInfo = m_Infos[0];
m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator);
m_ObservationsInitialized = true;
}
Profiler.BeginSample("ModelRunner.DecideAction");
Profiler.BeginSample(m_ModelName);
Profiler.BeginSample($"GenerateTensors");
// Prepare the input tensors to be feed into the engine
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos);
Profiler.EndSample();
Profiler.BeginSample($"PrepareBarracudaInputs");
PrepareBarracudaInputs(m_InferenceInputs);
Profiler.EndSample();
// Execute the Model
Profiler.BeginSample($"ExecuteGraph");
m_Engine.Execute(m_InputsByName);
Profiler.EndSample();
Profiler.BeginSample($"FetchBarracudaOutputs");
FetchBarracudaOutputs(m_OutputNames);
Profiler.EndSample();
Profiler.BeginSample($"ApplyTensors");
// Update the outputs
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
Profiler.EndSample();
Profiler.EndSample(); // end name
Profiler.EndSample(); // end ModelRunner.DecideAction
m_Infos.Clear();
m_OrderedAgentsRequestingDecisions.Clear();
}
public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice)
{
return m_Model == other && m_InferenceDevice == otherInferenceDevice;
}
public ActionBuffers GetAction(int agentId)
{
if (m_LastActionsReceived.ContainsKey(agentId))
{
return m_LastActionsReceived[agentId];
}
return ActionBuffers.Empty;
}
}
}