This repository was archived by the owner on Feb 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 149
[WIP] Add Attention is All you need transformer and Translation example #422
Closed
Closed
Changes from 23 commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
0ddb147
added base translation model and package
andr-ec 847480b
WIP for tokenizers
andr-ec b1a20dd
added full text preprocessing and started creating training loop
andr-ec 4d78760
working attention
andr-ec 7793186
working model on forward pass
andr-ec 7950233
working forwards pass
andr-ec 046b55e
cleaning up code
andr-ec 7488aed
comments
andr-ec 12f3975
working training loop
andr-ec 398a416
updated training step
andr-ec 101a444
to gpu
andr-ec cac0e88
removed python import
andr-ec 77554c5
added foundation import
andr-ec 312bfda
fixed import in wrong file
andr-ec e7d38bb
reduced batch size
andr-ec 4345737
added package to allow import of translation models
andr-ec 6e0930d
updated batch and sequence length to defualts
andr-ec 958531e
updated learning rate to that in paper
andr-ec 444f53d
made required methods public, fixed vocab to lookup correct values
andr-ec d959992
added requirements for greedy decoding
andr-ec 0f9039e
working greedy decoding, working ignoreIndex for padding, training lo…
andr-ec d6e1a57
moved custom crossentropy to utilities
andr-ec 819e1ec
made softmax public
andr-ec 3c3b674
cleaned up comments and code organization
andr-ec f7ba238
formatting
andr-ec a2c6787
added validation loop
andr-ec 9234b64
reformatted to use dataset helpers, much more effecient with memory a…
andr-ec ee0e3dd
organized project structure and started using existing vocab
andr-ec 0366a69
fix vocabulary loading and imports
andr-ec 456bece
moved extensions, added <unk> token, added decode function
andr-ec c92bbdd
fixing encoding
andr-ec 1250800
added initialization to many params
andr-ec 0f018bc
fixing initializations
andr-ec 2731c83
added init to activations
andr-ec 9ac1672
removed init from attention
andr-ec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// | ||
// main.swift | ||
// TranslationTransformer | ||
// | ||
// Created by Andre Carrera on 2/7/20. | ||
// Copyright © 2020 Lambdo. All rights reserved. | ||
// | ||
|
||
|
||
import TensorFlow | ||
import TranslationModels | ||
import Foundation | ||
|
||
struct WMTTranslationTask { | ||
// https://nlp.stanford.edu/projects/nmt/ | ||
// WMT'14 English-German data | ||
private let trainGermanURL = URL(string: "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.de")! | ||
private let trainEnglishURL = URL(string: "https://nlp.stanford.edu/projects/nmt/data/wmt14.en-de/train.en")! | ||
let directoryURL: URL | ||
var trainData: [TextBatch] | ||
var textProcessor: TextProcessor | ||
var sourceVocabSize: Int { | ||
textProcessor.sourceVocabulary.count | ||
} | ||
var targetVocabSize: Int { | ||
textProcessor.targetVocabulary.count | ||
} | ||
var trainDataSize: Int { | ||
trainData.count | ||
} | ||
init(taskDirectoryURL: URL, maxSequenceLength: Int, batchSize: Int) throws { | ||
self.directoryURL = taskDirectoryURL.appendingPathComponent("Translation") | ||
let dataURL = directoryURL.appendingPathComponent("data") | ||
|
||
let trainGermanDataPath = dataURL.appendingPathExtension("source") | ||
let trainEnglishDataPath = dataURL.appendingPathExtension("target") | ||
print("downloading datasets") | ||
try maybeDownload(from: trainGermanURL, to: trainGermanDataPath) | ||
try maybeDownload(from: trainEnglishURL, to: trainEnglishDataPath) | ||
print("loading datasets") | ||
let loadedGerman = try WMTTranslationTask.load(fromFile: trainGermanDataPath) | ||
let loadedEnglish = try WMTTranslationTask.load(fromFile: trainEnglishDataPath) | ||
|
||
let tokenizer = BasicTokenizer() | ||
|
||
print("preprocessing dataset") | ||
self.textProcessor = TextProcessor(tokenizer: tokenizer, sourceVocabulary: .init(), targetVocabulary: .init()) | ||
self.trainData = textProcessor.preprocess(source: loadedGerman, target: loadedEnglish, maxSequenceLength: maxSequenceLength, batchSize: batchSize) | ||
} | ||
|
||
func getTrainIterator() -> IndexingIterator<[TextBatch]>{ | ||
self.trainData.shuffled().makeIterator() | ||
} | ||
|
||
static func load(fromFile fileURL: URL) throws -> [String] { | ||
try Data(contentsOf: fileURL).withUnsafeBytes { | ||
$0.split(separator: UInt8(ascii: "\n")) | ||
.map { String(decoding: UnsafeRawBufferPointer(rebasing: $0), as: UTF8.self) } | ||
} | ||
} | ||
|
||
mutating func update(model: inout TransformerModel, using optimizer: inout Adam<TransformerModel>, for batch: TextBatch) -> Float { | ||
let labels = batch.targetTruth.reshaped(to: [-1]) | ||
let resultSize = batch.targetTruth.shape.last! * batch.targetTruth.shape.first! | ||
let padIndex = Int32(textProcessor.targetVocabulary.id(forToken: "<blank>")!) | ||
let result = withLearningPhase(.training) { () -> Float in | ||
let (loss, grad) = valueWithGradient(at: model) { | ||
softmaxCrossEntropy(logits: $0.generate(input: batch).reshaped(to: [resultSize, -1]), labels: labels,ignoreIndex: padIndex) | ||
} | ||
optimizer.update(&model, along: grad) | ||
return loss.scalarized() | ||
} | ||
return result | ||
} | ||
} | ||
|
||
let workspaceURL = URL(fileURLWithPath: "transformer", isDirectory: true, | ||
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(), | ||
isDirectory: true)) | ||
|
||
var translationTask = try WMTTranslationTask(taskDirectoryURL: workspaceURL, maxSequenceLength: 50, batchSize: 150) | ||
|
||
var model = TransformerModel(sourceVocabSize: translationTask.sourceVocabSize, targetVocabSize: translationTask.targetVocabSize) | ||
|
||
func greedyDecode(model: TransformerModel, input: TextBatch, maxLength: Int, startSymbol: Int32) -> Tensor<Int32> { | ||
let memory = model.encode(input: input) | ||
var ys = Tensor(repeating: startSymbol, shape: [1,1]) | ||
for _ in 0..<maxLength { | ||
let decoderInput = TextBatch(tokenIds: input.tokenIds, | ||
targetTokenIds: ys, | ||
mask: input.mask, | ||
targetMask: Tensor<Float>(subsequentMask(size: ys.shape[1])), | ||
targetTruth: input.targetTruth, | ||
tokenCount: input.tokenCount) | ||
let out = model.decode(input: decoderInput, memory: memory) | ||
let prob = model.generate(input: out[0...,-1]) | ||
let nextWord = Int32(prob.argmax().scalarized()) | ||
ys = Tensor(concatenating: [ys, Tensor(repeating: nextWord, shape: [1,1])], alongAxis: 1) | ||
} | ||
return ys | ||
} | ||
|
||
let epochs = 3 | ||
var optimizer = Adam.init(for: model, learningRate: 5e-4) | ||
for epoch in 0..<epochs { | ||
print("Start epoch \(epoch)") | ||
var iterator = translationTask.getTrainIterator() | ||
for step in 0..<translationTask.trainData.count { | ||
let batch = withDevice(.cpu) { iterator.next()! } | ||
let loss = translationTask.update(model: &model, using: &optimizer, for: batch) | ||
print("current loss at step \(step): \(loss)") | ||
} | ||
} | ||
|
||
// Test | ||
|
||
let batch = translationTask.trainData[0] | ||
let exampleIndex = 1 | ||
let source = TextBatch(tokenIds: batch.tokenIds[exampleIndex].expandingShape(at: 0), | ||
targetTokenIds: batch.targetTokenIds[exampleIndex].expandingShape(at: 0), | ||
mask: batch.mask[exampleIndex].expandingShape(at: 0), | ||
targetMask: batch.targetMask[exampleIndex].expandingShape(at: 0), | ||
targetTruth: batch.targetTruth[exampleIndex].expandingShape(at: 0), | ||
tokenCount: batch.tokenCount) | ||
let startSymbol = Int32(translationTask.textProcessor.targetVocabulary.id(forToken: "<s>")!) | ||
|
||
Context.local.learningPhase = .inference | ||
|
||
let out = greedyDecode(model: model, input: source, maxLength: 50, startSymbol: startSymbol) | ||
|
||
func decode(tensor: Tensor<Float>, vocab: Vocabulary) -> String { | ||
tensor.scalars.compactMap{ vocab.token(forId: Int($0)) }.joined(separator: " ") | ||
// todo use a loop and break on </s> | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WMT == 💪!