Skip to content
This repository was archived by the owner on Feb 13, 2025. It is now read-only.

Initial incorporation of a general training loop #586

Merged
merged 14 commits into from
Jun 17, 2020

Conversation

BradLarson
Copy link
Contributor

@BradLarson BradLarson commented Jun 6, 2020

This is the initial incorporation of a general callback-based training loop, originally designed by @sgugger and proposed as the DifferentiableStep option here. As a first step, the following models have been converted to use this new training loop in place of the previous custom loop:

  • LeNet-MNIST
  • ResNet-CIFAR10
  • MobileNetV1-Imagenette
  • MobileNetV2-Imagenette

An initial set of callbacks have been provided that draw an animated progress bar on the console during training, and display the average loss and top-1 classification accuracy. These metric updates can either be continuous during training and validation, or can appear only at the end of an epoch (this is a performance option, because currently training will slow by up to 30% if continuous updates are enabled). Which metrics to display, if any, are also configurable.

By default, X10 is used where available for training models, and this loop fully supports X10 or eager mode devices.

As a next step, all but one or two classification examples will be reworked to use this loop, and timing functionality will be introduced to have this be the default loop within our benchmarks.

This pull request is now ready for review.

@BradLarson BradLarson marked this pull request as ready for review June 9, 2020 02:03
@BradLarson BradLarson requested a review from saeta June 9, 2020 02:03
@shabalind shabalind requested review from xihui-wu and dabrahams June 10, 2020 17:31
@saeta saeta self-assigned this Jun 10, 2020
Copy link
Contributor

@saeta saeta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @BradLarson for getting this together. Looking at the diff for Examples/LeNet-MNIST/main.swift shows how much of a simplification and cleanup this will be. I'm so very excited. :-)

var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
model.move(to: device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I've taken so long to respond to this PR. In general, I really like the design that you @BradLarson, @sgugger and @dabrahams have put together. But the one thing that makes me unhappy is this. It seems way too easy to accidentally re-use the original model variable and get the un-trained model after calling fit. I've been meaning to play around with a few different alternative arrangements for the training loop API, but rather than keep delaying responding any longer, I'll write some sketches here (and then it can be a race to see who gets to implementing them first!) and/or folks can tell me why they're silly ideas. :-D

Alternative 1: Capture the model building process within the construction of the training loop. I'm imagining something like:

var dataset = // ...
var trainingLoop = TrainingLoop(training: dataset) { dataset in
  let model = Model(dataset.inputHeight, dataset.inputWidth)  // Or something...
  let optimizer = SGD(for: model)
  return (model, optimizer)
}

Advantages of something in this direction are: straight forward code, has nice type inference, and has the order of operations "right". (Models don't exist in a vacuum, but are instead applied to data, and this pattern allows users to easily derive hyperparameters for the model from the data.)

Things I don't like about this approach: (it's hard to put my finger on it, so I'm saying wishy-washy things here) if the TrainingLoop owns the model (instead of the user's code), this feels more like a framework than a library. (Of course, a built-in training loop is a framework in the sense that it's inversion of control. But I think we should be careful about taking only the "minimal" amount of control.)

Alternate 2: Take model as inout (or mutating). In this direction, I'm thinking something like:

try! trainingLoop.fit(&model, epochs: 10)

We don't even need to take model in the TrainingLoop initializer (even for type inference purposes), because we can get the model type from the optimizer's associated type. (Aside: right now optimizers are reference types, but I think we should revisit that design choice. When doing hierarchical learning or certain forms of meta-learning, you want to take derivatives through the optimizer too.)

Of course, an alternate variation of this could be model.train(using: &trainingLoop, epochs: 10) (or even model.train(using: trainingLoop, epochs: 10) if we decide that training loops should be reference types. I haven't thought this through, so this could be a silly idea). I think I somewhat lean against this, because it "reads backwards" in some sense, but wanted to mention it for completeness.

One thing to note: we will have to change the callback signature to take the model as inout in addition to the training loop itself, due to the training loop struct no longer having a copy of the model.

I know that @dabrahams and @sgugger thought about things for a while, and perhaps they have already explored these alternatives and have good reasons why these are silly ideas...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a really good point, I'd totally missed that in the design. A training loop is somewhat useless if the model you have access to never changes.

I've changed the loop to thread the model through the fit() function instead of being a property of the TrainingLoop (your Alternative 2). This works fairly well, and makes sense to me when I see it in the model examples. I've verified that the local model ends up trained after fit() is done by running validation and running another fit() with model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is better this way. And not, @saeta it was not something that had been thought a lot through ;-) I did suggest passing a modelInit insteadof a model which was your first option, but this is better.

Note @BradLarson that the model is not accessible anymore in the callbacks though, so it should either be stored in temporary data of the training loop (like lastLoss and the others) or passed along inout to each callback. I'd personally prefer the first way, but the second should work too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the primary thing happening in the training loop is updating the model, shouldn't the interface be a mutating method on the model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the primary thing happening in the training loop is updating the model, shouldn't the interface be a mutating method on the model?

Yeah, I was wondering about that previously (e.g. when I suggested model.train(using: &trainingLoop, epochs: 10)). (I said it "felt backwards", but perhaps I'm too used to the current design?) I think we should think carefully about whether we should call this model.fit instead of my proposed train. fit is well-known from Keras and Scikit-Learn, which is both good and bad. It's good in the sense that it's familiar to people, but my concern is that our proposed functionality is something quite different from (and more powerful than) Keras or Scikit-Learn.

the model is not accessible anymore in the callbacks though

@sgugger is absolutely right that we definitely need to solve this. If we store it in a temporary in the training loop, we should be careful to avoid forcing a copy. (We can make the model accessible from the training loop data structure by using withUnsafePointer(to:_:) if we don't want pass it inout to each callback.)

As I mentioned previously, I do think that we should revisit our design of optimizers as reference types. Assuming we do switch optimizers to value types, we should ensure our proposed API works naturally for those as well. (A future extension could be making the training loop itself differentiable, but I think that would put undue constraints on the training loop for a somewhat niche meta-learning use case. Of course, this just underscores the need in ensuring that S4TF is easy to use without the in-built training loop.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keras, Scikit-Learn and SparkML all use the model.fit(dataset, training settings...) pattern. Literally that means train to fit the model into the dataset with the training settings.

So IMO: if we want to keep using 'fit', then do model.fit(using: trainingLoop), if we want to use 'train', then trainingLoop.train(model: &model). -- Also, why not incorporate 'epochs' into TrainingLoop?


let epochCount = 12
let batchSize = 128

// Until https://github.com/tensorflow/swift-models/issues/588 is fixed, default to the eager-mode
// Until https://github.com/tensorflow/swift-apis/issues/993 is fixed, default to the eager-mode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: I think we should eventually move this behavior into a (default) training loop callback so users don't ever get tripped up on this. (Of course, the right long-term thing is to fix the underlying bug, but it's not the highest priority for now.)

Copy link
Contributor Author

@BradLarson BradLarson Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I moved the migration of the model and optimizer onto a device into the training loop, and added an on: parameter to the fit() function, because I can see the sense of having where training occurs be something associated with the loop. This further simplifies the user-facing training code.

However, the one potential complication this introduces is that the dataset needs to provide tensors on the same device that the model and optimizer will reside on or you'll hit a runtime error. For the classification datasets, this is provided as a parameter, but this is something the user could forget. The situation with these latest changes is better than it was, because they only need to specify the device on the dataset and loop, rather than remembering to copy the model and optimizer as well, but there might be more we can do to help.

For workaround of using XLA devices only non non-macOS platforms, we might still want that to be done on a per-model basis because not all models trigger a crash on macOS with X10 and several are even faster with X10 on macOS. I've left that outside of the loop for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I moved the migration of the model and optimizer onto a device into the training loop, and added an on: parameter to the fit() function

Note: I think the ideal situation we should aim for is for users to not have to specify the device, and for the training loop to make the right thing happen automatically by default. (Of course, we should absolutely allow users to specify a device if they would like.) That said, I'm okay with this for now to let us continue to iterate & experiment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I think the ideal situation we should aim for is for users to not have to specify the device

@saeta Sorry, I have a newbie question - would this allow for distributed/multi-GPU/TPU training or even hybrid (some on CPU, some on accelerators, if that’s a thing)? Or maybe it can be a separate option for production workloads, so the devops is not a concern for ML/data scientists 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great question @8bitmp3 , thanks for raising it. I had previously only been thinking about how we could auto-detect CPU, local GPU(s), and TPUs (both local & slices), which only need local information to auto-detect. But you're absolutely right that distributed GPU requires some additional information to be provided in order to get this to work, so this would be insufficient. Thanks for raising this design question!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@8bitmp3 - The multi-device training case is something that this definitely should encapsulate at some point, and @xihui-wu has some ideas about that. This initial design is oriented towards a single device, but it's our intention to iterate on that to eventually add multi-device support. We currently don't have multi-GPU support in the XLA backend (but it's being worked on) and have examples of multi-device training support for TPUs within swift-apis.

I like the idea of the standard training loop encapsulating best practices with automatic device selection, as long as we can easily override that device. Our current situation with macOS failing on X10 for many models, but better for some, complicates the device selection.

Copy link
Contributor

@8bitmp3 8bitmp3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BradLarson. One question from me below if that's Ok.

validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @BradLarson This is a sight for sore eyes, more user friendly/"Keras"-like. Although I appreciate the longer old and "original" training loops.


let dataset = Imagewoof(batchSize: batchSize, inputSize: .full, outputSize: 224)
let dataset = Imagewoof(batchSize: 32, inputSize: .resized320, outputSize: 224, on: device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BradLarson Can you help understand why the batch size is hardcoded here? Great stuff btw ⚡️

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this from being defined in an above line to being provided at the call site. I figured it simplified the code, while still maintaining the same meaning as before. It seems to me to be more straightforward to have Imagewoof(batchSize: 32... than let batchSize = 32; Imagewoof(batchSize: batchSize...

If we had a number of other parameters that someone might want to configure, I can see grouping them at the top as variables, but in this case we just had batch size and epochs, so I was trying to simplify the code as much as possible.

What I'd really like to do is to combine the various classification examples into one central executable that uses ArgumentParser and command-line options to let you combine various permutations of models and datasets, with different parameters, so we don't have individual examples for each. Maybe one example to show a custom model and custom training loop, one to show a really simplified model and the standard training loop, and then the Swiss-army-knife program for training all classification models and datasets with arbitrary parameters.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants