Skip to content

Commit c80d0e2

Browse files
authored
[quantization] Initial commit (#280)
This example shows how to perform weight quantization on models trained in TensorFlow.js and the effects of the weight quantization on inference accuracy.
1 parent 4d74bec commit c80d0e2

22 files changed

+5736
-1
lines changed

README.md

+12-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ to another project.
9191
<td></td>
9292
<td>Building a tf.data.Dataset using a generator</td>
9393
<td>Regression</td>
94-
<td>Multilayer perceptron</td>
94+
<td>Browser</td>
9595
<td>Browser</td>
9696
<td>Layers</td>
9797
<td></td>
@@ -250,6 +250,17 @@ to another project.
250250
<td>Core (Ops)</td>
251251
<td></td>
252252
</tr>
253+
<tr>
254+
<td><a href="./quantization">quantization</a></td>
255+
<td></td>
256+
<td>Various</td>
257+
<td>Demonstrates the effect of post-training weight quantization</td>
258+
<td>Various</td>
259+
<td>Node.js</td>
260+
<td>Node.js</td>
261+
<td>Layers</td>
262+
<td></td>
263+
</tr>
253264
<tr>
254265
<td><a href="./sentiment">sentiment</a></td>
255266
<td><a href="https://storage.googleapis.com/tfjs-examples/sentiment/dist/index.html">🔗</a></td>

quantization/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

quantization/.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
data-fashion-mnist/
2+
data-mnist/
3+
imagenet-1000-samples/
4+
models/

quantization/README.md

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# TensorFlow.js Example: Effects of Post-Training Weight Quantization
2+
3+
Post-training quantization is a model-size reducing technique useful for
4+
deploying model on the web and in storage-limited environments such as
5+
mobile devices. TensorFlow.js's
6+
[converter module](https://github.com/tensorflow/tfjs-converter)
7+
supports reducing the numeric precision of weights to 16-bit and 8-bit
8+
integers after the completion of the model training, which leads to
9+
approximately 50% and 75% reduction in model size, respectively.
10+
11+
The following figure provides an intuitive understanding of the degree
12+
to which weight values are discretized under the 16- and 8-bit quantization
13+
regimes. The figure is based on a zoomed-in view of a sinusoidal wave.
14+
15+
![Weight quantization: 16-bit and 8-bit](./quantization.png)
16+
17+
This example focuses on how such quantization of weights affect the
18+
model's predicton accuracy.
19+
20+
## What's in this demo
21+
22+
This demo on quantization consists of four examples:
23+
1. housing: this demo evaluates the effect of quantization on the accuracy
24+
of a multi-layer perceptron regression model.
25+
2. mnist: this demo evaluates the effect of quantization on the accuracy
26+
of a relatively small deep convnet trained on the MNIST handwritten digits
27+
dataset. Without quantization, the convnet can achieve close-to-perfect
28+
(i.e., ~99.5%) test accuracy.
29+
3. fashion-mnist: this demo evaluates the effect of quantization on the
30+
accuracy of another small deep convnet traind on a problem slightly harder
31+
than MNIST. In particular, it is based on the Fashion MNIST dataset. The
32+
original, non-quantized model has an accuracy of 92%-93%.
33+
4. MobileNetV2: this demo evaluates quantized and non-quantizd versions of
34+
MobeilNetV2 (width = 1.0) on a sample of 1000 images from the
35+
[ImageNet](http://www.image-net.org/) dataset. This subset is based on the
36+
sampling done by https://github.com/ajschumacher/imagen.
37+
38+
In the first three demos, quantizing the weights to 16 or 8 bits does not
39+
have any significant effect on the accuracy. In the MobileNetV2 demo, however,
40+
quantizing the weights to 8 bits leads to a significant deterioration in
41+
accuracy, as measured by the top-1 and top-5 accuracies. See example results
42+
in the table below:
43+
44+
| Dataset and Mdoel | Original (no-quantization) | 16-bit quantization | 8-bit quantization |
45+
| ---------------------- | -------------------------- | ------------------- | ------------------ |
46+
| housing: multi-layer regressor | MAE=0.311984 | MAE=0.311983 | MAE=0.312780 |
47+
| MNIST: convnet | accuracy=0.9952 | accuracy=0.9952 | accuracy=0.9952 |
48+
| Fashion MNIST: convnet | accuracy=0.922 | accuracy=0.922 | accuracy=0.9211 |
49+
| MobileNetV2 | top-1 accuracy=0.618; top-5 accuracy=0.788 | top-1 accuracy=0.624; top-5 accuracy=0.789 | top-1 accuracy=0.280; top-5 accuracy=0.490 |
50+
51+
MAE Stands for mean absolute error.
52+
53+
They demonstrate different effects of the same quantization technique
54+
on different problems.
55+
56+
## Running the housing quantization demo
57+
58+
In preparation, do:
59+
60+
```sh
61+
yarn
62+
```
63+
64+
To run the train and save the model from scratch, do:
65+
```sh
66+
yarn train-housing
67+
```
68+
69+
If you are running on a Linux system that is [CUDA compatible](https://www.tensorflow.org/install/install_linux), try installing the GPU:
70+
71+
```sh
72+
yarn train-housing --gpu
73+
```
74+
75+
To perform quantization on the model saved in the `yarn train` step
76+
and evaluate the effects on the model's test accuracy, do:
77+
78+
```
79+
yarn quantize-and-evaluate-housing
80+
```
81+
82+
## Running the MNIST quantization demo
83+
84+
In preparation, do:
85+
86+
```sh
87+
yarn
88+
```
89+
90+
To run the train and save the model from scratch, do:
91+
```sh
92+
yarn train-mnist
93+
```
94+
95+
or with CUDA acceleration:
96+
97+
```sh
98+
yarn train-mnist --gpu
99+
```
100+
101+
To perform quantization on the model saved in the `yarn train` step
102+
and evaluate the effects on the model's test accuracy, do:
103+
104+
```
105+
yarn quantize-and-evaluate-mnist
106+
```
107+
108+
## Running the Fashion-MNIST quantization demo
109+
110+
In preparation, do:
111+
112+
```sh
113+
yarn
114+
```
115+
116+
To run the train and save the model from scratch, do:
117+
```sh
118+
yarn train-fashion-mnist
119+
```
120+
121+
or with CUDA acceleration:
122+
123+
```sh
124+
yarn train-fashion-mnist --gpu
125+
```
126+
127+
To perform quantization on the model saved in the `yarn train` step
128+
and evaluate the effects on the model's test accuracy, do:
129+
130+
```
131+
yarn quantize-and-evaluate-fashion-mnist
132+
```
133+
134+
## Running the MobileNetV2 quantization demo
135+
136+
Unlike the previous three demos, the MobileNetV2 demo doesn't involve
137+
a model training step. Instead, the model is loaded as a Keras application
138+
and converted to the TensorFlow.js format for quantization and evaluation.
139+
140+
The non-quantized and quantized versions of MobileNetV2 are evaluated
141+
on a sample of 1000 images from the [ImageNet](http://www.image-net.org/)
142+
dataset. The image files are downloaded from the hosted location on the
143+
web. This subset is based on the sampling done by
144+
https://github.com/ajschumacher/imagen.
145+
146+
All these steps can be performed with a single command:
147+
148+
```sh
149+
yarn quantize-and-evaluate-MobileNetV2
150+
```

quantization/data_housing.js

+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/**
2+
* @license
3+
* Copyright 2019 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
const HOUSING_CSV_URL = 'https://storage.googleapis.com/learnjs-data/csv-datasets/california_housing_train_10k.csv';
21+
22+
export const featureColumns = [
23+
'longitude', 'latitude', 'housing_median_age', 'total_rooms',
24+
'total_bedrooms', 'population', 'households', 'median_income'];
25+
const labelColumn = 'median_house_value';
26+
27+
/**
28+
* Calculate the column-by-column statistics of the housing CSV dataset.
29+
*
30+
* @return An object consisting of the following fields:
31+
* count {number} Number of data rows.
32+
* featureMeans {number[]} Each element is the arithmetic mean over all values
33+
* in a column. Ordered by the feature columns in the CSV dataset.
34+
* featureStddevs {number[]} Each element is the standard deviation over all
35+
* values in a column. Ordered by the columsn in the in the CSV dataset.
36+
* labelMean {number} The arithmetic mean of the label column.
37+
* labeStddev {number} The standard deviation of the albel column.
38+
*/
39+
export async function getDatasetStats() {
40+
const featureValues = {};
41+
featureColumns.forEach(feature => {
42+
featureValues[feature] = [];
43+
});
44+
const labelValues = [];
45+
46+
const dataset = tf.data.csv(HOUSING_CSV_URL, {
47+
columnConfigs: {
48+
[labelColumn]: {
49+
isLabel: true
50+
}
51+
}
52+
});
53+
const iterator = await dataset.iterator();
54+
let count = 0;
55+
while (true) {
56+
const item = await iterator.next();
57+
if (item.done) {
58+
break;
59+
}
60+
featureColumns.forEach(feature => {
61+
if (item.value.xs[feature] == null) {
62+
throw new Error(`item #{count} lacks feature ${feature}`);
63+
}
64+
featureValues[feature].push(item.value.xs[feature]);
65+
});
66+
labelValues.push(item.value.ys[labelColumn]);
67+
count++;
68+
}
69+
70+
return tf.tidy(() => {
71+
const featureMeans = {};
72+
const featureStddevs = {};
73+
featureColumns.forEach(feature => {
74+
const {mean, variance} = tf.moments(featureValues[feature]);
75+
featureMeans[feature] = mean.arraySync();
76+
featureStddevs[feature] = tf.sqrt(variance).arraySync();
77+
});
78+
79+
const moments = tf.moments(labelValues);
80+
const labelMean = moments.mean.arraySync();
81+
const labelStddev = tf.sqrt(moments.variance).arraySync();
82+
return {
83+
count,
84+
featureMeans,
85+
featureStddevs,
86+
labelMean,
87+
labelStddev
88+
};
89+
});
90+
}
91+
92+
/**
93+
* Get a dataset with the features and label z-normalized,
94+
* the dataset is split into three xs-ys tensor pairs: for training,
95+
* validation and evaluation.
96+
*
97+
* @param {number} count Number of rows in the CSV dataset, computed beforehand.
98+
* @param {{[feature: string]: number}} featureMeans Arithmetic means of the
99+
* features. Use for normalization.
100+
* @param {[feature: string]: number} featureStddevs Standard deviations of the
101+
* features. Used for normalization.
102+
* @param {number} labelMean Arithmetic mean of the label. Used for
103+
* normalization.
104+
* @param {number} labelStddev Standard deviation of the label. Used for
105+
* normalization.
106+
* @param {number} validationSplit Validation spilt, must be >0 and <1.
107+
* @param {number} evaluationSplit Evaluation split, must be >0 and <1.
108+
* @returns An object consisting of the following keys:
109+
* trainXs {tf.Tensor} training feature tensor
110+
* trainYs {tf.Tensor} training label tensor
111+
* valXs {tf.Tensor} validation feature tensor
112+
* valYs {tf.Tensor} validation label tensor
113+
* evalXs {tf.Tensor} evaluation feature tensor
114+
* evalYs {tf.Tensor} evaluation label tensor.
115+
*/
116+
export async function getNormalizedDatasets(
117+
count, featureMeans, featureStddevs, labelMean, labelStddev,
118+
validationSplit, evaluationSplit) {
119+
tf.util.assert(
120+
validationSplit > 0 && validationSplit < 1,
121+
() => `validationSplit is expected to be >0 and <1, ` +
122+
`but got ${validationSplit}`);
123+
tf.util.assert(
124+
evaluationSplit > 0 && evaluationSplit < 1,
125+
() => `evaluationSplit is expected to be >0 and <1, ` +
126+
`but got ${evaluationSplit}`);
127+
tf.util.assert(
128+
validationSplit + evaluationSplit < 1,
129+
() => `The sum of validationSplit and evaluationSplit exceeds 1`);
130+
131+
const dataset = tf.data.csv(HOUSING_CSV_URL, {
132+
columnConfigs: {
133+
[labelColumn]: {
134+
isLabel: true
135+
}
136+
}
137+
});
138+
139+
const featureValues = [];
140+
const labelValues = [];
141+
const indices = [];
142+
const iterator = await dataset.iterator();
143+
for (let i = 0; i < count; ++i) {
144+
const {value, done} = await iterator.next();
145+
if (done) {
146+
break;
147+
}
148+
featureColumns.map(feature => {
149+
featureValues.push(
150+
(value.xs[feature] - featureMeans[feature]) /
151+
featureStddevs[feature]);
152+
});
153+
labelValues.push((value.ys[labelColumn] - labelMean) / labelStddev);
154+
indices.push(i);
155+
}
156+
157+
const xs = tf.tensor2d(featureValues, [count, featureColumns.length]);
158+
const ys = tf.tensor2d(labelValues, [count, 1]);
159+
160+
// Set random seed to fix shuffling order and therefore to fix the
161+
// training, validation, and evaluation splits.
162+
Math.seedrandom('1337');
163+
tf.util.shuffle(indices);
164+
165+
const numTrain = Math.round(count * (1 - validationSplit - evaluationSplit));
166+
const numVal = Math.round(count * validationSplit);
167+
const trainXs = xs.gather(indices.slice(0, numTrain));
168+
const trainYs = ys.gather(indices.slice(0, numTrain));
169+
const valXs = xs.gather(indices.slice(numTrain, numTrain + numVal));
170+
const valYs = ys.gather(indices.slice(numTrain, numTrain + numVal));
171+
const evalXs = xs.gather(indices.slice(numTrain + numVal));
172+
const evalYs = ys.gather(indices.slice(numTrain + numVal));
173+
174+
return {trainXs, trainYs, valXs, valYs, evalXs, evalYs};
175+
176+
}

0 commit comments

Comments
 (0)