SIGN IN SIGN UP
microsoft / BitNet UNCLAIMED

Official inference framework for 1-bit LLMs

36832 0 0 Python
2025-05-15 05:55:42 +00:00
import torch
import numpy as np
def B_global_16x32_to_shared_load_16x32_layout(i, j):
"""
stride * 8 * (tx // HALF_WARP_expr)
+ (tx % 8) * stride
+ 16 * ((tx % HALF_WARP_expr) // 8)
"""
thread_id = i * 2 + j // 16
row = (thread_id // 16) * 8 + (thread_id % 8)
col = (j % 16) + 16 * ((thread_id % 16) // 8)
return row, col
def permutate_weight_fastest(weight):
wmma_n = 16
wmma_k = 32
N = weight.shape[0]
K = weight.shape[1]
# Create a lookup table for the permutation
mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int)
for ii in range(wmma_n):
for jj in range(wmma_k):
mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj)
# Reshape weight for the final format
permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8")
# Use advanced indexing for the entire operation
i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis]
j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis]
# Create the source indices
src_i = i_indices * wmma_n + mapping[:, :, 0]
src_j = j_indices * wmma_k + mapping[:, :, 1]
# Extract and reshape in one go
permutated_weight = weight[src_i, src_j]
return permutated_weight
def compress_int2_to_int8(int2_weight):
int8_weight = np.zeros(
(*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8
)
for j in range(int2_weight.shape[-1] // 4):
for k in range(4):
int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2)
return int8_weight
def interleave_weight_int8(qweight, nbits=2):\
# reinterpret the data type of qweight to int32
# shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30]
# index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15]
qweight = qweight.view(np.int32)
new_qweight = np.zeros_like(qweight)
bits_stride = 8
mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f
num_groups = 32 // bits_stride # 4
elems_per_group = bits_stride // nbits # 4
for i in range(num_groups):
for j in range(elems_per_group):
offset = i * elems_per_group + j
shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits
new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift
return new_qweight.view(np.int8)
def convert_weight_int8_to_int2(weight):
N = weight.shape[0]
K = weight.shape[1]
weight = weight+2
weight = weight.cpu().numpy()
# print(weight)
# print(torch.max(weight), torch.min(weight))
# permutated_weight_slow = permutate_weight(weight)
permutated_weight = permutate_weight_fastest(weight)
# assert np.all(permutated_weight_slow == permutated_weight)
# print("Permutation is correct")
compressed_weight = compress_int2_to_int8(permutated_weight)
interleaved_weight = interleave_weight_int8(compressed_weight, 2)
ret = torch.from_numpy(interleaved_weight)
ret = torch.reshape(ret, (N, K // 4))
return ret