2017-01-22 22:50:33 -08:00
# Model API
2016-10-18 00:41:30 +05:30
2016-11-19 00:54:29 +05:30
The model API provides a simplified way to train neural networks using common best practices.
2017-10-13 17:45:20 -07:00
It's a thin wrapper built on top of the [ndarray ](../python/ndarray/ndarray.md ) and [symbolic ](../python/symbol/symbol.md )
2017-01-22 22:50:33 -08:00
modules that make neural network training easy.
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
Topics:
2017-01-09 16:22:06 -08:00
* [Train a Model ](#train-a-model )
2016-03-19 20:37:09 -04:00
* [Save the Model ](#save-the-model )
2017-01-10 23:31:25 -08:00
* [Periodic Checkpoint ](#periodic-checkpointing )
2016-03-19 20:37:09 -04:00
* [Initializer API Reference ](#initializer-api-reference )
2017-05-23 12:10:20 -07:00
* [Evaluation Metric API Reference ](#evaluation-metric-api-reference )
2016-03-19 20:37:09 -04:00
* [Optimizer API Reference ](#optimizer-api-reference )
2016-11-21 14:32:56 -08:00
* [Model API Reference ](#model-api-reference )
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
## Train the Model
2016-10-18 00:41:30 +05:30
2017-01-22 22:50:33 -08:00
To train a model, perform two steps: configure the model using the symbol parameter,
2016-11-19 00:54:29 +05:30
then call ```model.Feedforward.create` `` to create the model.
The following example creates a two-layer neural network.
2016-03-19 20:37:09 -04:00
` ``python
2016-11-19 00:54:29 +05:30
# configure a two layer neuralnetwork
2016-12-01 01:18:40 +08:00
data = mx.symbol.Variable('data')
2016-11-19 00:54:29 +05:30
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
# create a model
model = mx.model.FeedForward.create(
softmax,
X=data_set,
num_epoch=num_epoch,
learning_rate=0.01)
2016-03-19 20:37:09 -04:00
` ``
2017-02-11 23:13:49 -08:00
You can also use the ` scikit-learn-style` construct and ` fit` function to create a model.
2016-11-19 00:54:29 +05:30
2016-03-19 20:37:09 -04:00
` ``python
2017-02-11 23:13:49 -08:00
# create a model using sklearn-style two-step way
2016-11-19 00:54:29 +05:30
model = mx.model.FeedForward(
softmax,
num_epoch=num_epoch,
learning_rate=0.01)
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
model.fit(X=data_set)
2016-03-19 20:37:09 -04:00
` ``
2016-11-19 00:54:29 +05:30
For more information, see [Model API Reference](#model-api-reference).
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
## Save the Model
2016-10-18 00:41:30 +05:30
2016-11-19 00:54:29 +05:30
After the job is done, save your work.
To save the model, you can directly pickle it with Python.
2017-02-11 23:13:49 -08:00
We also provide ` save` and ` load` functions.
2016-03-19 20:37:09 -04:00
` ``python
2016-11-19 00:54:29 +05:30
# save a model to mymodel-symbol.json and mymodel-0100.params
prefix = 'mymodel'
iteration = 100
model.save(prefix, iteration)
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
# load model back
model_loaded = mx.model.FeedForward.load(prefix, iteration)
2016-03-19 20:37:09 -04:00
` ``
2017-02-11 23:13:49 -08:00
The advantage of these two ` save` and ` load` functions are that they are language agnostic.
2016-11-19 00:54:29 +05:30
You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.
2016-03-19 20:37:09 -04:00
2016-11-19 00:54:29 +05:30
## Periodic Checkpointing
2016-10-18 00:41:30 +05:30
2016-11-19 00:54:29 +05:30
We recommend checkpointing your model after each iteration.
To do this, add a checkpoint callback ` ``do_checkpoint(path)` `` to the function.
The training process automatically checkpoints the specified location after
2016-03-19 20:37:09 -04:00
each iteration.
` ``python
2016-11-19 00:54:29 +05:30
prefix='models/chkpt'
model = mx.model.FeedForward.create(
softmax,
X=data_set,
iter_end_callback=mx.callback.do_checkpoint(prefix),
...)
2016-03-19 20:37:09 -04:00
` ``
You can load the model checkpoint later using ` ``Feedforward.load` ``.
2016-11-19 00:54:29 +05:30
## Use Multiple Devices
2016-10-18 00:41:30 +05:30
2016-11-19 00:54:29 +05:30
Set ` ``ctx` `` to the list of devices that you want to train on.
2016-03-19 20:37:09 -04:00
` ``python
2016-11-19 00:54:29 +05:30
devices = [mx.gpu(i) for i in range(num_device)]
model = mx.model.FeedForward.create(
softmax,
X=dataset,
ctx=devices,
...)
2016-03-19 20:37:09 -04:00
` ``
2016-11-19 00:54:29 +05:30
Training occurs in parallel on the GPUs that you specify.
2016-03-19 20:37:09 -04:00
2016-04-17 19:13:56 -04:00
` ``eval_rst
2016-11-19 00:54:29 +05:30
.. raw:: html
2016-04-17 19:13:56 -04:00
2016-11-19 00:54:29 +05:30
<script type="text/javascript" src='../../_static/js/auto_module_index.js'></script>
2016-04-17 19:13:56 -04:00
` ``
2016-11-19 00:54:29 +05:30
## Initializer API Reference
2016-10-18 00:41:30 +05:30
2016-03-19 20:37:09 -04:00
` ``eval_rst
2016-11-19 00:54:29 +05:30
.. automodule:: mxnet.initializer
:members:
2016-04-17 18:50:56 -04:00
2016-11-19 00:54:29 +05:30
.. raw:: html
2016-04-17 18:50:56 -04:00
2016-11-21 14:32:56 -08:00
<script>auto_index("initializer-api-reference");</script>
2016-03-19 20:37:09 -04:00
` ``
2016-11-19 00:54:29 +05:30
## Evaluation Metric API Reference
2016-10-18 00:41:30 +05:30
2016-03-19 20:37:09 -04:00
` ``eval_rst
2016-11-19 00:54:29 +05:30
.. automodule:: mxnet.metric
:members:
2016-04-17 18:50:56 -04:00
2016-11-19 00:54:29 +05:30
.. raw:: html
2016-04-17 18:50:56 -04:00
2016-11-21 14:32:56 -08:00
<script>auto_index("evaluation-metric-api-reference");</script>
2016-03-19 20:37:09 -04:00
` ``
2016-11-19 00:54:29 +05:30
## Optimizer API Reference
2016-10-18 00:41:30 +05:30
2016-03-19 20:37:09 -04:00
` ``eval_rst
2016-11-19 00:54:29 +05:30
.. automodule:: mxnet.optimizer
:members:
2016-04-17 18:50:56 -04:00
2016-11-19 00:54:29 +05:30
.. raw:: html
2016-04-17 18:50:56 -04:00
2016-11-21 14:32:56 -08:00
<script>auto_index("optimizer-api-reference");</script>
2016-03-19 20:37:09 -04:00
` ``
2016-11-19 00:54:29 +05:30
## Model API Reference
2016-10-18 00:41:30 +05:30
2016-03-19 20:37:09 -04:00
` ``eval_rst
2016-11-19 00:54:29 +05:30
.. automodule:: mxnet.model
:members:
2016-04-17 18:50:56 -04:00
2016-11-19 00:54:29 +05:30
.. raw:: html
2016-04-17 18:50:56 -04:00
2016-11-21 14:32:56 -08:00
<script>auto_index("model-api-reference");</script>
2016-03-19 20:37:09 -04:00
` ``
2016-10-18 00:41:30 +05:30
2016-11-19 00:54:29 +05:30
## Next Steps
2017-10-13 17:45:20 -07:00
* See [Symbolic API ](../python/symbol/symbol.md ) for operations on NDArrays that assemble neural networks from layers.
* See [IO Data Loading API ](../python/io/io.md ) for parsing and loading data.
* See [NDArray API ](../python/ndarray/ndarray.md ) for vector/matrix/tensor operations.
* See [KVStore API ](../python/kvstore/kvstore.md ) for multi-GPU and multi-host distributed training.