Skip to content

Commit 4cec9db

Browse files
committed
- improve Javadoc comments for network related source code
1 parent 4587aeb commit 4cec9db

File tree

8 files changed

+278
-18
lines changed

8 files changed

+278
-18
lines changed

rapaio-core/src/rapaio/nn/Autograd.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@
3131
import rapaio.darray.DArray;
3232

3333
/**
34-
* Central place of automatic differentiation in reverse mode.
35-
* <p>
36-
* Object which allows differentiation must implement {@link Tensor}.
34+
* Implementation of automatic differentiation in reverse mode. Object which allows differentiation must implement {@link Tensor}.
3735
* <p>
3836
* The forward operations are performed when the computation is called using various operations
3937
* on {@link Tensor} or when new node are created with {@link TensorManager#var(DArray)}.

rapaio-core/src/rapaio/nn/Loss.java

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
import rapaio.nn.data.Batch;
2828

29+
/**
30+
* Loss function used to optimize a network during backpropagation
31+
*/
2932
public interface Loss {
3033

3134
enum Reduce {

rapaio-core/src/rapaio/nn/Network.java

+99
Original file line numberDiff line numberDiff line change
@@ -104,32 +104,131 @@ public interface Network extends Serializable {
104104
*/
105105
void eval();
106106

107+
/**
108+
* Performs network inference by taking an array of input tensors and computing an array of output tensors.
109+
* Usually a network takes a single input and produces a single output. For this purpose the default
110+
* implementation delegates to {@link #forward11(Tensor)}.
111+
* <p>
112+
* If the default scenario is not in the purpose of the network, this method needs to be implemented.
113+
* <p>
114+
* The forward method contains operations on tensors. All operations on tensors are tracked by the
115+
* computational graph since each tensor operation leaves a trace which consists of backpropagation
116+
* functions. When {@link Autograd#backward(Tensor)} method is called on some tensor which has
117+
* a scalar gradient, the computational graph starts to back propagate gradients.
118+
*
119+
* @param xs input tensors
120+
* @return computed output tensors
121+
*/
107122
default Tensor[] forward(Tensor... xs) {
108123
if (xs.length == 1) {
109124
return new Tensor[] {forward11(xs[0])};
110125
}
111126
throw new NotImplementedException();
112127
}
113128

129+
/**
130+
* The default case of {@link #forward(Tensor...)} method which receives a single input tensor
131+
* and outputs a single tensor.
132+
*
133+
* @param x input tensor
134+
* @return computed output tensor
135+
*/
114136
default Tensor forward11(Tensor x) {
115137
throw new NotImplementedException();
116138
}
117139

140+
/**
141+
* Improved forward method which trades memory for parallel batched execution of the forward pass.
142+
* <p>
143+
* The execution consists of splitting the input tensors in batches and parallel execution
144+
* of those batches in the forward step. The tradeoff consists in the fact that all the computational
145+
* graph will reside in memory, thus one can use this method if the dataset is small enough,
146+
* depending on the available memory.
147+
* <p>
148+
* The result consists of a list of batches. Each batch contains input data and also contains the
149+
* network output tensors computed for the given specific batch.
150+
* <p>
151+
* The batch size is given as parameter. Before splitting in batches the data from the dataset is shuffled
152+
* and all the batches are used for execution.
153+
*
154+
* @param batchSize the number of instances for each batch, the last batch might contain few instances
155+
* @param inputs input tensors
156+
* @return list of computed batches
157+
*/
118158
default List<Batch> batchForward(int batchSize, Tensor... inputs) {
119159
return batchForward(batchSize, true, false, inputs);
120160
}
121161

162+
/**
163+
* Fully customizable version of {@link #batchForward(int, Tensor...)}.
164+
* <p>
165+
* Improved forward method which trades memory for parallel batched execution of the forward pass.
166+
* <p>
167+
* The execution consists of splitting the input tensors in batches and parallel execution
168+
* of those batches in the forward step. The tradeoff consists in the fact that all the computational
169+
* graph will reside in memory, thus one can use this method if the dataset is small enough,
170+
* depending on the available memory.
171+
* <p>
172+
* The result consists of a list of batches. Each batch contains input data and also contains the
173+
* network output tensors computed for the given specific batch.
174+
* <p>
175+
* The batch size is given as parameter. Before splitting in batches the data from the instances are shuffled if
176+
* {@code shuffle} parameter is true. In some cases the last batch might contain fewer instances. If this
177+
* is not desirable, one can set {@code skipLast} to {@code true} to skip the last batch.
178+
*
179+
* @param batchSize the batch size
180+
* @param shuffle if data is shuffled before splitting in batches
181+
* @param skipLast if the last batch, which might be smaller, is skipped for execution
182+
* @param inputs input tensors
183+
* @return list of computed batches
184+
*/
122185
List<Batch> batchForward(int batchSize, boolean shuffle, boolean skipLast, Tensor... inputs);
123186

187+
/**
188+
* Saves the state of the network to an atom output stream.
189+
*
190+
* @param out atom output stream
191+
* @throws IOException thrown if something goes wrong
192+
*/
124193
void saveState(AtomOutputStream out) throws IOException;
125194

195+
/**
196+
* Loads the state of the network from an atom input stream
197+
*
198+
* @param in atom input stream
199+
* @throws IOException thrown if something goes wrong
200+
*/
126201
void loadState(AtomInputStream in) throws IOException;
127202

203+
/**
204+
* Save the network state using atom binary serialization protocol to a file.
205+
*
206+
* @param file file which will store the network state
207+
* @throws IOException thrown if something goes wrong
208+
*/
128209
void saveState(File file) throws IOException;
129210

211+
/**
212+
* Saves the network state using atom binary serialization protocol to a generic output stream.
213+
*
214+
* @param out output stream
215+
* @throws IOException thrown if something goes wrong
216+
*/
130217
void saveState(OutputStream out) throws IOException;
131218

219+
/**
220+
* Loads the network state using atom binary serialization protocol from a file.
221+
*
222+
* @param file file which contains the serialized network state
223+
* @throws IOException thrown if something goes wrong
224+
*/
132225
void loadState(File file) throws IOException;
133226

227+
/**
228+
* Loads the network state using atom binary serialization protocol from a generic input stream
229+
*
230+
* @param in input stream
231+
* @throws IOException if something goes wrong
232+
*/
134233
void loadState(InputStream in) throws IOException;
135234
}

rapaio-core/src/rapaio/nn/NetworkState.java

+17
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@
2424
import java.util.ArrayList;
2525
import java.util.List;
2626

27+
/**
28+
* Represents the network state. The network state is a container which holds
29+
* all the tensors which are used for inference and learning in a network.
30+
* <p>
31+
* For serialization purposes the network code is not saved. This is in order
32+
* to give enough freedom to the used to customize the network behavior
33+
* with custom code at training and inference time. Instead of that,
34+
* if a network has to be serialized, the following scenario can be followed:
35+
*
36+
* <li>
37+
* <item>create a network instance and do whatever is needed to be used later (including initialization, training, other customizations)</item>
38+
* <item>save the network state into a persistent storage using one of the methods {@code Network#saveState}</item>
39+
* <item>For later usage, create again a new instance of the network</item>
40+
* <item>Loads the network state from a persistent storage using one of the methods {@code Network#loadState}</item>
41+
* <item>The new network is ready to be used like the old network instance for inference of for other scenarios, like further training</item>
42+
* </li>
43+
*/
2744
public final class NetworkState {
2845

2946
private final ArrayList<Tensor> tensors;

rapaio-core/src/rapaio/nn/Optimizer.java

+29
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,46 @@
2626
import rapaio.nn.optimizer.Adam;
2727
import rapaio.nn.optimizer.SGD;
2828

29+
/**
30+
* Defines the contract for optimization algorithms.
31+
* <p>
32+
* An optimization algorithm is an algorithm which uses computed gradients and updates the values of tracked tensors according
33+
* to its own strategy. The tracked tensors are received at creation time.
34+
* <p>
35+
* Even if a tensor is tracked for optimization, it can be skipped for optimization if it has {@code requiresGrad} set to false.
36+
* This is useful for scenarios when somebody wants to freeze some parts of the networks and update only other parts.
37+
*/
2938
public interface Optimizer {
3039

40+
/**
41+
* Creates a new instance of Stochastic Gradient Descent optimizer. Further customization can be done on the returned instance.
42+
*
43+
* @param tm tensor manager used for computation
44+
* @param params list of tracked tensors for optimization
45+
* @return new optimizer instance
46+
*/
3147
static SGD SGD(TensorManager tm, Collection<Tensor> params) {
3248
return new SGD(tm, params);
3349
}
3450

51+
/**
52+
* Creates a new instance of Adam optimizer. Further customization can be done on the returned instance.
53+
*
54+
* @param tm tensor manager used for computation
55+
* @param params list of tracked tensors for optimization
56+
* @return new optimizer instance
57+
*/
3558
static Adam Adam(TensorManager tm, Collection<Tensor> params) {
3659
return new Adam(tm, params);
3760
}
3861

62+
/**
63+
* Deletes all the computed gradients for the tracked tensors
64+
*/
3965
void zeroGrad();
4066

67+
/**
68+
* Performs the optimization which consists of updating tensor values according to the computed gradients and algorithm strategy.
69+
*/
4170
void step();
4271
}

0 commit comments

Comments
 (0)