When TensorFlow 1.3 was released the Estimator, and related high-level APIs, caught my eye. This is almost a year ago and TensorFlow has had a few updates, with 1.8 the latest version at the time of writing this. Time to revisit these APIs and see how they evolved.
The Estimator and Dataset APIs have become more mature since TF 1.3. The TensorFlow tutorials recommend to use them when writing TensorFlow programs:
We strongly recommend writing TensorFlow programs with the following APIs:
Estimators, which represent a complete model. The Estimator API provides methods to train the model, to judge the model’s accuracy, and to generate predictions.
Datasets, which build a data input pipeline. The Dataset API has methods to load and manipulate data, and feed it into your model. The Dataset API meshes well with the Estimators API.
The Estimator API provides a top-level abstraction and integrates nicely with other APIs such as the Dataset API to build input streams, and the Layers API to build model architectures. It’s even possible to construct an estimator from a Keras model with one function.
In the following, I will give an overview of the API. An accompanying repository with example code is provided here.
The core of the Estimator API has stayed stable, we can still create an estimator as follows:
return tf.estimator.Estimator( model_fn=model_fn, config=config, params=params, )
After creating the Estimator we can train it using the train_and_evaluate function:
tf.estimator.train_and_evaluate(model_estimator, train_spec, eval_spec)
Note that this is different to the previous blogpost, where we used the TensorFlow Experiment class, which is now deprecated, to run the training. Getting rid of the Experiment class makes everything less complicated. The training and evaluation input functions and hooks are now clearly separated into the TrainSpec and the EvalSpec, and you only have to call the train_and_evaluate function.
The Dataset API has become fully mature and moved from contrib to the TensorFlow core library and it now allows you to build complex input pipelines. In the accompanying code this is used to build an input feeder that shuffles the data and repeats it for as long as needed with the correct batch size:
dataset = tf.data.Dataset.from_tensor_slices(mnist_data) dataset = dataset.shuffle( buffer_size=1000, reshuffle_each_iteration=True ).repeat(count=None).batch(batch_size)
Note that in this example I’m just loading the mnist data from the Keras API. You can build much more complex input pipelines with the Dataset API.
You can run the code from the repo locally by:
This should start a training and evaluation session. By using the Estimator API it also sets-up default logging and checkpoint saving, which we can visualize with TensorBoard:
Thanks to the abstraction of the configuration the Estimator API allows for it to train models easily on Google Cloud’s ML Engine:
tf.estimator.train_and_evaluateyou can run the same code both locally and distributed in the cloud, on different devices and using different cluster configurations, and get consistent results without making any code changes
I provided a minimal example of how to run the accompanying code on Google Cloud. For example, you can train the code on the cloud by running:
gcloud ml-engine jobs submit training mnist_estimator_`date +%s` \ --project mnist-estimator \ --runtime-version 1.8 \ --python-version 3.5 \ --job-dir gs://estimator-data/train \ --scale-tier BASIC \ --region europe-west1 \ --module-name src.mnist_estimator \ --package-path src/ \ -- \ --train-steps 6000 \ --batch-size 128
In summary, the TensorFlow Estimator API, as well as the Dataset API, have matured a lot. They provide a nice abstraction layer to manage input data streams, models, and training/evaluation configurations.
Data Science: Supervised Machine Learning in Python
Ensemble Machine Learning in Python: Random Forest, AdaBoost
Unity Machine Learning with Python!
☞ Machine Learning Zero to Hero - Learn Machine Learning from scratch
☞ Platform for Complete Machine Learning Lifecycle
☞ Introduction to Machine Learning with TensorFlow.js
☞ TensorFlow.js Bringing Machine Learning to the Web and Beyond
☞ Learn Python in 12 Hours | Python Tutorial For Beginners