SIGN IN SIGN UP

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

0 0 40 Python
"""Tests for timm pooling layers."""
import pytest
import torch
import torch.nn as nn
import importlib
import os
torch_backend = os.environ.get('TORCH_BACKEND')
if torch_backend is not None:
importlib.import_module(torch_backend)
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
# Adaptive Avg/Max Pooling Tests
class TestAdaptiveAvgMaxPool:
"""Test adaptive_avgmax_pool module."""
def test_adaptive_avgmax_pool2d(self):
from timm.layers import adaptive_avgmax_pool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
out = adaptive_avgmax_pool2d(x, 1)
assert out.shape == (2, 64, 1, 1)
# Should be average of avg and max
expected = 0.5 * (x.mean(dim=(2, 3), keepdim=True) + x.amax(dim=(2, 3), keepdim=True))
assert torch.allclose(out, expected)
def test_select_adaptive_pool2d(self):
from timm.layers import select_adaptive_pool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
out_avg = select_adaptive_pool2d(x, pool_type='avg', output_size=1)
assert out_avg.shape == (2, 64, 1, 1)
assert torch.allclose(out_avg, x.mean(dim=(2, 3), keepdim=True))
out_max = select_adaptive_pool2d(x, pool_type='max', output_size=1)
assert out_max.shape == (2, 64, 1, 1)
assert torch.allclose(out_max, x.amax(dim=(2, 3), keepdim=True))
def test_adaptive_avgmax_pool2d_module(self):
from timm.layers import AdaptiveAvgMaxPool2d
x = torch.randn(2, 64, 14, 14, device=torch_device)
pool = AdaptiveAvgMaxPool2d(output_size=1).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 1, 1)
def test_select_adaptive_pool2d_module(self):
from timm.layers import SelectAdaptivePool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for pool_type in ['avg', 'max', 'avgmax', 'catavgmax']:
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
out = pool(x)
if pool_type == 'catavgmax':
assert out.shape == (2, 128) # concatenated
else:
assert out.shape == (2, 64)
def test_select_adaptive_pool2d_fast(self):
from timm.layers import SelectAdaptivePool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for pool_type in ['fast', 'fastavg', 'fastmax', 'fastavgmax', 'fastcatavgmax']:
pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True).to(torch_device)
out = pool(x)
if 'cat' in pool_type:
assert out.shape == (2, 128)
else:
assert out.shape == (2, 64)
# Attention Pool Tests
class TestAttentionPool:
"""Test attention-based pooling layers."""
def test_attention_pool_latent_basic(self):
from timm.layers import AttentionPoolLatent
x = torch.randn(2, 49, 64, device=torch_device)
pool = AttentionPoolLatent(in_features=64, num_heads=4).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_attention_pool_latent_multi_latent(self):
from timm.layers import AttentionPoolLatent
x = torch.randn(2, 49, 64, device=torch_device)
pool = AttentionPoolLatent(
in_features=64,
num_heads=4,
latent_len=4,
pool_type='avg',
).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_attention_pool2d_basic(self):
from timm.layers import AttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_attention_pool2d_different_feat_size(self):
from timm.layers import AttentionPool2d
# Test with different spatial sizes (requires pos_embed interpolation)
pool = AttentionPool2d(in_features=64, feat_size=7).to(torch_device)
for size in [7, 14]:
x = torch.randn(2, 64, size, size, device=torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_rot_attention_pool2d_basic(self):
from timm.layers import RotAttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_rot_attention_pool2d_different_sizes(self):
from timm.layers import RotAttentionPool2d
pool = RotAttentionPool2d(in_features=64, ref_feat_size=7).to(torch_device)
for size in [7, 14, 10]:
x = torch.randn(2, 64, size, size, device=torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_rot_attention_pool2d_rope_types(self):
from timm.layers import RotAttentionPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for rope_type in ['base', 'cat', 'dinov3']:
pool = RotAttentionPool2d(
in_features=64,
ref_feat_size=7,
rope_type=rope_type,
).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
# LSE Pool Tests
class TestLsePool:
"""Test LogSumExp pooling layers."""
def test_lse_plus_2d_basic(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d().to(torch_device)
out = pool(x)
# Default is flatten=True
assert out.shape == (2, 64)
def test_lse_plus_2d_no_flatten(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(flatten=False).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 1, 1)
def test_lse_plus_1d_basic(self):
from timm.layers import LsePlus1d
x = torch.randn(2, 49, 64, device=torch_device)
pool = LsePlus1d().to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_lse_high_r_approximates_max(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=100.0, r_learnable=False).to(torch_device)
out = pool(x)
out_max = x.amax(dim=(2, 3))
assert torch.allclose(out, out_max, atol=0.1)
def test_lse_low_r_approximates_avg(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=0.01, r_learnable=False).to(torch_device)
out = pool(x)
out_avg = x.mean(dim=(2, 3))
assert torch.allclose(out, out_avg, atol=0.1)
def test_lse_learnable_r_gradient(self):
from timm.layers import LsePlus2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = LsePlus2d(r=10.0, r_learnable=True).to(torch_device)
out = pool(x).sum()
out.backward()
assert pool.r.grad is not None
assert pool.r.grad.abs() > 0
# SimPool Tests
class TestSimPool:
"""Test SimPool attention-based pooling layers."""
def test_simpool_2d_basic(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_simpool_1d_basic(self):
from timm.layers import SimPool1d
x = torch.randn(2, 49, 64, device=torch_device)
pool = SimPool1d(dim=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_simpool_multi_head(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
for num_heads in [1, 2, 4, 8]:
pool = SimPool2d(dim=64, num_heads=num_heads).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
def test_simpool_with_gamma(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64, gamma=2.0).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
assert not torch.isnan(out).any()
def test_simpool_qk_norm(self):
from timm.layers import SimPool2d
x = torch.randn(2, 64, 7, 7, device=torch_device)
pool = SimPool2d(dim=64, qk_norm=True).to(torch_device)
out = pool(x)
assert out.shape == (2, 64)
# Common Tests (Gradient, JIT, dtype)
class TestPoolingCommon:
"""Common tests across all pooling layers."""
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('SelectAdaptivePool2d', {'pool_type': 'avg', 'flatten': True}, (2, 64, 7, 7)),
('AttentionPoolLatent', {'in_features': 64, 'num_heads': 4}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_gradient_flow(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device, requires_grad=True)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
out = pool(x)
loss = out.sum()
loss.backward()
assert x.grad is not None
assert x.grad.abs().sum() > 0
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_torchscript(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
pool.eval()
scripted = torch.jit.script(pool)
out_orig = pool(x)
out_script = scripted(x)
assert torch.allclose(out_orig, out_script, atol=1e-5)
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('LsePlus1d', {}, (2, 49, 64)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('SimPool1d', {'dim': 64}, (2, 49, 64)),
('AttentionPool2d', {'in_features': 64, 'feat_size': 7}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_eval_deterministic(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
x = torch.randn(*input_shape, device=torch_device)
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
pool.eval()
with torch.no_grad():
out1 = pool(x)
out2 = pool(x)
assert torch.allclose(out1, out2)
@pytest.mark.parametrize('pool_cls,kwargs,input_shape', [
('LsePlus2d', {}, (2, 64, 7, 7)),
('SimPool2d', {'dim': 64}, (2, 64, 7, 7)),
('RotAttentionPool2d', {'in_features': 64, 'ref_feat_size': 7}, (2, 64, 7, 7)),
])
def test_different_spatial_sizes(self, pool_cls, kwargs, input_shape):
import timm.layers as layers
B, C, _, _ = input_shape
pool = getattr(layers, pool_cls)(**kwargs).to(torch_device)
for H, W in [(7, 7), (14, 14), (1, 1), (3, 5)]:
x = torch.randn(B, C, H, W, device=torch_device)
out = pool(x)
assert out.shape[0] == B
assert out.shape[-1] == C
# BlurPool Tests
class TestBlurPool:
"""Test BlurPool anti-aliasing layer."""
def test_blur_pool_2d_basic(self):
from timm.layers import BlurPool2d
x = torch.randn(2, 64, 14, 14, device=torch_device)
pool = BlurPool2d(channels=64).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 7, 7)
def test_blur_pool_2d_stride(self):
from timm.layers import BlurPool2d
x = torch.randn(2, 64, 28, 28, device=torch_device)
pool = BlurPool2d(channels=64, stride=4).to(torch_device)
out = pool(x)
assert out.shape == (2, 64, 8, 8)
# Pool1d Tests
class TestPool1d:
"""Test 1D pooling utilities."""
def test_global_pool_nlc(self):
from timm.layers import global_pool_nlc
x = torch.randn(2, 49, 64, device=torch_device)
# By default, avg/max excludes first token (num_prefix_tokens=1)
out_avg = global_pool_nlc(x, pool_type='avg')
assert out_avg.shape == (2, 64)
assert torch.allclose(out_avg, x[:, 1:].mean(dim=1))
out_max = global_pool_nlc(x, pool_type='max')
assert out_max.shape == (2, 64)
assert torch.allclose(out_max, x[:, 1:].amax(dim=1))
out_first = global_pool_nlc(x, pool_type='token')
assert out_first.shape == (2, 64)
assert torch.allclose(out_first, x[:, 0])
# Test with reduce_include_prefix=True
out_avg_all = global_pool_nlc(x, pool_type='avg', reduce_include_prefix=True)
assert torch.allclose(out_avg_all, x.mean(dim=1))