import torch import numpy as np def B_global_16x32_to_shared_load_16x32_layout(i, j): """ stride * 8 * (tx // HALF_WARP_expr) + (tx % 8) * stride + 16 * ((tx % HALF_WARP_expr) // 8) """ thread_id = i * 2 + j // 16 row = (thread_id // 16) * 8 + (thread_id % 8) col = (j % 16) + 16 * ((thread_id % 16) // 8) return row, col def permutate_weight_fastest(weight): wmma_n = 16 wmma_k = 32 N = weight.shape[0] K = weight.shape[1] # Create a lookup table for the permutation mapping = np.zeros((wmma_n, wmma_k, 2), dtype=int) for ii in range(wmma_n): for jj in range(wmma_k): mapping[ii, jj] = B_global_16x32_to_shared_load_16x32_layout(ii, jj) # Reshape weight for the final format permutated_weight = np.zeros((N // wmma_n, K // wmma_k, wmma_n, wmma_k), dtype="int8") # Use advanced indexing for the entire operation i_indices = np.arange(N // wmma_n)[:, np.newaxis, np.newaxis, np.newaxis] j_indices = np.arange(K // wmma_k)[np.newaxis, :, np.newaxis, np.newaxis] # Create the source indices src_i = i_indices * wmma_n + mapping[:, :, 0] src_j = j_indices * wmma_k + mapping[:, :, 1] # Extract and reshape in one go permutated_weight = weight[src_i, src_j] return permutated_weight def compress_int2_to_int8(int2_weight): int8_weight = np.zeros( (*int2_weight.shape[:-1], int2_weight.shape[-1] // 4), dtype=np.int8 ) for j in range(int2_weight.shape[-1] // 4): for k in range(4): int8_weight[:, :, :, j] |= int2_weight[:, :, :, j * 4 + k] << (k * 2) return int8_weight def interleave_weight_int8(qweight, nbits=2):\ # reinterpret the data type of qweight to int32 # shift = [ 0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30] # index: [ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15] qweight = qweight.view(np.int32) new_qweight = np.zeros_like(qweight) bits_stride = 8 mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f num_groups = 32 // bits_stride # 4 elems_per_group = bits_stride // nbits # 4 for i in range(num_groups): for j in range(elems_per_group): offset = i * elems_per_group + j shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift return new_qweight.view(np.int8) def convert_weight_int8_to_int2(weight): N = weight.shape[0] K = weight.shape[1] weight = weight+2 weight = weight.cpu().numpy() # print(weight) # print(torch.max(weight), torch.min(weight)) # permutated_weight_slow = permutate_weight(weight) permutated_weight = permutate_weight_fastest(weight) # assert np.all(permutated_weight_slow == permutated_weight) # print("Permutation is correct") compressed_weight = compress_int2_to_int8(permutated_weight) interleaved_weight = interleave_weight_int8(compressed_weight, 2) ret = torch.from_numpy(interleaved_weight) ret = torch.reshape(ret, (N, K // 4)) return ret