2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2016-10-15 17:34:37 -04:00
import doctest
import logging
import mxnet
2016-10-16 21:10:46 -04:00
import numpy
2016-10-15 17:34:37 -04:00
def import_into ( globs , module , names = None , error_on_overwrite = True ) :
""" Import names from module into the globs dict.
2016-10-16 21:10:46 -04:00
2016-10-15 17:34:37 -04:00
Parameters
----------
"""
mod_names = dir ( module )
if names is not None :
for name in names :
2022-09-16 13:01:27 +02:00
assert name in mod_names , f ' { name } not found in { module } '
2016-10-15 17:34:37 -04:00
mod_names = names
for name in mod_names :
2016-10-16 21:10:46 -04:00
if name in globs and globs [ name ] is not getattr ( module , name ) :
2022-09-16 13:01:27 +02:00
error_msg = f ' Attempting to overwrite definition of { name } '
2016-10-15 17:34:37 -04:00
if error_on_overwrite :
raise RuntimeError ( error_msg )
logging . warning ( ' %s ' , error_msg )
globs [ name ] = getattr ( module , name )
return globs
def test_symbols ( ) :
2017-03-22 20:16:55 -07:00
globs = { ' np ' : numpy , ' mx ' : mxnet , ' test_utils ' : mxnet . test_utils , ' SymbolDoc ' : mxnet . symbol_doc . SymbolDoc }
2016-10-15 17:34:37 -04:00
# make sure all the operators are available
import_into ( globs , mxnet . symbol )
2017-03-23 20:46:18 -07:00
doctest . testmod ( mxnet . symbol_doc , globs = globs , verbose = True )
2016-10-15 17:34:37 -04:00
2017-03-22 20:16:55 -07:00
def test_ndarray ( ) :
globs = { ' np ' : numpy , ' mx ' : mxnet }
2017-03-23 20:46:18 -07:00
doctest . testmod ( mxnet . ndarray , globs = globs , verbose = True )
2017-03-22 20:16:55 -07:00
2016-10-15 17:34:37 -04:00
if __name__ == ' __main__ ' :
test_symbols ( )
2017-03-22 20:16:55 -07:00
test_ndarray ( )