|
|
import torch
|
||
|
|
from torch.utils import benchmark
|
||
|
|
from torch import nn
|
||
|
|
|
||
|
|
from pack_weight import convert_weight_int8_to_int2
|
||
|
|
from torch.profiler import profile, record_function, ProfilerActivity
|
||
|
|
import ctypes
|
||
|
|
import numpy as np
|
||
|
|
# set all seed
|
||
|
|
torch.manual_seed(42)
|
||
|
|
np.random.seed(42)
|
||
|
|
|
||
|
|
bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so')
|
||
|
|
|
||
|
|
def bitnet_int8xint2_linear(input0, input1, s, ws, ret):
|
||
|
|
out_shape = list(input0.shape)
|
||
|
|
out_shape[-1] = input1.shape[0]
|
||
|
|
|
||
|
|
stream = torch.cuda.current_stream()
|
||
|
|
|
||
|
|
M = input0.shape[0]
|
||
|
|
if len(out_shape) == 3:
|
||
|
|
M *= input0.shape[1]
|
||
|
|
N = input1.shape[0]
|
||
|
|
K = input1.shape[1] * 4
|
||
|
|
|
||
|
|
bitnet_lib.bitlinear_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), ctypes.c_void_p(input1.data_ptr()), ctypes.c_void_p(ret.data_ptr()), ctypes.c_void_p(s.data_ptr()), ctypes.c_void_p(ws.data_ptr()), ctypes.c_int(M), ctypes.c_int(N), ctypes.c_int(K), ctypes.c_void_p(stream.cuda_stream)])
|
||
|
|
|
||
|
|
return ret
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
test_list = [
|
||
|
|
(2560, 2560),
|
||
|
|
(3840, 2560),
|
||
|
|
(13824, 2560),
|
||
|
|
(2560, 6912) ,
|
||
|
|
(3200, 3200),
|
||
|
|
(4800, 3200),
|
||
|
|
(3200, 10240),
|
||
|
|
(20480, 3200),
|
||
|
|
]
|
||
|
|
for N,K in test_list:
|
||
|
|
weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda')
|
||
|
|
weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
weight_compressed = convert_weight_int8_to_int2(weight).to('cuda')
|
||
|
|
|
||
|
|
for i in range(1):
|
||
|
|
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
|
||
|
|
input0_bf16 = input0.to(torch.bfloat16)
|
||
|
|
input_np = input0.cpu().to(torch.int32).numpy()
|
||
|
|
weight_np = weight.cpu().to(torch.int32).T.numpy()
|
||
|
|
out_np = np.matmul(input_np,weight_np)
|
||
|
|
out_np = torch.tensor(out_np).cuda().to(torch.bfloat16)
|
||
|
|
|
||
|
|
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
|
||
|
|
|
||
|
|
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
|
||
|
|
out = bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
|
||
|
|
|
||
|
|
print(f'custom == np {torch.all(out==out_np)}')
|
||
|
|
|
||
|
|
input0 = torch.randint(-128,127,(1, K),dtype=torch.int8, device='cuda')
|
||
|
|
input0_fp16 = input0.to(torch.float16)
|
||
|
|
input0_bf16 = input0.to(torch.bfloat16)
|
||
|
|
weight_fp16 = weight.to(torch.float16).T
|
||
|
|
weight_bf16 = weight.to(torch.bfloat16).T
|
||
|
|
ret = torch.empty((1,N), dtype=torch.bfloat16, device=input0.device)
|
||
|
|
s = torch.ones(1, dtype=torch.bfloat16, device='cuda')
|
||
|
|
ws = torch.ones(6, dtype=torch.bfloat16, device='cuda')
|
||
|
|
t0 = benchmark.Timer(
|
||
|
|
stmt="bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)",
|
||
|
|
setup="from __main__ import input0, weight_compressed, s, ws, ret, bitnet_int8xint2_linear",
|
||
|
|
num_threads=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
t1 = benchmark.Timer(
|
||
|
|
stmt="torch.matmul(input0_bf16,weight_bf16)",
|
||
|
|
setup="from __main__ import input0_bf16, weight_bf16",
|
||
|
|
num_threads=1,
|
||
|
|
)
|
||
|
|
|
||
|
|
time0 = t0.timeit(50)
|
||
|
|
time1 = t1.timeit(50)
|
||
|
|
|
||
|
|
print(f'Shape{N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us')
|
||
|
|
# activities = [ ProfilerActivity.CUDA,
|
||
|
|
# # ProfilerActivity.CPU
|
||
|
|
# ]
|
||
|
|
# sort_by_keyword = 'cuda' + "_time_total"
|
||
|
|
# with profile(activities=activities, record_shapes=True) as prof:
|
||
|
|
# with record_function("model_inference1"):
|
||
|
|
# for _ in range(10):
|
||
|
|
# bitnet_int8xint2_linear(input0, weight_compressed, s, ws, ret)
|
||
|
|
# torch.matmul(input0_fp16,weight_fp16)
|
||
|
|
# torch.matmul(input0_bf16,weight_bf16)
|
||
|
|
|
||
|
|
# print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=15))
|
||
|
|
|