2022-02-22 21:32:37 +08:00
|
|
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
|
//
|
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
|
//
|
|
|
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
//
|
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/utils/array_ref.h"
|
|
|
|
|
|
2023-08-07 16:27:41 +08:00
|
|
|
#include <array>
|
2022-02-22 21:32:37 +08:00
|
|
|
#include <cstdlib>
|
|
|
|
|
#include <ctime>
|
|
|
|
|
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "gtest/gtest.h"
|
2024-07-31 10:03:41 +08:00
|
|
|
#include "paddle/common/enforce.h"
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
TEST(array_ref, array_ref) {
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> a;
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(a.size(),
|
|
|
|
|
size_t(0),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d but received %d.",
|
|
|
|
|
size_t(0),
|
|
|
|
|
a.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(a.data(),
|
|
|
|
|
static_cast<int*>(nullptr),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
static_cast<int*>(nullptr),
|
|
|
|
|
a.data()));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> b(paddle::none);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(b.size(),
|
|
|
|
|
size_t(0),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d but received %d.",
|
|
|
|
|
size_t(0),
|
|
|
|
|
b.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(b.data(),
|
|
|
|
|
static_cast<int*>(nullptr),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
static_cast<int*>(nullptr),
|
|
|
|
|
b.data()));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
int v = 1;
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> c(v);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
c.size(),
|
|
|
|
|
size_t(1),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(1)) but received %d.",
|
|
|
|
|
size_t(1),
|
|
|
|
|
c.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
c.data(),
|
|
|
|
|
&v,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d(&v) but received %d.",
|
|
|
|
|
&v,
|
|
|
|
|
c.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(c.equals(paddle::make_array_ref(v)),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"The output of paddle::make_array_ref(v) is wrong."));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2023-08-07 16:27:41 +08:00
|
|
|
std::array<int, 5> v1 = {1, 2, 3, 4, 5};
|
|
|
|
|
paddle::array_ref<int> d(v1.data(), 5);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
d.size(),
|
|
|
|
|
size_t(5),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(5)) but received %d.",
|
|
|
|
|
size_t(5),
|
|
|
|
|
d.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(d.data(),
|
|
|
|
|
v1.data(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
v1.data(),
|
|
|
|
|
d.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
d.equals(paddle::make_array_ref(v1.data(), 5)),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"The output of paddle::make_array_ref(v1.data(), 5) is wrong."));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> e(&v1[0], &v1[4]);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
e.size(),
|
|
|
|
|
size_t(4),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(4)) but received %d.",
|
|
|
|
|
size_t(4),
|
|
|
|
|
e.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
e.data(),
|
|
|
|
|
v1.data(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d(v1.data()) but received %d.",
|
|
|
|
|
v1.data(),
|
|
|
|
|
e.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
e.equals(paddle::make_array_ref(&v1[0], &v1[4])),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"The output of paddle::make_array_ref(&v1[0], &v1[4]) is wrong."));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::small_vector<int, 3> small_vector{1, 2, 3};
|
|
|
|
|
paddle::array_ref<int> f(small_vector);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
f.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(3)) but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
f.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(f.data(),
|
|
|
|
|
small_vector.data(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
small_vector.data(),
|
|
|
|
|
f.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
f.equals(paddle::make_array_ref(small_vector)),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"The output of paddle::make_array_ref(small_vector) is wrong."));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
std::vector<int> vector{1, 2, 3};
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> g(vector);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
g.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(3)) but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
g.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(g.data(),
|
|
|
|
|
vector.data(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
vector.data(),
|
|
|
|
|
g.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
g.equals(paddle::make_array_ref(vector)),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"The output of paddle::make_array_ref(vector) is wrong."));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
std::initializer_list<int> list = {1, 2, 3};
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> h(list);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
h.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(3)) but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
h.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(h.data(),
|
|
|
|
|
list.begin(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
list.begin(),
|
|
|
|
|
h.data()));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2022-04-27 14:24:20 +08:00
|
|
|
paddle::array_ref<int> i(h);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
i.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's size is invalid, expected %d(size_t(3)) but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
i.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(i.data(),
|
|
|
|
|
list.begin(),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Array's data is invalid, expected %d but received %d.",
|
|
|
|
|
list.begin(),
|
|
|
|
|
i.data()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
i.equals(h),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument("Array i(h) is not equal with h"));
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(i.equals(paddle::make_array_ref(h)),
|
|
|
|
|
true,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"i(h) is not equal with paddle::make_array_ref(h)"));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
auto slice = i.slice(1, 2);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
slice.size(),
|
|
|
|
|
size_t(2),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Slice's size is invalid, expected %d(size_t(2)) but received %d.",
|
|
|
|
|
size_t(2),
|
|
|
|
|
slice.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
slice[0],
|
|
|
|
|
2,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"slice[0]'s value is invalid, expected 2 but received %d.",
|
|
|
|
|
slice[0]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
slice[1],
|
|
|
|
|
3,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"slice[1]'s value is invalid, expected 3 but received %d.",
|
|
|
|
|
slice[1]));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
|
|
|
|
auto drop = i.drop_front(2);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
drop.size(),
|
|
|
|
|
size_t(1),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Drop's size is invalid, expected %d(size_t(1)) but received %d.",
|
|
|
|
|
size_t(1),
|
|
|
|
|
drop.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
drop[0],
|
|
|
|
|
3,
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"drop[0]'s value is invalid, expected 3 but received %d.", drop[0]));
|
2022-02-22 21:32:37 +08:00
|
|
|
|
2023-04-03 11:17:47 +08:00
|
|
|
static paddle::array_ref<int> nums = {1, 2, 3, 4, 5, 6, 7, 8};
|
2022-02-22 21:32:37 +08:00
|
|
|
auto front = nums.take_front(3);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
front.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Front Array's size is invalid, expected %d but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
front.size()));
|
2022-02-22 21:32:37 +08:00
|
|
|
for (size_t i = 0; i < 3; ++i) {
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
front[i],
|
|
|
|
|
nums[i],
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"front[%d]'s value is invalid, expected %d but received %d.",
|
|
|
|
|
i,
|
|
|
|
|
nums[i],
|
|
|
|
|
front[i]));
|
2022-02-22 21:32:37 +08:00
|
|
|
}
|
|
|
|
|
auto back = nums.take_back(3);
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
back.size(),
|
|
|
|
|
size_t(3),
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"Back Array's size is invalid, expected %d but received %d.",
|
|
|
|
|
size_t(3),
|
|
|
|
|
back.size()));
|
2022-02-22 21:32:37 +08:00
|
|
|
for (size_t i = 0; i < 3; ++i) {
|
2024-07-31 10:03:41 +08:00
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
back[i],
|
|
|
|
|
nums[i + 5],
|
2024-08-06 10:28:58 +08:00
|
|
|
common::errors::InvalidArgument(
|
2024-07-31 10:03:41 +08:00
|
|
|
"back[%d]'s value is invalid, expected %d but received %d.",
|
|
|
|
|
i,
|
|
|
|
|
nums[i + 5],
|
|
|
|
|
back[i]));
|
2022-02-22 21:32:37 +08:00
|
|
|
}
|
|
|
|
|
}
|