2025-12-03 13:46:52 -08:00
|
|
|
"""Tests for timm pooling layers."""
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
import importlib
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
torch_backend = os.environ.get('TORCH_BACKEND')
|
|
|
|
|
if torch_backend is not None:
|
|
|
|
|
importlib.import_module(torch_backend)
|
|
|
|
|
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Adaptive Avg/Max Pooling Tests
|
|
|
|
|
|
|
|
|
|
class TestAdaptiveAvgMaxPool:
|
|
|
|
|
"""Test adaptive_avgmax_pool module."""
|
|
|
|
|
|
|
|
|
|
def test_adaptive_avgmax_pool2d(self):
|
|
|
|
|
from timm.layers import adaptive_avgmax_pool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
out = adaptive_avgmax_pool2d(x, 1)
|
|
|
|
|
assert out.shape == (2, 64, 1, 1)
|
|
|
|
|
# Should be average of avg and max
|
|
|
|
|
expected = 0.5 * (x.mean(dim=(2, 3), keepdim=True) + x.amax(dim=(2, 3), keepdim=True))
|
|
|
|
|
assert torch.allclose(out, expected)
|
|
|
|
|
|
|
|
|
|
def test_select_adaptive_pool2d(self):
|
|
|
|
|
from timm.layers import select_adaptive_pool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
|
|
|
|
|
out_avg = select_adaptive_pool2d(x, pool_type='avg', output_size=1)
|
|
|
|
|
assert out_avg.shape == (2, 64, 1, 1)
|
|
|
|
|
assert torch.allclose(out_avg, x.mean(dim=(2, 3), keepdim=True))
|
|
|
|
|
|
|
|
|
|
out_max = select_adaptive_pool2d(x, pool_type='max', output_size=1)
|
|
|
|
|
assert out_max.shape == (2, 64, 1, 1)
|
|
|
|
|
assert torch.allclose(out_max, x.amax(dim=(2, 3), keepdim=True))
|
|
|
|
|
|
|
|
|
|
def test_adaptive_avgmax_pool2d_module(self):
|
|
|
|
|
from timm.layers import AdaptiveAvgMaxPool2d
|
|
|
|
|
x = torch.randn(2, 64, 14, 14, device=torch_device)
|
|
|
|
|
pool = AdaptiveAvgMaxPool2d(output_size=1).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64, 1, 1)
|
|
|
|
|
|
|
|
|
|
def test_select_adaptive_pool2d_module(self):
|
|
|
|
|
from timm.layers import SelectAdaptivePool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
|
|
|
|
|
for pool_type in ['avg', 'max', 'avgmax', 'catavgmax']:
|
|
|
|
|
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
if pool_type == 'catavgmax':
|
|
|
|
|
assert out.shape == (2, 128) # concatenated
|
|
|
|
|
else:
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_select_adaptive_pool2d_fast(self):
|
|
|
|
|
from timm.layers import SelectAdaptivePool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
|
|
|
|
|
for pool_type in ['fast', 'fastavg', 'fastmax', 'fastavgmax', 'fastcatavgmax']:
|
|
|
|
|
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
if 'cat' in pool_type:
|
|
|
|
|
assert out.shape == (2, 128)
|
|
|
|
|
else:
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Attention Pool Tests
|
|
|
|
|
|
|
|
|
|
class TestAttentionPool:
|
|
|
|
|
"""Test attention-based pooling layers."""
|
|
|
|
|
|
|
|
|
|
def test_attention_pool_latent_basic(self):
|
|
|
|
|
from timm.layers import AttentionPoolLatent
|
|
|
|
|
x = torch.randn(2, 49, 64, device=torch_device)
|
|
|
|
|
pool = AttentionPoolLatent(in_features=64, num_heads=4).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_attention_pool_latent_multi_latent(self):
|
|
|
|
|
from timm.layers import AttentionPoolLatent
|
|
|
|
|
x = torch.randn(2, 49, 64, device=torch_device)
|
|
|
|
|
pool = AttentionPoolLatent(
|
|
|
|
|
in_features=64,
|
|
|
|
|
num_heads=4,
|
|
|
|
|
latent_len=4,
|
|
|
|
|
pool_type='avg',
|
|
|
|
|
).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_attention_pool2d_basic(self):
|
|
|
|
|
from timm.layers import AttentionPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_attention_pool2d_different_feat_size(self):
|
|
|
|
|
from timm.layers import AttentionPool2d
|
|
|
|
|
# Test with different spatial sizes (requires pos_embed interpolation)
|
|
|
|
|
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
|
|
|
|
|
for size in [7, 14]:
|
|
|
|
|
x = torch.randn(2, 64, size, size, device=torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_rot_attention_pool2d_basic(self):
|
|
|
|
|
from timm.layers import RotAttentionPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_rot_attention_pool2d_different_sizes(self):
|
|
|
|
|
from timm.layers import RotAttentionPool2d
|
|
|
|
|
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
|
|
|
|
|
for size in [7, 14, 10]:
|
|
|
|
|
x = torch.randn(2, 64, size, size, device=torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_rot_attention_pool2d_rope_types(self):
|
|
|
|
|
from timm.layers import RotAttentionPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
for rope_type in ['base', 'cat', 'dinov3']:
|
|
|
|
|
pool = RotAttentionPool2d(
|
|
|
|
|
in_features=64,
|
|
|
|
|
ref_feat_size=7,
|
|
|
|
|
rope_type=rope_type,
|
|
|
|
|
).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# LSE Pool Tests
|
|
|
|
|
|
|
|
|
|
class TestLsePool:
|
|
|
|
|
"""Test LogSumExp pooling layers."""
|
|
|
|
|
|
|
|
|
|
def test_lse_plus_2d_basic(self):
|
|
|
|
|
from timm.layers import LsePlus2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = LsePlus2d().to(torch_device)
|
|
|
|
|
out = pool(x)
|
2025-12-03 19:13:37 -08:00
|
|
|
# Default is flatten=True
|
|
|
|
|
assert out.shape == (2, 64)
|
2025-12-03 13:46:52 -08:00
|
|
|
|
2025-12-03 19:13:37 -08:00
|
|
|
def test_lse_plus_2d_no_flatten(self):
|
2025-12-03 13:46:52 -08:00
|
|
|
from timm.layers import LsePlus2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
2025-12-03 19:13:37 -08:00
|
|
|
pool = LsePlus2d(flatten=False).to(torch_device)
|
2025-12-03 13:46:52 -08:00
|
|
|
out = pool(x)
|
2025-12-03 19:13:37 -08:00
|
|
|
assert out.shape == (2, 64, 1, 1)
|
2025-12-03 13:46:52 -08:00
|
|
|
|
|
|
|
|
def test_lse_plus_1d_basic(self):
|
|
|
|
|
from timm.layers import LsePlus1d
|
|
|
|
|
x = torch.randn(2, 49, 64, device=torch_device)
|
|
|
|
|
pool = LsePlus1d().to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_lse_high_r_approximates_max(self):
|
|
|
|
|
from timm.layers import LsePlus2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device)
|
|
|
|
|
out = pool(x)
|
2025-12-03 19:13:37 -08:00
|
|
|
out_max = x.amax(dim=(2, 3))
|
2025-12-03 13:46:52 -08:00
|
|
|
assert torch.allclose(out, out_max, atol=0.1)
|
|
|
|
|
|
|
|
|
|
def test_lse_low_r_approximates_avg(self):
|
|
|
|
|
from timm.layers import LsePlus2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device)
|
|
|
|
|
out = pool(x)
|
2025-12-03 19:13:37 -08:00
|
|
|
out_avg = x.mean(dim=(2, 3))
|
2025-12-03 13:46:52 -08:00
|
|
|
assert torch.allclose(out, out_avg, atol=0.1)
|
|
|
|
|
|
|
|
|
|
def test_lse_learnable_r_gradient(self):
|
|
|
|
|
from timm.layers import LsePlus2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = LsePlus2d(r=10.0, r_learnable=True).to(torch_device)
|
|
|
|
|
out = pool(x).sum()
|
|
|
|
|
out.backward()
|
|
|
|
|
assert pool.r.grad is not None
|
|
|
|
|
assert pool.r.grad.abs() > 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# SimPool Tests
|
|
|
|
|
|
|
|
|
|
class TestSimPool:
|
|
|
|
|
"""Test SimPool attention-based pooling layers."""
|
|
|
|
|
|
|
|
|
|
def test_simpool_2d_basic(self):
|
|
|
|
|
from timm.layers import SimPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
pool = SimPool2d(dim=64).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_simpool_1d_basic(self):
|
|
|
|
|
from timm.layers import SimPool1d
|
|
|
|
|
x = torch.randn(2, 49, 64, device=torch_device)
|
|
|
|
|
pool = SimPool1d(dim=64).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_simpool_multi_head(self):
|
|
|
|
|
from timm.layers import SimPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
|
|
|
|
for num_heads in [1, 2, 4, 8]:
|
2025-12-03 19:13:37 -08:00
|
|
|
pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device)
|
2025-12-03 13:46:52 -08:00
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
def test_simpool_with_gamma(self):
|
|
|
|
|
from timm.layers import SimPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
2025-12-03 19:13:37 -08:00
|
|
|
pool = SimPool2d(dim=64, gamma=2.0).to(torch_device)
|
2025-12-03 13:46:52 -08:00
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
assert not torch.isnan(out).any()
|
|
|
|
|
|
|
|
|
|
def test_simpool_qk_norm(self):
|
|
|
|
|
from timm.layers import SimPool2d
|
|
|
|
|
x = torch.randn(2, 64, 7, 7, device=torch_device)
|
2025-12-03 19:13:37 -08:00
|
|
|
pool = SimPool2d(dim=64, qk_norm=True).to(torch_device)
|
2025-12-03 13:46:52 -08:00
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Common Tests (Gradient, JIT, dtype)
|
|
|
|
|
|
|
|
|
|
class TestPoolingCommon:
|
|
|
|
|
"""Common tests across all pooling layers."""
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
|
|
|
|
|
('LsePlus2d', {}, (2, 64, 7, 7)),
|
|
|
|
|
('LsePlus1d', {}, (2, 49, 64)),
|
|
|
|
|
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
|
|
|
|
|
('SimPool1d', {'dim': 64}, (2, 49, 64)),
|
|
|
|
|
('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)),
|
|
|
|
|
('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)),
|
|
|
|
|
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
])
|
|
|
|
|
def test_gradient_flow(self, pool_cls, kwargs, input_shape):
|
|
|
|
|
import timm.layers as layers
|
|
|
|
|
x = torch.randn(*input_shape, device=torch_device, requires_grad=True)
|
|
|
|
|
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
loss = out.sum()
|
|
|
|
|
loss.backward()
|
|
|
|
|
assert x.grad is not None
|
|
|
|
|
assert x.grad.abs().sum() > 0
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
|
|
|
|
|
('LsePlus2d', {}, (2, 64, 7, 7)),
|
|
|
|
|
('LsePlus1d', {}, (2, 49, 64)),
|
|
|
|
|
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
|
|
|
|
|
('SimPool1d', {'dim': 64}, (2, 49, 64)),
|
|
|
|
|
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
])
|
|
|
|
|
def test_torchscript(self, pool_cls, kwargs, input_shape):
|
|
|
|
|
import timm.layers as layers
|
|
|
|
|
x = torch.randn(*input_shape, device=torch_device)
|
|
|
|
|
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
|
|
|
|
|
pool.eval()
|
|
|
|
|
scripted = torch.jit.script(pool)
|
|
|
|
|
out_orig = pool(x)
|
|
|
|
|
out_script = scripted(x)
|
|
|
|
|
assert torch.allclose(out_orig, out_script, atol=1e-5)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
|
2025-12-03 19:13:37 -08:00
|
|
|
('LsePlus2d', {}, (2, 64, 7, 7)),
|
2025-12-03 13:46:52 -08:00
|
|
|
('LsePlus1d', {}, (2, 49, 64)),
|
2025-12-03 19:13:37 -08:00
|
|
|
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
|
2025-12-03 13:46:52 -08:00
|
|
|
('SimPool1d', {'dim': 64}, (2, 49, 64)),
|
|
|
|
|
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
])
|
|
|
|
|
def test_eval_deterministic(self, pool_cls, kwargs, input_shape):
|
|
|
|
|
import timm.layers as layers
|
|
|
|
|
x = torch.randn(*input_shape, device=torch_device)
|
|
|
|
|
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
|
|
|
|
|
pool.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
out1 = pool(x)
|
|
|
|
|
out2 = pool(x)
|
|
|
|
|
assert torch.allclose(out1, out2)
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
|
2025-12-03 19:13:37 -08:00
|
|
|
('LsePlus2d', {}, (2, 64, 7, 7)),
|
|
|
|
|
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
|
2025-12-03 13:46:52 -08:00
|
|
|
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
|
|
|
|
|
])
|
|
|
|
|
def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape):
|
|
|
|
|
import timm.layers as layers
|
|
|
|
|
B, C, _, _ = input_shape
|
|
|
|
|
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
|
|
|
|
|
for H, W in [(7, 7), (14, 14), (1, 1), (3, 5)]:
|
|
|
|
|
x = torch.randn(B, C, H, W, device=torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape[0] == B
|
|
|
|
|
assert out.shape[-1] == C
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# BlurPool Tests
|
|
|
|
|
|
|
|
|
|
class TestBlurPool:
|
|
|
|
|
"""Test BlurPool anti-aliasing layer."""
|
|
|
|
|
|
|
|
|
|
def test_blur_pool_2d_basic(self):
|
|
|
|
|
from timm.layers import BlurPool2d
|
|
|
|
|
x = torch.randn(2, 64, 14, 14, device=torch_device)
|
|
|
|
|
pool = BlurPool2d(channels=64).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64, 7, 7)
|
|
|
|
|
|
|
|
|
|
def test_blur_pool_2d_stride(self):
|
|
|
|
|
from timm.layers import BlurPool2d
|
|
|
|
|
x = torch.randn(2, 64, 28, 28, device=torch_device)
|
|
|
|
|
pool = BlurPool2d(channels=64, stride=4).to(torch_device)
|
|
|
|
|
out = pool(x)
|
|
|
|
|
assert out.shape == (2, 64, 8, 8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Pool1d Tests
|
|
|
|
|
|
|
|
|
|
class TestPool1d:
|
|
|
|
|
"""Test 1D pooling utilities."""
|
|
|
|
|
|
|
|
|
|
def test_global_pool_nlc(self):
|
|
|
|
|
from timm.layers import global_pool_nlc
|
|
|
|
|
x = torch.randn(2, 49, 64, device=torch_device)
|
|
|
|
|
|
|
|
|
|
# By default, avg/max excludes first token (num_prefix_tokens=1)
|
|
|
|
|
out_avg = global_pool_nlc(x, pool_type='avg')
|
|
|
|
|
assert out_avg.shape == (2, 64)
|
|
|
|
|
assert torch.allclose(out_avg, x[:, 1:].mean(dim=1))
|
|
|
|
|
|
|
|
|
|
out_max = global_pool_nlc(x, pool_type='max')
|
|
|
|
|
assert out_max.shape == (2, 64)
|
|
|
|
|
assert torch.allclose(out_max, x[:, 1:].amax(dim=1))
|
|
|
|
|
|
|
|
|
|
out_first = global_pool_nlc(x, pool_type='token')
|
|
|
|
|
assert out_first.shape == (2, 64)
|
|
|
|
|
assert torch.allclose(out_first, x[:, 0])
|
|
|
|
|
|
|
|
|
|
# Test with reduce_include_prefix=True
|
|
|
|
|
out_avg_all = global_pool_nlc(x, pool_type='avg', reduce_include_prefix=True)
|
|
|
|
|
assert torch.allclose(out_avg_all, x.mean(dim=1))
|