Skip to content

Commit 659e982

Browse files
committed
Add comments to DLRM initializers to explain the hyperparameters.
1 parent f15a00f commit 659e982

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

Models/Recommendation/DLRM.swift

+10-1
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,21 @@ public struct DLRM: Module {
4646
@noDerivative public let nDense: Int
4747
@noDerivative public let interaction: InteractionType
4848

49+
/// Randomly initialize a DLRM model from the given hyperparameters.
50+
///
51+
/// - Parameters:
52+
/// - nDense: The number of continuous or dense inputs for each example.
53+
/// - mSpa: The "width" of all embedding tables.
54+
/// - lnEmb: Defines the "heights" of each of each embedding table.
55+
/// - lnBot: The size of the hidden layers in the bottom MLP.
56+
/// - lnTop: The size of the hidden layers in the top MLP.
57+
/// - interaction: The type of interactions between the hidden features.
4958
public init(nDense: Int, mSpa: Int, lnEmb: [Int], lnBot: [Int], lnTop: [Int],
5059
interaction: InteractionType = .concatenate) {
5160
self.nDense = nDense
5261
mlpBottom = MLP(dims: [nDense] + lnBot)
5362
let topInput = lnEmb.count * mSpa + lnBot.last!
54-
mlpTop = MLP(dims: [topInput] + lnTop + [1])
63+
mlpTop = MLP(dims: [topInput] + lnTop + [1], sigmoidLastLayer: true)
5564
latentFactors = lnEmb.map { Embedding(vocabularySize: $0, embeddingSize: mSpa) }
5665
self.interaction = interaction
5766
}

Models/Recommendation/MLP.swift

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ import TensorFlow
1717
/// MLP is a multi-layer perceptron and is used as a component of the DLRM model
1818
public struct MLP: Layer {
1919
public var blocks: [Dense<Float>] = []
20+
21+
/// Randomly initializes a new multilayer perceptron from the given hyperparameters.
22+
///
23+
/// - Parameter dims: Dims represents the size of the input, hidden layers, and output of the
24+
/// multi-layer perceptron.
25+
/// - Parameter sigmoidLastLayer: if `true`, use a `sigmoid` activation function for the last layer,
26+
/// `relu` otherwise.
2027
init(dims: [Int], sigmoidLastLayer: Bool = false) {
2128
for i in 0..<(dims.count-1) {
2229
if sigmoidLastLayer && i == dims.count - 2 {

0 commit comments

Comments
 (0)