Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Commit 713bb8d

Browse files
authored
Implement the DLRM model (#344)
* Implement the DLRM model * Respond to comments * Add dot interactions and MLP activations * Add comments to DLRM initializers to explain the hyperparameters. * Clean up tests. As part of ensuring the train test is reliable, I have matched the initialization of the reference implementation for the Embedding layers. Additionally, the DLRM model is susceptible to a "bad initialization" that doesn't perfectly memorize the single test minibatch. Although this is infrequent (~1 out of 50 test runs), I have modified the tests to randomly re-initialize 5 times, ensuring the test is approximately flaky with a probability of 3.2e-9 while still maintaining the quality of the test (e.g. testing random initialization, etc). Finally, instead of checking that loss drops below a particular value, the test checks that the accuracy is 100%. This results in a faster stopping condition, and thus the convergence test often runs in under 300ms on a laptop. * Add CMake build directives. * Switch from assert to precondition
1 parent 23b09c0 commit 713bb8d

File tree

10 files changed

+400
-0
lines changed

10 files changed

+400
-0
lines changed

Models/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(ImageClassification)
2+
add_subdirectory(Recommendation)
23
add_subdirectory(Text)

Models/Recommendation/CMakeLists.txt

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_library(RecommendationModels
2+
DLRM.swift
3+
MLP.swift)
4+
set_target_properties(RecommendationModels PROPERTIES
5+
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
6+
target_compile_options(RecommendationModels PRIVATE
7+
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)
8+
9+
10+
install(TARGETS RecommendationModels
11+
ARCHIVE DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
12+
LIBRARY DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
13+
RUNTIME DESTINATION bin)
14+
get_swift_host_arch(swift_arch)
15+
install(FILES
16+
$<TARGET_PROPERTY:RecommendationModels,Swift_MODULE_DIRECTORY>/RecommendationModels.swiftdoc
17+
$<TARGET_PROPERTY:RecommendationModels,Swift_MODULE_DIRECTORY>/RecommendationModels.swiftmodule
18+
DESTINATION lib/swift$<$<NOT:$<BOOL:${BUILD_SHARED_LIBS}>>:_static>/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>/${swift_arch})

Models/Recommendation/DLRM.swift

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
/// The DLRM model is parameterized to support multiple ways of combining the latent spaces of the inputs.
18+
public enum InteractionType {
19+
/// Concatenate the tensors representing the latent spaces of the inputs together.
20+
///
21+
/// This operation is the fastest, but does not encode any higher-order feature interactions.
22+
case concatenate
23+
24+
/// Compute the dot product of every input latent space with every other input latent space
25+
/// and concatenate the results.
26+
///
27+
/// This computation encodes 2nd-order feature interactions.
28+
///
29+
/// If `selfInteraction` is true, 2nd-order self-interactions occur. If false,
30+
/// self-interactions are excluded.
31+
case dot(selfInteraction: Bool)
32+
}
33+
34+
/// DLRM is the deep learning recommendation model and is used for recommendation tasks.
35+
///
36+
/// DLRM handles inputs that contain both sparse categorical data and numerical data.
37+
/// Original Paper:
38+
/// "Deep Learning Recommendation Model for Personalization and Recommendation Systems"
39+
/// Maxim Naumov et al.
40+
/// https://arxiv.org/pdf/1906.00091.pdf
41+
public struct DLRM: Module {
42+
43+
public var mlpBottom: MLP
44+
public var mlpTop: MLP
45+
public var latentFactors: [Embedding<Float>]
46+
@noDerivative public let nDense: Int
47+
@noDerivative public let interaction: InteractionType
48+
49+
/// Randomly initialize a DLRM model from the given hyperparameters.
50+
///
51+
/// - Parameters:
52+
/// - nDense: The number of continuous or dense inputs for each example.
53+
/// - mSpa: The "width" of all embedding tables.
54+
/// - lnEmb: Defines the "heights" of each of each embedding table.
55+
/// - lnBot: The size of the hidden layers in the bottom MLP.
56+
/// - lnTop: The size of the hidden layers in the top MLP.
57+
/// - interaction: The type of interactions between the hidden features.
58+
public init(nDense: Int, mSpa: Int, lnEmb: [Int], lnBot: [Int], lnTop: [Int],
59+
interaction: InteractionType = .concatenate) {
60+
self.nDense = nDense
61+
mlpBottom = MLP(dims: [nDense] + lnBot)
62+
let topInput = lnEmb.count * mSpa + lnBot.last!
63+
mlpTop = MLP(dims: [topInput] + lnTop + [1], sigmoidLastLayer: true)
64+
latentFactors = lnEmb.map { embeddingSize -> Embedding<Float> in
65+
// Use a random uniform initialization to match the reference implementation.
66+
let weights = Tensor<Float>(
67+
randomUniform: [embeddingSize, mSpa],
68+
lowerBound: Tensor(Float(-1.0)/Float(embeddingSize)),
69+
upperBound: Tensor(Float(1.0)/Float(embeddingSize)))
70+
return Embedding(embeddings: weights)
71+
}
72+
self.interaction = interaction
73+
}
74+
75+
@differentiable
76+
public func callAsFunction(_ input: DLRMInput) -> Tensor<Float> {
77+
callAsFunction(denseInput: input.dense, sparseInput: input.sparse)
78+
}
79+
80+
@differentiable(wrt: self)
81+
public func callAsFunction(
82+
denseInput: Tensor<Float>,
83+
sparseInput: [Tensor<Int32>]
84+
) -> Tensor<Float> {
85+
precondition(denseInput.shape.last! == nDense)
86+
precondition(sparseInput.count == latentFactors.count)
87+
let denseEmbVec = mlpBottom(denseInput)
88+
let sparseEmbVecs = computeEmbeddings(sparseInputs: sparseInput,
89+
latentFactors: latentFactors)
90+
let topInput = computeInteractions(
91+
denseEmbVec: denseEmbVec, sparseEmbVecs: sparseEmbVecs)
92+
let prediction = mlpTop(topInput)
93+
94+
// TODO: loss threshold clipping
95+
return prediction.reshaped(to: [-1])
96+
}
97+
98+
@differentiable(wrt: (denseEmbVec, sparseEmbVecs))
99+
public func computeInteractions(
100+
denseEmbVec: Tensor<Float>,
101+
sparseEmbVecs: [Tensor<Float>]
102+
) -> Tensor<Float> {
103+
switch self.interaction {
104+
case .concatenate:
105+
return Tensor(concatenating: sparseEmbVecs + [denseEmbVec], alongAxis: 1)
106+
case let .dot(selfInteraction):
107+
let batchSize = denseEmbVec.shape[0]
108+
let allEmbeddings = Tensor(
109+
concatenating: sparseEmbVecs + [denseEmbVec],
110+
alongAxis: 1).reshaped(to: [batchSize, -1, denseEmbVec.shape[1]])
111+
// Use matmul to efficiently compute all dot products
112+
let higherOrderInteractions = matmul(
113+
allEmbeddings, allEmbeddings.transposed(permutation: 0, 2, 1))
114+
// Gather relevant indices
115+
let flattenedHigherOrderInteractions = higherOrderInteractions.reshaped(
116+
to: [batchSize, -1])
117+
let desiredIndices = makeIndices(
118+
n: Int32(higherOrderInteractions.shape[1]),
119+
selfInteraction: selfInteraction)
120+
let desiredInteractions =
121+
flattenedHigherOrderInteractions.batchGathering(atIndices: desiredIndices)
122+
return Tensor(concatenating: [desiredInteractions, denseEmbVec], alongAxis: 1)
123+
}
124+
}
125+
}
126+
127+
/// DLRMInput represents the categorical and numerical input
128+
public struct DLRMInput {
129+
130+
/// dense represents a mini-batch of continuous inputs.
131+
///
132+
/// It should have shape `[batchSize, continuousCount]`
133+
public let dense: Tensor<Float>
134+
135+
/// sparse represents the categorical inputs to the mini-batch.
136+
///
137+
/// The array should be of length `numCategoricalInputs`.
138+
/// Each tensor within the array should be a vector of length `batchSize`.
139+
public let sparse: [Tensor<Int32>]
140+
}
141+
142+
// Work-around for lack of inout support
143+
fileprivate func computeEmbeddings(
144+
sparseInputs: [Tensor<Int32>],
145+
latentFactors: [Embedding<Float>]
146+
) -> [Tensor<Float>] {
147+
var sparseEmbVecs: [Tensor<Float>] = []
148+
for i in 0..<sparseInputs.count {
149+
sparseEmbVecs.append(latentFactors[i](sparseInputs[i]))
150+
}
151+
return sparseEmbVecs
152+
}
153+
154+
// TODO: remove computeEmbeddingsVJP once inout differentiation is supported!
155+
@derivative(of: computeEmbeddings)
156+
fileprivate func computeEmbeddingsVJP(
157+
sparseInput: [Tensor<Int32>],
158+
latentFactors: [Embedding<Float>]
159+
) -> (
160+
value: [Tensor<Float>],
161+
pullback: (Array<Tensor<Float>>.TangentVector) -> Array<Embedding<Float>>.TangentVector
162+
) {
163+
var sparseEmbVecs = [Tensor<Float>]()
164+
var pullbacks = [(Tensor<Float>.TangentVector) -> Embedding<Float>.TangentVector]()
165+
for i in 0..<sparseInput.count {
166+
let (fwd, pullback) = valueWithPullback(at: latentFactors[i]) { $0(sparseInput[i]) }
167+
sparseEmbVecs.append(fwd)
168+
pullbacks.append(pullback)
169+
}
170+
return (
171+
value: sparseEmbVecs,
172+
pullback: { v in
173+
let arr = zip(v, pullbacks).map { $0.1($0.0) }
174+
return Array.DifferentiableView(arr)
175+
}
176+
)
177+
}
178+
179+
/// Compute indices for the upper triangle (optionally including the diagonal) in a flattened representation.
180+
///
181+
/// - Parameter n: Size of the square matrix.
182+
/// - Parameter selfInteraction: Include the diagonal iff selfInteraction is true.
183+
fileprivate func makeIndices(n: Int32, selfInteraction: Bool) -> Tensor<Int32> {
184+
let interactionOffset: Int32
185+
if selfInteraction {
186+
interactionOffset = 0
187+
} else {
188+
interactionOffset = 1
189+
}
190+
var result = [Int32]()
191+
for i in 0..<n {
192+
for j in (i + interactionOffset)..<n {
193+
result.append(i*n + j)
194+
}
195+
}
196+
return Tensor(result)
197+
}

Models/Recommendation/MLP.swift

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import TensorFlow
16+
17+
/// MLP is a multi-layer perceptron and is used as a component of the DLRM model
18+
public struct MLP: Layer {
19+
public var blocks: [Dense<Float>] = []
20+
21+
/// Randomly initializes a new multilayer perceptron from the given hyperparameters.
22+
///
23+
/// - Parameter dims: Dims represents the size of the input, hidden layers, and output of the
24+
/// multi-layer perceptron.
25+
/// - Parameter sigmoidLastLayer: if `true`, use a `sigmoid` activation function for the last layer,
26+
/// `relu` otherwise.
27+
init(dims: [Int], sigmoidLastLayer: Bool = false) {
28+
for i in 0..<(dims.count-1) {
29+
if sigmoidLastLayer && i == dims.count - 2 {
30+
blocks.append(Dense(inputSize: dims[i], outputSize: dims[i+1], activation: sigmoid))
31+
} else {
32+
blocks.append(Dense(inputSize: dims[i], outputSize: dims[i+1], activation: relu))
33+
}
34+
}
35+
}
36+
37+
@differentiable
38+
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
39+
let blocksReduced = blocks.differentiableReduce(input) { last, layer in
40+
layer(last)
41+
}
42+
return blocksReduced
43+
}
44+
45+
}
46+

Package.swift

+3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ let package = Package(
1313
.library(name: "Datasets", targets: ["Datasets"]),
1414
.library(name: "ModelSupport", targets: ["ModelSupport"]),
1515
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
16+
.library(name: "RecommendationModels", targets: ["RecommendationModels"]),
1617
.library(name: "TextModels", targets: ["TextModels"]),
1718
.executable(name: "Benchmarks", targets: ["Benchmarks"]),
1819
.executable(name: "VGG-Imagewoof", targets: ["VGG-Imagewoof"]),
@@ -41,6 +42,7 @@ let package = Package(
4142
.target(name: "ModelSupport", dependencies: ["SwiftProtobuf"], path: "Support"),
4243
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
4344
.target(name: "TextModels", dependencies: ["Datasets"], path: "Models/Text"),
45+
.target(name: "RecommendationModels", path: "Models/Recommendation"),
4446
.target(
4547
name: "Autoencoder1D", dependencies: ["Datasets", "ModelSupport"],
4648
path: "Autoencoder/Autoencoder1D"),
@@ -75,6 +77,7 @@ let package = Package(
7577
name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo", sources: ["main.swift"]),
7678
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
7779
.testTarget(name: "ImageClassificationTests", dependencies: ["ImageClassificationModels"]),
80+
.testTarget(name: "RecommendationModelTests", dependencies: ["RecommendationModels"]),
7881
.testTarget(name: "DatasetsTests", dependencies: ["Datasets", "TextModels"]),
7982
.target(
8083
name: "GPT2-Inference", dependencies: ["TextModels"],

Tests/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(DatasetsTests)
33
add_subdirectory(FastStyleTransferTests)
44
add_subdirectory(ImageClassificationTests)
55
add_subdirectory(MiniGoTests)
6+
add_subdirectory(RecommendationModelTests)
67
add_subdirectory(SupportTests)
78
add_subdirectory(TextTests)
89

@@ -14,6 +15,7 @@ target_link_libraries(ModelTests PRIVATE
1415
FastStyleTransferTests
1516
ImageClassificationTests
1617
MiniGoTests
18+
RecommendationModelTests
1719
SupportTests
1820
TextTests
1921
XCTest)

Tests/LinuxMain.swift

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import DatasetsTests
33
import FastStyleTransferTests
44
import ImageClassificationTests
55
import MiniGoTests
6+
import RecommendationModelTests
67
import SupportTests
78
import TextTests
89
import XCTest
@@ -13,6 +14,7 @@ tests += MiniGoTests.allTests()
1314
tests += FastStyleTransferTests.allTests()
1415
tests += DatasetsTests.allTests()
1516
tests += CheckpointTests.allTests()
17+
tests += RecommendationModelTests.allTests()
1618
tests += SupportTests.allTests()
1719
tests += TextTests.allTests()
1820
XCTMain(tests)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_library(RecommendationModelTests
2+
DLRMTests.swift
3+
XCTestManifests.swift)
4+
set_target_properties(RecommendationModelTests PROPERTIES
5+
RUNTIME_OUTPUT_DIRECTORY $<TARGET_FILE_DIR:ModelTests>
6+
LIBRARY_OUTPUT_DIRECTORY $<TARGET_FILE_DIR:ModelTests>)
7+
target_link_libraries(RecommendationModelTests PUBLIC
8+
RecommendationModels)

0 commit comments

Comments
 (0)