This repository was archived by the owner on May 11, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathReadoutLayer.cs
831 lines (757 loc) · 39.3 KB
/
ReadoutLayer.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
using RCNet.MathTools;
using RCNet.MiscTools;
using RCNet.Neural.Data;
using RCNet.Neural.Data.Filter;
using RCNet.Neural.Network.NonRecurrent;
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Text;
using System.Threading.Tasks;
namespace RCNet.Neural.Network.SM.Readout
{
/// <summary>
/// Implements the readout layer consisting of trained readout units.
/// </summary>
[Serializable]
public class ReadoutLayer
{
//Events
/// <summary>
/// This informative event occurs each time the progress of the build process takes a step forward.
/// </summary>
[field: NonSerialized]
public event RLBuildProgressChangedHandler RLBuildProgressChanged;
/// <summary>
/// The delegate of the RLBuildProgressChanged event handler.
/// </summary>
/// <param name="buildProgress">The current state of the build process.</param>
public delegate void RLBuildProgressChangedHandler(BuildProgress buildProgress);
//Attribute properties
/// <summary>
/// Indicates whether the readout layer is trained.
/// </summary>
public bool Trained { get; private set; }
/// <summary>
/// The readout layer configuration.
/// </summary>
public ReadoutLayerSettings ReadoutLayerCfg { get; }
//Attributes
private FeatureFilterBase[] _predictorFeatureFilterCollection;
private FeatureFilterBase[] _outputFeatureFilterCollection;
private PredictorsMapper _predictorsMapper;
private ReadoutUnit[] _readoutUnitCollection;
private OneTakesAllGroup[] _oneTakesAllGroupCollection;
//Progress tracking attributes
[field: NonSerialized]
private int _buildReadoutUnitIdx;
[field: NonSerialized]
private int _buildOTAGroupIdx;
//Constructor
/// <summary>
/// Creates an uninitialized instance.
/// </summary>
/// <param name="readoutLayerCfg">The readout layer configuration.</param>
public ReadoutLayer(ReadoutLayerSettings readoutLayerCfg)
{
ReadoutLayerCfg = (ReadoutLayerSettings)readoutLayerCfg.DeepClone();
Reset();
return;
}
//Static properties
/// <summary>
/// Input and output data is normalized to this range.
/// </summary>
private static Interval InternalDataRange { get { return Interval.IntN1P1; } }
//Properties
/// <summary>
/// Gets the cloned error statistics of the readout units.
/// </summary>
public List<TNRNetCluster.ClusterErrStatistics> ReadoutUnitErrStatCollection
{
get
{
List<TNRNetCluster.ClusterErrStatistics> clonedStatisticsCollection = new List<TNRNetCluster.ClusterErrStatistics>(_readoutUnitCollection.Length);
foreach (ReadoutUnit ru in _readoutUnitCollection)
{
clonedStatisticsCollection.Add(ru.GetErrStat());
}
return clonedStatisticsCollection;
}
}
//Methods
private void ResetProgressTracking()
{
_buildReadoutUnitIdx = 0;
_buildOTAGroupIdx = 0;
return;
}
private void OnReadoutUnitBuildProgressChanged(ReadoutUnit.BuildProgress unitBuildProgress)
{
int maxNumOfGroups = 0;
if (_oneTakesAllGroupCollection != null)
{
maxNumOfGroups = _oneTakesAllGroupCollection.Length;
}
//Prepare readout layer version
BuildProgress buildProgress = new BuildProgress(Math.Min(_buildReadoutUnitIdx + 1, ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count),
ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count,
unitBuildProgress,
0,
maxNumOfGroups,
null
);
//Raise event
RLBuildProgressChanged?.Invoke(buildProgress);
return;
}
private void OnOTAGBuildProgressChanged(OneTakesAllGroup.BuildProgress groupBuildProgress)
{
//Prepare readout layer version
BuildProgress buildProgress = new BuildProgress(Math.Min(_buildReadoutUnitIdx + 1, ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count),
ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count,
null,
Math.Min(_buildOTAGroupIdx + 1, _oneTakesAllGroupCollection.Length),
_oneTakesAllGroupCollection.Length,
groupBuildProgress
);
//Raise event
RLBuildProgressChanged?.Invoke(buildProgress);
return;
}
/// <summary>
/// Resets the readout layer to its initial untrained state.
/// </summary>
public void Reset()
{
_predictorFeatureFilterCollection = null;
_outputFeatureFilterCollection = null;
_predictorsMapper = null;
_readoutUnitCollection = new ReadoutUnit[ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count];
for (int i = 0; i < ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count; i++)
{
ReadoutUnitSettings cfg = ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[i];
_readoutUnitCollection[i] = new ReadoutUnit(i, cfg, ReadoutLayerCfg.TaskDefaultsCfg);
}
_oneTakesAllGroupCollection = null;
if (ReadoutLayerCfg.OneTakesAllGroupsCfg != null)
{
_oneTakesAllGroupCollection = new OneTakesAllGroup[ReadoutLayerCfg.OneTakesAllGroupsCfg.OneTakesAllGroupCfgCollection.Count];
for (int i = 0; i < ReadoutLayerCfg.OneTakesAllGroupsCfg.OneTakesAllGroupCfgCollection.Count; i++)
{
OneTakesAllGroupSettings cfg = ReadoutLayerCfg.OneTakesAllGroupsCfg.OneTakesAllGroupCfgCollection[i];
_oneTakesAllGroupCollection[i] = new OneTakesAllGroup(i, cfg, ReadoutLayerCfg.GetOneTakesAllGroupMemberRUnitIndexes(cfg.Name));
}
}
Trained = false;
ResetProgressTracking();
return;
}
/// <summary>
/// Builds trained readout layer.
/// </summary>
/// <param name="dataBundle">The data to be used for training.</param>
/// <param name="predictorsMapper">The mapper of specific predictors to readout units (optional).</param>
/// <param name="controller">The build process controller (optional).</param>
/// <param name="randomizerSeek">Specifies the random number generator initial seek (optional). A value greater than or equal to 0 will always ensure the same initialization.</param>
/// <returns>The results of training.</returns>
public RegressionOverview Build(VectorBundle dataBundle,
PredictorsMapper predictorsMapper = null,
TNRNetBuilder.BuildControllerDelegate controller = null,
int randomizerSeek = 0
)
{
if (Trained)
{
throw new InvalidOperationException("Readout layer is already built.");
}
//Basic checks
int numOfPredictors = dataBundle.InputVectorCollection[0].Length;
int numOfOutputs = dataBundle.OutputVectorCollection[0].Length;
if (numOfPredictors == 0)
{
throw new InvalidOperationException($"Number of predictors must be greater than 0.");
}
if (numOfOutputs != ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count)
{
throw new InvalidOperationException($"Incorrect length of output vectors.");
}
//Predictors mapper (specified or default)
_predictorsMapper = predictorsMapper ?? new PredictorsMapper(numOfPredictors);
//Allocation and preparation of feature filters
//Predictors
_predictorFeatureFilterCollection = new FeatureFilterBase[numOfPredictors];
Parallel.For(0, _predictorFeatureFilterCollection.Length, nrmIdx =>
{
_predictorFeatureFilterCollection[nrmIdx] = new RealFeatureFilter(InternalDataRange, true, true);
for (int pairIdx = 0; pairIdx < dataBundle.InputVectorCollection.Count; pairIdx++)
{
//Adjust filter
_predictorFeatureFilterCollection[nrmIdx].Update(dataBundle.InputVectorCollection[pairIdx][nrmIdx]);
}
});
//Output values
_outputFeatureFilterCollection = new FeatureFilterBase[numOfOutputs];
Parallel.For(0, _outputFeatureFilterCollection.Length, nrmIdx =>
{
_outputFeatureFilterCollection[nrmIdx] = FeatureFilterFactory.Create(InternalDataRange, ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[nrmIdx].TaskCfg.FeatureFilterCfg);
for (int pairIdx = 0; pairIdx < dataBundle.OutputVectorCollection.Count; pairIdx++)
{
//Adjust output normalizer
_outputFeatureFilterCollection[nrmIdx].Update(dataBundle.OutputVectorCollection[pairIdx][nrmIdx]);
}
});
//Data normalization
//Allocation
double[][] normalizedPredictorsCollection = new double[dataBundle.InputVectorCollection.Count][];
double[][] normalizedIdealOutputsCollection = new double[dataBundle.OutputVectorCollection.Count][];
//Normalization
Parallel.For(0, dataBundle.InputVectorCollection.Count, pairIdx =>
{
//Predictors
double[] predictors = new double[numOfPredictors];
for (int i = 0; i < numOfPredictors; i++)
{
if (_predictorsMapper.PredictorGeneralSwitchCollection[i])
{
predictors[i] = _predictorFeatureFilterCollection[i].ApplyFilter(dataBundle.InputVectorCollection[pairIdx][i]);
}
else
{
predictors[i] = double.NaN;
}
}
normalizedPredictorsCollection[pairIdx] = predictors;
//Outputs
double[] outputs = new double[numOfOutputs];
for (int i = 0; i < numOfOutputs; i++)
{
outputs[i] = _outputFeatureFilterCollection[i].ApplyFilter(dataBundle.OutputVectorCollection[pairIdx][i]);
}
normalizedIdealOutputsCollection[pairIdx] = outputs;
});
//Random object initialization
Random rand = (randomizerSeek < 0 ? new Random() : new Random(randomizerSeek));
//Create shuffled copy of the data
VectorBundle shuffledData = new VectorBundle(normalizedPredictorsCollection, normalizedIdealOutputsCollection);
shuffledData.Shuffle(rand);
//"One Takes All" groups input data space initialization
List<CompositeResult[]> allReadoutUnitResults = new List<CompositeResult[]>(shuffledData.InputVectorCollection.Count);
if (_oneTakesAllGroupCollection != null)
{
for (int i = 0; i < shuffledData.InputVectorCollection.Count; i++)
{
allReadoutUnitResults.Add(new CompositeResult[ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count]);
}
}
ResetProgressTracking();
//Building of readout units
for (_buildReadoutUnitIdx = 0; _buildReadoutUnitIdx < ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count; _buildReadoutUnitIdx++)
{
List<double[]> idealValueCollection = new List<double[]>(shuffledData.OutputVectorCollection.Count);
//Transformation of ideal vectors to a single value vectors
foreach (double[] idealVector in shuffledData.OutputVectorCollection)
{
double[] value = new double[1];
value[0] = idealVector[_buildReadoutUnitIdx];
idealValueCollection.Add(value);
}
List<double[]> readoutUnitInputVectorCollection = _predictorsMapper.CreateVectorCollection(ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[_buildReadoutUnitIdx].Name, shuffledData.InputVectorCollection);
VectorBundle readoutUnitDataBundle = new VectorBundle(readoutUnitInputVectorCollection, idealValueCollection);
_readoutUnitCollection[_buildReadoutUnitIdx].ReadoutUnitBuildProgressChanged += OnReadoutUnitBuildProgressChanged;
_readoutUnitCollection[_buildReadoutUnitIdx].Build(readoutUnitDataBundle,
_outputFeatureFilterCollection[_buildReadoutUnitIdx],
rand,
controller
);
//Add unit's all computed results into the input data for "One Takes All" groups
if (_oneTakesAllGroupCollection != null)
{
for (int sampleIdx = 0; sampleIdx < readoutUnitDataBundle.InputVectorCollection.Count; sampleIdx++)
{
allReadoutUnitResults[sampleIdx][_buildReadoutUnitIdx] = _readoutUnitCollection[_buildReadoutUnitIdx].Compute(readoutUnitDataBundle.InputVectorCollection[sampleIdx]);
}
}
}//unitIdx
//One Takes All groups build
if (_oneTakesAllGroupCollection != null)
{
foreach (OneTakesAllGroup group in _oneTakesAllGroupCollection)
{
//Only the group having inner probabilistic cluster has to be built
if (group.DecisionMethod == OneTakesAllGroup.OneTakesAllDecisionMethod.ClusterChain)
{
BinFeatureFilter[] groupFilters = new BinFeatureFilter[group.NumOfMemberClasses];
for (int i = 0; i < group.NumOfMemberClasses; i++)
{
groupFilters[i] = (BinFeatureFilter)_outputFeatureFilterCollection[group.MemberReadoutUnitIndexCollection[i]];
}
++_buildOTAGroupIdx;
group.OTAGBuildProgressChanged += OnOTAGBuildProgressChanged;
group.Build(allReadoutUnitResults, shuffledData.OutputVectorCollection, groupFilters, rand, controller);
}
}
}
//Readout layer is trained and ready
Trained = true;
return new RegressionOverview(ReadoutUnitErrStatCollection);
}
/// <summary>
/// Creates the text report of predicted values.
/// </summary>
/// <param name="predictedValues">The computed vector.</param>
/// <param name="margin">Specifies the left margin of the text.</param>
/// <returns>The built text report.</returns>
public string GetForecastReport(double[] predictedValues, int margin)
{
string leftMargin = margin == 0 ? string.Empty : new string(' ', margin);
StringBuilder sb = new StringBuilder();
//Results
for (int outputIdx = 0; outputIdx < ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection.Count; outputIdx++)
{
sb.Append(leftMargin + $"Output field [{ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[outputIdx].Name}]: {predictedValues[outputIdx].ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
}
return sb.ToString();
}
/// <summary>
/// Normalizes the vector of predictors.
/// </summary>
/// <param name="predictors">The predictors vector.</param>
private double[] NormalizePredictors(double[] predictors)
{
//Check
if (predictors.Length != _predictorFeatureFilterCollection.Length)
{
throw new InvalidOperationException($"Incorrect length of predictors vector.");
}
double[] nrmPredictors = new double[predictors.Length];
for (int i = 0; i < predictors.Length; i++)
{
nrmPredictors[i] = _predictorFeatureFilterCollection[i].ApplyFilter(predictors[i]);
}
return nrmPredictors;
}
/// <summary>
/// Naturalizes the output values.
/// </summary>
/// <param name="outputs">The output values vector.</param>
private double[] NaturalizeOutputs(double[] outputs)
{
double[] natOutputs = new double[outputs.Length];
for (int i = 0; i < outputs.Length; i++)
{
natOutputs[i] = _outputFeatureFilterCollection[i].ApplyReverse(outputs[i]);
}
return natOutputs;
}
/// <summary>
/// Computes the readout units.
/// </summary>
private CompositeResult[] ComputeReadoutUnits(double[] predictors, out double[] outputVector)
{
CompositeResult[] unitsResults = new CompositeResult[_readoutUnitCollection.Length];
outputVector = new double[_readoutUnitCollection.Length];
for (int unitIdx = 0; unitIdx < _readoutUnitCollection.Length; unitIdx++)
{
double[] readoutUnitInputVector = _predictorsMapper.CreateVector(ReadoutLayerCfg.ReadoutUnitsCfg.ReadoutUnitCfgCollection[unitIdx].Name, predictors);
CompositeResult unitResult = _readoutUnitCollection[unitIdx].Compute(readoutUnitInputVector);
outputVector[unitIdx] = unitResult.Result[0];
unitsResults[unitIdx] = unitResult;
}
return unitsResults;
}
/// <summary>
/// Computes the readout layer.
/// </summary>
/// <param name="predictors">The predictors.</param>
/// <param name="readoutData">The detailed computed data.</param>
/// <returns>An output vector of the computed and naturalized values.</returns>
public double[] Compute(double[] predictors, out ReadoutData readoutData)
{
//Check readyness
if (!Trained)
{
throw new InvalidOperationException($"Readout layer is not trained. Build function has to be called before Compute function can be used.");
}
//Normalize predictors
double[] nrmPredictors = NormalizePredictors(predictors);
//Compute all readout units
CompositeResult[] unitsResults = ComputeReadoutUnits(nrmPredictors, out double[] nrmOutputVector);
//Build readout units results
ReadoutData.ReadoutUnitData[] readoutUnitsData = new ReadoutData.ReadoutUnitData[unitsResults.Length];
for (int unitIdx = 0; unitIdx < readoutUnitsData.Length; unitIdx++)
{
readoutUnitsData[unitIdx] = new ReadoutData.ReadoutUnitData()
{
Name = _readoutUnitCollection[unitIdx].Name,
Index = _readoutUnitCollection[unitIdx].Index,
Task = _readoutUnitCollection[unitIdx].Task,
CompResult = unitsResults[unitIdx],
RawNrmDataValue = nrmOutputVector[unitIdx],
RawNatDataValue = _outputFeatureFilterCollection[unitIdx].ApplyReverse(nrmOutputVector[unitIdx]),
FinalNatDataValue = _outputFeatureFilterCollection[unitIdx].ApplyReverse(nrmOutputVector[unitIdx])
};
}
//Compute all "One Takes All" groups
ReadoutData.OneTakesAllGroupData[] groupsData = null;
if (_oneTakesAllGroupCollection != null)
{
groupsData = new ReadoutData.OneTakesAllGroupData[_oneTakesAllGroupCollection.Length];
for (int groupIdx = 0; groupIdx < _oneTakesAllGroupCollection.Length; groupIdx++)
{
int groupInnerWinnerIdx = _oneTakesAllGroupCollection[groupIdx].Compute(unitsResults, out CompositeResult groupResult, out double[] groupOutputVector);
int layerWinnerIdx = _oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection[groupInnerWinnerIdx];
groupsData[groupIdx] = new ReadoutData.OneTakesAllGroupData()
{
GroupName = _oneTakesAllGroupCollection[groupIdx].Name,
WinningReadoutUnitName = _readoutUnitCollection[layerWinnerIdx].Name,
WinningReadoutUnitIndex = layerWinnerIdx,
MemberWinningGroupIndex = groupInnerWinnerIdx,
MemberReadoutUnitIndexes = _oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection.ToArray(),
CompResult = groupResult,
MemberProbabilities = groupOutputVector
};
//Update nrmOuputVector
for (int i = 0; i < _oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection.Count; i++)
{
if (i == groupInnerWinnerIdx)
{
nrmOutputVector[_oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection[i]] = InternalDataRange.Max;
}
else
{
nrmOutputVector[_oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection[i]] = InternalDataRange.Min;
}
//nrmOutputVector[_oneTakesAllGroupCollection[groupIdx].MemberReadoutUnitIndexCollection[i]] = groupOutputVector[i];
}
}
}
//Output data finalization
double[] natOuputVector = NaturalizeOutputs(nrmOutputVector);
for (int unitIdx = 0; unitIdx < readoutUnitsData.Length; unitIdx++)
{
readoutUnitsData[unitIdx].FinalNatDataValue = natOuputVector[unitIdx];
}
readoutData = new ReadoutData(nrmOutputVector, natOuputVector, readoutUnitsData, groupsData);
return natOuputVector;
}
//Inner classes
/// <summary>
/// Implements the holder of detailed computed data.
/// </summary>
[Serializable]
public class ReadoutData
{
/// <summary>
/// The vector of normalized output values.
/// </summary>
public double[] NrmDataVector { get; }
/// <summary>
/// The vector of naturalized output values.
/// </summary>
public double[] NatDataVector { get; }
/// <summary>
/// The collection of readout units data.
/// </summary>
public List<ReadoutUnitData> ReadoutUnitDataCollection { get; }
/// <summary>
/// The collection of one-takes-all groups data.
/// </summary>
public List<OneTakesAllGroupData> OneTakesAllGroupDataCollection { get; }
//Constructor
/// <summary>
/// Creates an initialized instance.
/// </summary>
/// <param name="nrmDataVector">The vector of normalized output values.</param>
/// <param name="natDataVector">The vector of naturalized output values.</param>
/// <param name="unitsResults">The collection of readout units data.</param>
/// <param name="oneTakesAllGroupsResults">The collection of one-takes-all groups data.</param>
public ReadoutData(double[] nrmDataVector,
double[] natDataVector,
ReadoutUnitData[] unitsResults,
OneTakesAllGroupData[] oneTakesAllGroupsResults
)
{
NrmDataVector = nrmDataVector;
NatDataVector = natDataVector;
ReadoutUnitDataCollection = new List<ReadoutUnitData>();
foreach (ReadoutUnitData unitResult in unitsResults)
{
ReadoutUnitDataCollection.Add(unitResult);
}
//One Takes All groups
OneTakesAllGroupDataCollection = new List<OneTakesAllGroupData>();
if (oneTakesAllGroupsResults != null)
{
foreach (OneTakesAllGroupData groupResult in oneTakesAllGroupsResults)
{
OneTakesAllGroupDataCollection.Add(groupResult);
}
}
return;
}
/// <summary>
/// Gets the one-takes-all group data by the group name.
/// </summary>
/// <param name="groupName">The name of the one-takes-all group.</param>
/// <returns>The group data or null if not found.</returns>
public OneTakesAllGroupData GetOneTakesAllGroupData(string groupName)
{
foreach (OneTakesAllGroupData groupData in OneTakesAllGroupDataCollection)
{
if (groupData.GroupName == groupName)
{
return groupData;
}
}
return null;
}
/// <summary>
/// Gets the readout unit data by the readout unit name.
/// </summary>
/// <param name="readoutUnitName">The name of the readout unit.</param>
/// <returns>The readout unit data or null if not found.</returns>
public ReadoutUnitData GetReadoutUnitData(string readoutUnitName)
{
foreach (ReadoutUnitData readoutUnitData in ReadoutUnitDataCollection)
{
if (readoutUnitData.Name == readoutUnitName)
{
return readoutUnitData;
}
}
return null;
}
//Inner classes
/// <summary>
/// Implements the holder of the readout unit results.
/// </summary>
[Serializable]
public class ReadoutUnitData
{
/// <summary>
/// The name of the readout unit.
/// </summary>
public string Name { get; set; }
/// <summary>
/// The zero-based index of the readout unit.
/// </summary>
public int Index { get; set; }
/// <inheritdoc cref="ReadoutUnit.TaskType"/>
public ReadoutUnit.TaskType Task { get; set; }
/// <summary>
/// The composite result.
/// </summary>
public CompositeResult CompResult { get; set; }
/// <summary>
/// The normalized output data value computed by the unit.
/// </summary>
public double RawNrmDataValue { get; set; }
/// <summary>
/// The naturalized output data value computed by the unit.
/// </summary>
public double RawNatDataValue { get; set; }
/// <summary>
/// The naturalized final output data value.
/// </summary>
public double FinalNatDataValue { get; set; }
}//ReadoutUnitData
/// <summary>
/// Implements the holder of the "One Takes All" group result.
/// </summary>
[Serializable]
public class OneTakesAllGroupData
{
/// <summary>
/// The name of the "One Takes All" group.
/// </summary>
public string GroupName { get; set; }
/// <summary>
/// The name of the winning readout unit (class).
/// </summary>
public string WinningReadoutUnitName { get; set; }
/// <summary>
/// The zero-based index of the winning readout unit.
/// </summary>
public int WinningReadoutUnitIndex { get; set; }
/// <summary>
/// The zero-based index of the winning member within the group.
/// </summary>
public int MemberWinningGroupIndex { get; set; }
/// <summary>
/// The indexes of readout units belonging to the group.
/// </summary>
public int[] MemberReadoutUnitIndexes { get; set; }
/// <summary>
/// The composite result (always in 0...1 range).
/// </summary>
public CompositeResult CompResult { get; set; }
/// <summary>
/// The probabilities of the group members (probabilities are in common range).
/// </summary>
public double[] MemberProbabilities { get; set; }
}//OneTakesAllGroupData
}//ReadoutData
/// <summary>
/// Implements the holder of the readout layer training (regression) results.
/// </summary>
[Serializable]
public class RegressionOverview
{
/// <summary>
/// The collection of error statistics of the readout units.
/// </summary>
public List<TNRNetCluster.ClusterErrStatistics> ReadoutUnitErrStatCollection { get; }
/// <summary>
/// Creates an initialized instance.
/// </summary>
/// <param name="readoutUnitErrStatCollection">The collection of error statistics of the readout units.</param>
public RegressionOverview(List<TNRNetCluster.ClusterErrStatistics> readoutUnitErrStatCollection)
{
ReadoutUnitErrStatCollection = readoutUnitErrStatCollection;
return;
}
//Methods
private string BuildErrStatReport(string leftMargin, TNRNetCluster.ClusterErrStatistics ces)
{
StringBuilder sb = new StringBuilder();
if (ces.BinaryErrStat != null)
{
//Classification task report
sb.Append(leftMargin + $" Classification of negative samples" + Environment.NewLine);
sb.Append(leftMargin + $" Number of samples: {ces.BinaryErrStat.BinValErrStat[0].NumOfSamples}" + Environment.NewLine);
sb.Append(leftMargin + $" Number of errors: {ces.BinaryErrStat.BinValErrStat[0].Sum.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Error rate: {ces.BinaryErrStat.BinValErrStat[0].ArithAvg.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Accuracy: {(1 - ces.BinaryErrStat.BinValErrStat[0].ArithAvg).ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Classification of positive samples" + Environment.NewLine);
sb.Append(leftMargin + $" Number of samples: {ces.BinaryErrStat.BinValErrStat[1].NumOfSamples}" + Environment.NewLine);
sb.Append(leftMargin + $" Number of errors: {ces.BinaryErrStat.BinValErrStat[1].Sum.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Error rate: {ces.BinaryErrStat.BinValErrStat[1].ArithAvg.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Accuracy: {(1 - ces.BinaryErrStat.BinValErrStat[1].ArithAvg).ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Overall classification results" + Environment.NewLine);
sb.Append(leftMargin + $" Number of samples: {ces.BinaryErrStat.TotalErrStat.NumOfSamples}" + Environment.NewLine);
sb.Append(leftMargin + $" Number of errors: {ces.BinaryErrStat.TotalErrStat.Sum.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Error rate: {ces.BinaryErrStat.TotalErrStat.ArithAvg.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Accuracy: {(1 - ces.BinaryErrStat.TotalErrStat.ArithAvg).ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
}
else
{
//Forecast task report
sb.Append(leftMargin + $" Number of samples: {ces.NatPrecissionErrStat.NumOfSamples}" + Environment.NewLine);
sb.Append(leftMargin + $" Biggest error: {ces.NatPrecissionErrStat.Max.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Smallest error: {ces.NatPrecissionErrStat.Min.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
sb.Append(leftMargin + $" Average error: {ces.NatPrecissionErrStat.ArithAvg.ToString(CultureInfo.InvariantCulture)}" + Environment.NewLine);
}
sb.Append(Environment.NewLine);
return sb.ToString();
}
/// <summary>
/// Returns the text report of the readout layer training (regression).
/// </summary>
/// <param name="margin">Specifies the left margin of the text.</param>
/// <returns>The built text report.</returns>
public string GetTrainingResultsReport(int margin)
{
string leftMargin = margin == 0 ? string.Empty : new string(' ', margin);
StringBuilder sb = new StringBuilder();
//Training results of readout units
foreach (TNRNetCluster.ClusterErrStatistics ces in ReadoutUnitErrStatCollection)
{
sb.Append(leftMargin + $"Output field [{ces.ClusterName}]" + Environment.NewLine);
sb.Append(BuildErrStatReport(leftMargin, ces));
}
return sb.ToString();
}
}//RegressionOverview
/// <summary>
/// Implements the holder of the readout layer build progress information.
/// </summary>
public class BuildProgress : IBuildProgress
{
//Attribute properties
/// <summary>
/// Information about the readout units processing progress.
/// </summary>
public ProgressTracker UnitsTracker { get; }
/// <summary>
/// Information about the current readout unit build progress.
/// </summary>
public ReadoutUnit.BuildProgress UnitBuildProgress { get; }
/// <summary>
/// Information about the One Takes All groups processing progress.
/// </summary>
public ProgressTracker GroupsTracker { get; }
/// <summary>
/// Information about the current One Takes All group build progress.
/// </summary>
public OneTakesAllGroup.BuildProgress GroupBuildProgress { get; }
//Constructor
/// <summary>
/// Creates an initialized instance.
/// </summary>
/// <param name="unitNum">The current readout unit number.</param>
/// <param name="maxNumOfUnits">The maximum number of readout units.</param>
/// <param name="unitBuildProgress">The holder of the readout unit build progress information.</param>
/// <param name="groupNum">The current One Takes All group number.</param>
/// <param name="maxNumOfGroups">The maximum number of One Takes All groups.</param>
/// <param name="groupBuildProgress">The holder of the One Takes All group build progress information.</param>
public BuildProgress(int unitNum,
int maxNumOfUnits,
ReadoutUnit.BuildProgress unitBuildProgress,
int groupNum,
int maxNumOfGroups,
OneTakesAllGroup.BuildProgress groupBuildProgress
)
{
UnitsTracker = new ProgressTracker((uint)maxNumOfUnits, (uint)unitNum);
UnitBuildProgress = unitBuildProgress;
GroupsTracker = new ProgressTracker((uint)maxNumOfGroups, (uint)groupNum);
GroupBuildProgress = groupBuildProgress;
return;
}
//Properties
/// <summary>
/// Indicates the readout units build progress is present.
/// </summary>
public bool ContainsReadoutUnitsProgressInfo { get { return UnitBuildProgress != null; } }
/// <summary>
/// Indicates the One Takes All groups build progress is present.
/// </summary>
public bool ContainsOneTakesAllGroupsProgressInfo { get { return GroupBuildProgress != null; } }
/// <inheritdoc/>
public bool NewEndNetwork
{
get
{
return GetActiveBuildProgress().NewEndNetwork;
}
}
/// <inheritdoc/>
public bool ShouldBeReported
{
get
{
return GetActiveBuildProgress().ShouldBeReported;
}
}
/// <inheritdoc/>
public int EndNetworkEpochNum
{
get
{
return GetActiveBuildProgress().EndNetworkEpochNum;
}
}
//Methods
private IBuildProgress GetActiveBuildProgress()
{
return (IBuildProgress)UnitBuildProgress ?? GroupBuildProgress;
}
/// <inheritdoc/>
public string GetInfoText(int margin = 0, bool includeName = true)
{
return GetActiveBuildProgress().GetInfoText(margin, includeName);
}
}//BuildProgress
}//ReadoutLayer
}//Namespace