Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Implement the DLRM model #344

Merged
merged 8 commits into from
Mar 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Models/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
add_subdirectory(ImageClassification)
add_subdirectory(Recommendation)
add_subdirectory(Text)
18 changes: 18 additions & 0 deletions Models/Recommendation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
add_library(RecommendationModels
DLRM.swift
MLP.swift)
set_target_properties(RecommendationModels PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES ${CMAKE_Swift_MODULE_DIRECTORY})
target_compile_options(RecommendationModels PRIVATE
$<$<BOOL:${BUILD_TESTING}>:-enable-testing>)


install(TARGETS RecommendationModels
ARCHIVE DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
LIBRARY DESTINATION lib/swift/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>
RUNTIME DESTINATION bin)
get_swift_host_arch(swift_arch)
install(FILES
$<TARGET_PROPERTY:RecommendationModels,Swift_MODULE_DIRECTORY>/RecommendationModels.swiftdoc
$<TARGET_PROPERTY:RecommendationModels,Swift_MODULE_DIRECTORY>/RecommendationModels.swiftmodule
DESTINATION lib/swift$<$<NOT:$<BOOL:${BUILD_SHARED_LIBS}>>:_static>/$<LOWER_CASE:${CMAKE_SYSTEM_NAME}>/${swift_arch})
197 changes: 197 additions & 0 deletions Models/Recommendation/DLRM.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

/// The DLRM model is parameterized to support multiple ways of combining the latent spaces of the inputs.
public enum InteractionType {
/// Concatenate the tensors representing the latent spaces of the inputs together.
///
/// This operation is the fastest, but does not encode any higher-order feature interactions.
case concatenate

/// Compute the dot product of every input latent space with every other input latent space
/// and concatenate the results.
///
/// This computation encodes 2nd-order feature interactions.
///
/// If `selfInteraction` is true, 2nd-order self-interactions occur. If false,
/// self-interactions are excluded.
case dot(selfInteraction: Bool)
}

/// DLRM is the deep learning recommendation model and is used for recommendation tasks.
///
/// DLRM handles inputs that contain both sparse categorical data and numerical data.
/// Original Paper:
/// "Deep Learning Recommendation Model for Personalization and Recommendation Systems"
/// Maxim Naumov et al.
/// https://arxiv.org/pdf/1906.00091.pdf
public struct DLRM: Module {

public var mlpBottom: MLP
public var mlpTop: MLP
public var latentFactors: [Embedding<Float>]
@noDerivative public let nDense: Int
@noDerivative public let interaction: InteractionType

/// Randomly initialize a DLRM model from the given hyperparameters.
///
/// - Parameters:
/// - nDense: The number of continuous or dense inputs for each example.
/// - mSpa: The "width" of all embedding tables.
/// - lnEmb: Defines the "heights" of each of each embedding table.
/// - lnBot: The size of the hidden layers in the bottom MLP.
/// - lnTop: The size of the hidden layers in the top MLP.
/// - interaction: The type of interactions between the hidden features.
public init(nDense: Int, mSpa: Int, lnEmb: [Int], lnBot: [Int], lnTop: [Int],
interaction: InteractionType = .concatenate) {
self.nDense = nDense
mlpBottom = MLP(dims: [nDense] + lnBot)
let topInput = lnEmb.count * mSpa + lnBot.last!
mlpTop = MLP(dims: [topInput] + lnTop + [1], sigmoidLastLayer: true)
latentFactors = lnEmb.map { embeddingSize -> Embedding<Float> in
// Use a random uniform initialization to match the reference implementation.
let weights = Tensor<Float>(
randomUniform: [embeddingSize, mSpa],
lowerBound: Tensor(Float(-1.0)/Float(embeddingSize)),
upperBound: Tensor(Float(1.0)/Float(embeddingSize)))
return Embedding(embeddings: weights)
}
self.interaction = interaction
}

@differentiable
public func callAsFunction(_ input: DLRMInput) -> Tensor<Float> {
callAsFunction(denseInput: input.dense, sparseInput: input.sparse)
}

@differentiable(wrt: self)
public func callAsFunction(
denseInput: Tensor<Float>,
sparseInput: [Tensor<Int32>]
) -> Tensor<Float> {
precondition(denseInput.shape.last! == nDense)
precondition(sparseInput.count == latentFactors.count)
let denseEmbVec = mlpBottom(denseInput)
let sparseEmbVecs = computeEmbeddings(sparseInputs: sparseInput,
latentFactors: latentFactors)
let topInput = computeInteractions(
denseEmbVec: denseEmbVec, sparseEmbVecs: sparseEmbVecs)
let prediction = mlpTop(topInput)

// TODO: loss threshold clipping
return prediction.reshaped(to: [-1])
}

@differentiable(wrt: (denseEmbVec, sparseEmbVecs))
public func computeInteractions(
denseEmbVec: Tensor<Float>,
sparseEmbVecs: [Tensor<Float>]
) -> Tensor<Float> {
switch self.interaction {
case .concatenate:
return Tensor(concatenating: sparseEmbVecs + [denseEmbVec], alongAxis: 1)
case let .dot(selfInteraction):
let batchSize = denseEmbVec.shape[0]
let allEmbeddings = Tensor(
concatenating: sparseEmbVecs + [denseEmbVec],
alongAxis: 1).reshaped(to: [batchSize, -1, denseEmbVec.shape[1]])
// Use matmul to efficiently compute all dot products
let higherOrderInteractions = matmul(
allEmbeddings, allEmbeddings.transposed(permutation: 0, 2, 1))
// Gather relevant indices
let flattenedHigherOrderInteractions = higherOrderInteractions.reshaped(
to: [batchSize, -1])
let desiredIndices = makeIndices(
n: Int32(higherOrderInteractions.shape[1]),
selfInteraction: selfInteraction)
let desiredInteractions =
flattenedHigherOrderInteractions.batchGathering(atIndices: desiredIndices)
return Tensor(concatenating: [desiredInteractions, denseEmbVec], alongAxis: 1)
}
}
}

/// DLRMInput represents the categorical and numerical input
public struct DLRMInput {

/// dense represents a mini-batch of continuous inputs.
///
/// It should have shape `[batchSize, continuousCount]`
public let dense: Tensor<Float>

/// sparse represents the categorical inputs to the mini-batch.
///
/// The array should be of length `numCategoricalInputs`.
/// Each tensor within the array should be a vector of length `batchSize`.
public let sparse: [Tensor<Int32>]
}

// Work-around for lack of inout support
fileprivate func computeEmbeddings(
sparseInputs: [Tensor<Int32>],
latentFactors: [Embedding<Float>]
) -> [Tensor<Float>] {
var sparseEmbVecs: [Tensor<Float>] = []
for i in 0..<sparseInputs.count {
sparseEmbVecs.append(latentFactors[i](sparseInputs[i]))
}
return sparseEmbVecs
}

// TODO: remove computeEmbeddingsVJP once inout differentiation is supported!
@derivative(of: computeEmbeddings)
fileprivate func computeEmbeddingsVJP(
sparseInput: [Tensor<Int32>],
latentFactors: [Embedding<Float>]
) -> (
value: [Tensor<Float>],
pullback: (Array<Tensor<Float>>.TangentVector) -> Array<Embedding<Float>>.TangentVector
) {
var sparseEmbVecs = [Tensor<Float>]()
var pullbacks = [(Tensor<Float>.TangentVector) -> Embedding<Float>.TangentVector]()
for i in 0..<sparseInput.count {
let (fwd, pullback) = valueWithPullback(at: latentFactors[i]) { $0(sparseInput[i]) }
sparseEmbVecs.append(fwd)
pullbacks.append(pullback)
}
return (
value: sparseEmbVecs,
pullback: { v in
let arr = zip(v, pullbacks).map { $0.1($0.0) }
return Array.DifferentiableView(arr)
}
)
}

/// Compute indices for the upper triangle (optionally including the diagonal) in a flattened representation.
///
/// - Parameter n: Size of the square matrix.
/// - Parameter selfInteraction: Include the diagonal iff selfInteraction is true.
fileprivate func makeIndices(n: Int32, selfInteraction: Bool) -> Tensor<Int32> {
let interactionOffset: Int32
if selfInteraction {
interactionOffset = 0
} else {
interactionOffset = 1
}
var result = [Int32]()
for i in 0..<n {
for j in (i + interactionOffset)..<n {
result.append(i*n + j)
}
}
return Tensor(result)
}
46 changes: 46 additions & 0 deletions Models/Recommendation/MLP.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import TensorFlow

/// MLP is a multi-layer perceptron and is used as a component of the DLRM model
public struct MLP: Layer {
public var blocks: [Dense<Float>] = []

/// Randomly initializes a new multilayer perceptron from the given hyperparameters.
///
/// - Parameter dims: Dims represents the size of the input, hidden layers, and output of the
/// multi-layer perceptron.
/// - Parameter sigmoidLastLayer: if `true`, use a `sigmoid` activation function for the last layer,
/// `relu` otherwise.
init(dims: [Int], sigmoidLastLayer: Bool = false) {
for i in 0..<(dims.count-1) {
if sigmoidLastLayer && i == dims.count - 2 {
blocks.append(Dense(inputSize: dims[i], outputSize: dims[i+1], activation: sigmoid))
} else {
blocks.append(Dense(inputSize: dims[i], outputSize: dims[i+1], activation: relu))
}
}
}

@differentiable
public func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let blocksReduced = blocks.differentiableReduce(input) { last, layer in
layer(last)
}
return blocksReduced
}

}

3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ let package = Package(
.library(name: "Datasets", targets: ["Datasets"]),
.library(name: "ModelSupport", targets: ["ModelSupport"]),
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
.library(name: "RecommendationModels", targets: ["RecommendationModels"]),
.library(name: "TextModels", targets: ["TextModels"]),
.executable(name: "Benchmarks", targets: ["Benchmarks"]),
.executable(name: "VGG-Imagewoof", targets: ["VGG-Imagewoof"]),
Expand Down Expand Up @@ -41,6 +42,7 @@ let package = Package(
.target(name: "ModelSupport", dependencies: ["SwiftProtobuf"], path: "Support"),
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
.target(name: "TextModels", dependencies: ["Datasets"], path: "Models/Text"),
.target(name: "RecommendationModels", path: "Models/Recommendation"),
.target(
name: "Autoencoder1D", dependencies: ["Datasets", "ModelSupport"],
path: "Autoencoder/Autoencoder1D"),
Expand Down Expand Up @@ -75,6 +77,7 @@ let package = Package(
name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo", sources: ["main.swift"]),
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
.testTarget(name: "ImageClassificationTests", dependencies: ["ImageClassificationModels"]),
.testTarget(name: "RecommendationModelTests", dependencies: ["RecommendationModels"]),
.testTarget(name: "DatasetsTests", dependencies: ["Datasets", "TextModels"]),
.target(
name: "TransformerDemo", dependencies: ["TextModels"],
Expand Down
2 changes: 2 additions & 0 deletions Tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(DatasetsTests)
add_subdirectory(FastStyleTransferTests)
add_subdirectory(ImageClassificationTests)
add_subdirectory(MiniGoTests)
add_subdirectory(RecommendationModelTests)
add_subdirectory(SupportTests)
add_subdirectory(TextTests)

Expand All @@ -14,6 +15,7 @@ target_link_libraries(ModelTests PRIVATE
FastStyleTransferTests
ImageClassificationTests
MiniGoTests
RecommendationModelTests
SupportTests
TextTests
XCTest)
Expand Down
2 changes: 2 additions & 0 deletions Tests/LinuxMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import DatasetsTests
import FastStyleTransferTests
import ImageClassificationTests
import MiniGoTests
import RecommendationModelTests
import SupportTests
import TextTests
import XCTest
Expand All @@ -13,6 +14,7 @@ tests += MiniGoTests.allTests()
tests += FastStyleTransferTests.allTests()
tests += DatasetsTests.allTests()
tests += CheckpointTests.allTests()
tests += RecommendationModelTests.allTests()
tests += SupportTests.allTests()
tests += TextTests.allTests()
XCTMain(tests)
8 changes: 8 additions & 0 deletions Tests/RecommendationModelTests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_library(RecommendationModelTests
DLRMTests.swift
XCTestManifests.swift)
set_target_properties(RecommendationModelTests PROPERTIES
RUNTIME_OUTPUT_DIRECTORY $<TARGET_FILE_DIR:ModelTests>
LIBRARY_OUTPUT_DIRECTORY $<TARGET_FILE_DIR:ModelTests>)
target_link_libraries(RecommendationModelTests PUBLIC
RecommendationModels)
Loading