Skip to content

Commit 5676601

Browse files
committedMar 23, 2020
Clean up tests.
As part of ensuring the train test is reliable, I have matched the initialization of the reference implementation for the Embedding layers. Additionally, the DLRM model is susceptible to a "bad initialization" that doesn't perfectly memorize the single test minibatch. Although this is infrequent (~1 out of 50 test runs), I have modified the tests to randomly re-initialize 5 times, ensuring the test is approximately flaky with a probability of 3.2e-9 while still maintaining the quality of the test (e.g. testing random initialization, etc). Finally, instead of checking that loss drops below a particular value, the test checks that the accuracy is 100%. This results in a faster stopping condition, and thus the convergence test often runs in under 300ms on a laptop.
1 parent 659e982 commit 5676601

File tree

2 files changed

+37
-24
lines changed

2 files changed

+37
-24
lines changed
 

‎Models/Recommendation/DLRM.swift

+10-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,14 @@ public struct DLRM: Module {
6161
mlpBottom = MLP(dims: [nDense] + lnBot)
6262
let topInput = lnEmb.count * mSpa + lnBot.last!
6363
mlpTop = MLP(dims: [topInput] + lnTop + [1], sigmoidLastLayer: true)
64-
latentFactors = lnEmb.map { Embedding(vocabularySize: $0, embeddingSize: mSpa) }
64+
latentFactors = lnEmb.map { embeddingSize -> Embedding<Float> in
65+
// Use a random uniform initialization to match the reference implementation.
66+
let weights = Tensor<Float>(
67+
randomUniform: [embeddingSize, mSpa],
68+
lowerBound: Tensor(Float(-1.0)/Float(embeddingSize)),
69+
upperBound: Tensor(Float(1.0)/Float(embeddingSize)))
70+
return Embedding(embeddings: weights)
71+
}
6572
self.interaction = interaction
6673
}
6774

@@ -80,8 +87,8 @@ public struct DLRM: Module {
8087
let denseEmbVec = mlpBottom(denseInput)
8188
let sparseEmbVecs = computeEmbeddings(sparseInputs: sparseInput,
8289
latentFactors: latentFactors)
83-
let topInput = Tensor(concatenating: sparseEmbVecs + [denseEmbVec],
84-
alongAxis: 1)
90+
let topInput = computeInteractions(
91+
denseEmbVec: denseEmbVec, sparseEmbVecs: sparseEmbVecs)
8592
let prediction = mlpTop(topInput)
8693

8794
// TODO: loss threshold clipping

‎Tests/RecommendationModelTests/DLRMTests.swift

+27-21
Original file line numberDiff line numberDiff line change
@@ -43,52 +43,58 @@ final class DLRMTests: XCTestCase {
4343
}
4444

4545
func testDLRMTraining() {
46-
let trainingSteps = 2000
46+
let trainingSteps = 400
4747
let nDense = 9
4848
let dimEmbed = 4
4949
let bottomMLPSize = [8, 4]
5050
let topMLPSize = [11, 4]
5151
let batchSize = 10
5252

53-
var model = DLRM(
54-
nDense: nDense,
55-
mSpa: dimEmbed,
56-
lnEmb: [10, 20],
57-
lnBot: bottomMLPSize,
58-
lnTop: topMLPSize)
59-
6053
func lossFunc(predicted: Tensor<Float>, labels: Tensor<Float>) -> Tensor<Float> {
6154
let difference = predicted - labels
6255
let squared = difference * difference
6356
return squared.sum()
6457
}
6558

66-
let trainingData = DLRMInput(dense: Tensor(ones: [batchSize, nDense]),
59+
let trainingData = DLRMInput(dense: Tensor(randomNormal: [batchSize, nDense]),
6760
sparse: [Tensor([7, 3, 1, 3, 1, 6, 7, 8, 9, 2]),
6861
Tensor([17, 13, 19, 0, 1, 6, 7, 8, 9, 10])])
6962
let labels = Tensor<Float>([1,0,0,1,1,1,0,1,0,1])
7063

71-
let optimizer = SGD(for: model, learningRate: 0.0015)
64+
// Sometimes DLRM on such a small dataset can get "stuck" in a bad initialization.
65+
// To ensure a reliable test, we give ourselves a few reinitializations.
66+
for attempt in 1...5 {
67+
var model = DLRM(
68+
nDense: nDense,
69+
mSpa: dimEmbed,
70+
lnEmb: [10, 20],
71+
lnBot: bottomMLPSize,
72+
lnTop: topMLPSize)
73+
let optimizer = SGD(for: model, learningRate: 0.1)
7274

73-
for step in 1...trainingSteps {
74-
let (loss, grads) = valueWithGradient(at: model) { model in
75-
lossFunc(predicted: model(trainingData), labels: labels)
76-
}
77-
if step % 100 == 0 {
78-
print(step, loss)
79-
if loss.scalarized() < 1e-7 {
80-
return // Success!
75+
for step in 0...trainingSteps {
76+
let (loss, grads) = valueWithGradient(at: model) { model in
77+
lossFunc(predicted: model(trainingData), labels: labels)
78+
}
79+
if step % 50 == 0 {
80+
print(step, loss)
81+
if round(model(trainingData)) == labels { return } // Success
82+
}
83+
if step > 300 && step % 50 == 0 {
84+
print("\n\n-----------------------------------------")
85+
print("Step: \(step), loss: \(loss)\nGrads:\n\(grads)\nModel:\n\(model)")
8186
}
87+
optimizer.update(&model, along: grads)
8288
}
83-
optimizer.update(&model, along: grads)
89+
print("Final model outputs (attempt: \(attempt)):\n\(model(trainingData))\nTarget:\n\(labels)")
8490
}
85-
XCTFail("Could not perfectly fit a single mini-batch.")
91+
XCTFail("Could not perfectly fit a single mini-batch after 5 reinitializations.")
8692
}
8793
}
8894

8995
extension DLRMTests {
9096
static var allTests = [
9197
("testDLRM", testDLRM),
92-
("testDLRMTraining", testDLRMTraining)
98+
("testDLRMTraining", testDLRMTraining),
9399
]
94100
}

0 commit comments

Comments
 (0)