@@ -104,32 +104,131 @@ public interface Network extends Serializable {
104
104
*/
105
105
void eval ();
106
106
107
+ /**
108
+ * Performs network inference by taking an array of input tensors and computing an array of output tensors.
109
+ * Usually a network takes a single input and produces a single output. For this purpose the default
110
+ * implementation delegates to {@link #forward11(Tensor)}.
111
+ * <p>
112
+ * If the default scenario is not in the purpose of the network, this method needs to be implemented.
113
+ * <p>
114
+ * The forward method contains operations on tensors. All operations on tensors are tracked by the
115
+ * computational graph since each tensor operation leaves a trace which consists of backpropagation
116
+ * functions. When {@link Autograd#backward(Tensor)} method is called on some tensor which has
117
+ * a scalar gradient, the computational graph starts to back propagate gradients.
118
+ *
119
+ * @param xs input tensors
120
+ * @return computed output tensors
121
+ */
107
122
default Tensor [] forward (Tensor ... xs ) {
108
123
if (xs .length == 1 ) {
109
124
return new Tensor [] {forward11 (xs [0 ])};
110
125
}
111
126
throw new NotImplementedException ();
112
127
}
113
128
129
+ /**
130
+ * The default case of {@link #forward(Tensor...)} method which receives a single input tensor
131
+ * and outputs a single tensor.
132
+ *
133
+ * @param x input tensor
134
+ * @return computed output tensor
135
+ */
114
136
default Tensor forward11 (Tensor x ) {
115
137
throw new NotImplementedException ();
116
138
}
117
139
140
+ /**
141
+ * Improved forward method which trades memory for parallel batched execution of the forward pass.
142
+ * <p>
143
+ * The execution consists of splitting the input tensors in batches and parallel execution
144
+ * of those batches in the forward step. The tradeoff consists in the fact that all the computational
145
+ * graph will reside in memory, thus one can use this method if the dataset is small enough,
146
+ * depending on the available memory.
147
+ * <p>
148
+ * The result consists of a list of batches. Each batch contains input data and also contains the
149
+ * network output tensors computed for the given specific batch.
150
+ * <p>
151
+ * The batch size is given as parameter. Before splitting in batches the data from the dataset is shuffled
152
+ * and all the batches are used for execution.
153
+ *
154
+ * @param batchSize the number of instances for each batch, the last batch might contain few instances
155
+ * @param inputs input tensors
156
+ * @return list of computed batches
157
+ */
118
158
default List <Batch > batchForward (int batchSize , Tensor ... inputs ) {
119
159
return batchForward (batchSize , true , false , inputs );
120
160
}
121
161
162
+ /**
163
+ * Fully customizable version of {@link #batchForward(int, Tensor...)}.
164
+ * <p>
165
+ * Improved forward method which trades memory for parallel batched execution of the forward pass.
166
+ * <p>
167
+ * The execution consists of splitting the input tensors in batches and parallel execution
168
+ * of those batches in the forward step. The tradeoff consists in the fact that all the computational
169
+ * graph will reside in memory, thus one can use this method if the dataset is small enough,
170
+ * depending on the available memory.
171
+ * <p>
172
+ * The result consists of a list of batches. Each batch contains input data and also contains the
173
+ * network output tensors computed for the given specific batch.
174
+ * <p>
175
+ * The batch size is given as parameter. Before splitting in batches the data from the instances are shuffled if
176
+ * {@code shuffle} parameter is true. In some cases the last batch might contain fewer instances. If this
177
+ * is not desirable, one can set {@code skipLast} to {@code true} to skip the last batch.
178
+ *
179
+ * @param batchSize the batch size
180
+ * @param shuffle if data is shuffled before splitting in batches
181
+ * @param skipLast if the last batch, which might be smaller, is skipped for execution
182
+ * @param inputs input tensors
183
+ * @return list of computed batches
184
+ */
122
185
List <Batch > batchForward (int batchSize , boolean shuffle , boolean skipLast , Tensor ... inputs );
123
186
187
+ /**
188
+ * Saves the state of the network to an atom output stream.
189
+ *
190
+ * @param out atom output stream
191
+ * @throws IOException thrown if something goes wrong
192
+ */
124
193
void saveState (AtomOutputStream out ) throws IOException ;
125
194
195
+ /**
196
+ * Loads the state of the network from an atom input stream
197
+ *
198
+ * @param in atom input stream
199
+ * @throws IOException thrown if something goes wrong
200
+ */
126
201
void loadState (AtomInputStream in ) throws IOException ;
127
202
203
+ /**
204
+ * Save the network state using atom binary serialization protocol to a file.
205
+ *
206
+ * @param file file which will store the network state
207
+ * @throws IOException thrown if something goes wrong
208
+ */
128
209
void saveState (File file ) throws IOException ;
129
210
211
+ /**
212
+ * Saves the network state using atom binary serialization protocol to a generic output stream.
213
+ *
214
+ * @param out output stream
215
+ * @throws IOException thrown if something goes wrong
216
+ */
130
217
void saveState (OutputStream out ) throws IOException ;
131
218
219
+ /**
220
+ * Loads the network state using atom binary serialization protocol from a file.
221
+ *
222
+ * @param file file which contains the serialized network state
223
+ * @throws IOException thrown if something goes wrong
224
+ */
132
225
void loadState (File file ) throws IOException ;
133
226
227
+ /**
228
+ * Loads the network state using atom binary serialization protocol from a generic input stream
229
+ *
230
+ * @param in input stream
231
+ * @throws IOException if something goes wrong
232
+ */
134
233
void loadState (InputStream in ) throws IOException ;
135
234
}
0 commit comments