2017-08-08 16:36:23 -07:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2015-11-21 20:46:05 -08:00
# pylint: skip-file
2015-11-20 21:51:28 -08:00
import mxnet as mx
2015-11-21 20:46:05 -08:00
import numpy as np
2015-11-20 21:51:28 -08:00
from numpy . testing import assert_allclose
if __name__ == ' __main__ ' :
x = mx . nd . zeros ( ( 10 , ) , ctx = mx . gpu ( 0 ) )
x [ : ] = 1
y = mx . nd . zeros ( ( 10 , ) , ctx = mx . gpu ( 0 ) )
y [ : ] = 2
2015-11-21 20:46:05 -08:00
rtc = mx . rtc ( ' abc ' , [ ( ' x ' , x ) ] , [ ( ' y ' , y ) ] , """
__shared__ float s_rec[10];
s_rec[threadIdx.x] = x[threadIdx.x];
y[threadIdx.x] = expf(s_rec[threadIdx.x]*5.0); """ )
2015-11-20 21:51:28 -08:00
rtc . push ( [ x ] , [ y ] , ( 1 , 1 , 1 ) , ( 10 , 1 , 1 ) )
2017-08-08 16:36:23 -07:00
assert_allclose ( y . asnumpy ( ) , np . exp ( x . asnumpy ( ) * 5.0 ) )