Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

[WIP] Add Attention is All you need transformer and Translation example #422

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0ddb147
added base translation model and package
andr-ec Feb 18, 2020
847480b
WIP for tokenizers
andr-ec Mar 13, 2020
b1a20dd
added full text preprocessing and started creating training loop
andr-ec Mar 13, 2020
4d78760
working attention
andr-ec Mar 17, 2020
7793186
working model on forward pass
andr-ec Mar 17, 2020
7950233
working forwards pass
andr-ec Mar 18, 2020
046b55e
cleaning up code
andr-ec Mar 18, 2020
7488aed
comments
andr-ec Mar 18, 2020
12f3975
working training loop
andr-ec Mar 20, 2020
398a416
updated training step
andr-ec Mar 20, 2020
101a444
to gpu
andr-ec Mar 20, 2020
cac0e88
removed python import
andr-ec Mar 20, 2020
77554c5
added foundation import
andr-ec Mar 20, 2020
312bfda
fixed import in wrong file
andr-ec Mar 20, 2020
e7d38bb
reduced batch size
andr-ec Mar 20, 2020
4345737
added package to allow import of translation models
andr-ec Mar 20, 2020
6e0930d
updated batch and sequence length to defualts
andr-ec Mar 20, 2020
958531e
updated learning rate to that in paper
andr-ec Mar 20, 2020
444f53d
made required methods public, fixed vocab to lookup correct values
andr-ec Mar 20, 2020
d959992
added requirements for greedy decoding
andr-ec Mar 20, 2020
0f9039e
working greedy decoding, working ignoreIndex for padding, training lo…
andr-ec Mar 25, 2020
d6e1a57
moved custom crossentropy to utilities
andr-ec Mar 26, 2020
819e1ec
made softmax public
andr-ec Mar 26, 2020
3c3b674
cleaned up comments and code organization
andr-ec Mar 26, 2020
f7ba238
formatting
andr-ec Mar 26, 2020
a2c6787
added validation loop
andr-ec Mar 26, 2020
9234b64
reformatted to use dataset helpers, much more effecient with memory a…
andr-ec Mar 27, 2020
ee0e3dd
organized project structure and started using existing vocab
andr-ec Mar 27, 2020
0366a69
fix vocabulary loading and imports
andr-ec Mar 27, 2020
456bece
moved extensions, added <unk> token, added decode function
andr-ec Mar 27, 2020
c92bbdd
fixing encoding
andr-ec Mar 27, 2020
1250800
added initialization to many params
andr-ec Mar 27, 2020
0f018bc
fixing initializations
andr-ec Mar 27, 2020
2731c83
added init to activations
andr-ec Mar 27, 2020
9ac1672
removed init from attention
andr-ec Mar 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions Examples/Transformer-Translation/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//
// main.swift
// TranslationTransformer
//
// Created by Andre Carrera on 2/7/20.
// Copyright © 2020 Lambdo. All rights reserved.
//


import TensorFlow
import TranslationModels
import Foundation

struct WMTTranslationTask {
// https://nlp.stanford.edu/projects/nmt/
// WMT'14 English-German data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WMT == 💪!

private let trainGermanURL = URL(string: "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de")!
private let trainEnglishURL = URL(string: "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en")!
let directoryURL: URL
var trainData: [TextBatch]
var textProcessor: TextProcessor
var sourceVocabSize: Int {
textProcessor.sourceVocabulary.count
}
var targetVocabSize: Int {
textProcessor.targetVocabulary.count
}
var trainDataSize: Int {
trainData.count
}
init(taskDirectoryURL: URL, maxSequenceLength: Int, batchSize: Int) throws {
self.directoryURL = taskDirectoryURL.appendingPathComponent("Translation")
let dataURL = directoryURL.appendingPathComponent("data")

let trainGermanDataPath = dataURL.appendingPathExtension("source")
let trainEnglishDataPath = dataURL.appendingPathExtension("target")
print("downloading datasets")
try maybeDownload(from: trainGermanURL, to: trainGermanDataPath)
try maybeDownload(from: trainEnglishURL, to: trainEnglishDataPath)
print("loading datasets")
let loadedGerman = try WMTTranslationTask.load(fromFile: trainGermanDataPath)
let loadedEnglish = try WMTTranslationTask.load(fromFile: trainEnglishDataPath)

let tokenizer = BasicTokenizer()

print("preprocessing dataset")
self.textProcessor = TextProcessor(tokenizer: tokenizer, sourceVocabulary: .init(), targetVocabulary: .init())
self.trainData = textProcessor.preprocess(source: loadedGerman, target: loadedEnglish, maxSequenceLength: maxSequenceLength, batchSize: batchSize)
}

func getTrainIterator() -> IndexingIterator<[TextBatch]>{
self.trainData.shuffled().makeIterator()
}

static func load(fromFile fileURL: URL) throws -> [String] {
try Data(contentsOf: fileURL).withUnsafeBytes {
$0.split(separator: UInt8(ascii: "\n"))
.map { String(decoding: UnsafeRawBufferPointer(rebasing: $0), as: UTF8.self) }
}
}

mutating func update(model: inout TransformerModel, using optimizer: inout Adam<TransformerModel>, for batch: TextBatch) -> Float {
let labels = batch.targetTruth.reshaped(to: [-1])
let resultSize = batch.targetTruth.shape.last! * batch.targetTruth.shape.first!
let padIndex = Int32(textProcessor.targetVocabulary.id(forToken: "<blank>")!)
let result = withLearningPhase(.training) { () -> Float in
let (loss, grad) = valueWithGradient(at: model) {
softmaxCrossEntropy(logits: $0.generate(input: batch).reshaped(to: [resultSize, -1]), labels: labels,ignoreIndex: padIndex)
}
optimizer.update(&model, along: grad)
return loss.scalarized()
}
return result
}
}

let workspaceURL = URL(fileURLWithPath: "transformer", isDirectory: true,
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(),
isDirectory: true))

var translationTask = try WMTTranslationTask(taskDirectoryURL: workspaceURL, maxSequenceLength: 50, batchSize: 150)

var model = TransformerModel(sourceVocabSize: translationTask.sourceVocabSize, targetVocabSize: translationTask.targetVocabSize)

func greedyDecode(model: TransformerModel, input: TextBatch, maxLength: Int, startSymbol: Int32) -> Tensor<Int32> {
let memory = model.encode(input: input)
var ys = Tensor(repeating: startSymbol, shape: [1,1])
for _ in 0..<maxLength {
let decoderInput = TextBatch(tokenIds: input.tokenIds,
targetTokenIds: ys,
mask: input.mask,
targetMask: Tensor<Float>(subsequentMask(size: ys.shape[1])),
targetTruth: input.targetTruth,
tokenCount: input.tokenCount)
let out = model.decode(input: decoderInput, memory: memory)
let prob = model.generate(input: out[0...,-1])
let nextWord = Int32(prob.argmax().scalarized())
ys = Tensor(concatenating: [ys, Tensor(repeating: nextWord, shape: [1,1])], alongAxis: 1)
}
return ys
}

let epochs = 3
var optimizer = Adam.init(for: model, learningRate: 5e-4)
for epoch in 0..<epochs {
print("Start epoch \(epoch)")
var iterator = translationTask.getTrainIterator()
for step in 0..<translationTask.trainData.count {
let batch = withDevice(.cpu) { iterator.next()! }
let loss = translationTask.update(model: &model, using: &optimizer, for: batch)
print("current loss at step \(step): \(loss)")
}
}

// Test

let batch = translationTask.trainData[0]
let exampleIndex = 1
let source = TextBatch(tokenIds: batch.tokenIds[exampleIndex].expandingShape(at: 0),
targetTokenIds: batch.targetTokenIds[exampleIndex].expandingShape(at: 0),
mask: batch.mask[exampleIndex].expandingShape(at: 0),
targetMask: batch.targetMask[exampleIndex].expandingShape(at: 0),
targetTruth: batch.targetTruth[exampleIndex].expandingShape(at: 0),
tokenCount: batch.tokenCount)
let startSymbol = Int32(translationTask.textProcessor.targetVocabulary.id(forToken: "<s>")!)

Context.local.learningPhase = .inference

let out = greedyDecode(model: model, input: source, maxLength: 50, startSymbol: startSymbol)

func decode(tensor: Tensor<Float>, vocab: Vocabulary) -> String {
tensor.scalars.compactMap{ vocab.token(forId: Int($0)) }.joined(separator: " ")
// todo use a loop and break on </s>
}
Loading