SIGN IN SIGN UP
microsoft / BitNet UNCLAIMED

Official inference framework for 1-bit LLMs

36832 0 0 Python
#!/usr/bin/env python3
"""
Perplexity Test Script
Tests GGUF model perplexity on multiple datasets using llama-perplexity.
"""
import os
import subprocess
import time
import csv
import re
from datetime import datetime
from pathlib import Path
import argparse
import tempfile
import shutil
import statistics
class PerplexityTester:
def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity",
data_dir="../data", output_dir="perplexity_results", quick_mode=False,
quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None):
self.model_path = Path(model_path)
self.llama_perplexity_bin = Path(llama_perplexity_bin)
self.quantize_bin = Path(quantize_bin)
self.data_dir = Path(data_dir)
self.output_dir = Path(output_dir)
self.quick_mode = quick_mode
self.test_embeddings = test_embeddings
self.csv_output = Path(csv_output) if csv_output else None
self.results = []
self.created_models = set() # Track newly created model files
self.temp_files = [] # Track temporary files for cleanup
# Embedding types to test
self.embedding_types = [
('F32', 'f32'),
('F16', 'f16'),
('Q8_0', 'q8_0'),
('Q6_K', 'q6_k'),
('Q5_0', 'q5_0'),
('Q4_0', 'q4_0'),
('Q3_K', 'q3_k'),
('TQ2_0', 'tq2_0'),
]
# Create output directory
self.output_dir.mkdir(parents=True, exist_ok=True)
# Verify llama-perplexity binary exists
if not self.llama_perplexity_bin.exists():
raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}")
# Verify quantize binary exists if testing embeddings
if self.test_embeddings and not self.quantize_bin.exists():
raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}")
# Verify model file exists
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
def find_datasets(self):
"""Find all test.txt files in dataset directories."""
datasets = []
if not self.data_dir.exists():
print(f"❌ Data directory not found: {self.data_dir}")
return datasets
print(f"\n🔍 Searching for datasets in {self.data_dir}...")
# Look for test.txt files in subdirectories
for dataset_dir in sorted(self.data_dir.iterdir()):
if dataset_dir.is_dir():
test_file = dataset_dir / "test.txt"
if test_file.exists():
size_mb = test_file.stat().st_size / (1024 * 1024)
datasets.append({
'name': dataset_dir.name,
'path': test_file,
'size': test_file.stat().st_size,
'size_mb': size_mb
})
print(f"{dataset_dir.name:<20} ({size_mb:.2f} MB)")
else:
print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)")
return datasets
def create_quick_dataset(self, dataset_path, num_chars=4096):
"""Create a temporary dataset with only the first N characters for quick testing."""
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8')
self.temp_files.append(temp_file.name)
try:
with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read(num_chars)
temp_file.write(content)
temp_file.close()
return Path(temp_file.name)
except Exception as e:
print(f"⚠️ Failed to create quick dataset: {e}")
temp_file.close()
return dataset_path
def cleanup_temp_files(self):
"""Clean up temporary files."""
for temp_file in self.temp_files:
try:
os.unlink(temp_file)
except:
pass
self.temp_files = []
def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None):
"""Run perplexity test on a single dataset."""
test_model = model_override if model_override else self.model_path
print(f"\n{'='*80}")
print(f"📊 Testing on dataset: {dataset_name}")
print(f" File: {dataset_path}")
print(f" Model: {test_model.name}")
print(f"{'='*80}")
cmd = [
str(self.llama_perplexity_bin),
"-m", str(test_model),
"-f", str(dataset_path),
"-t", str(threads),
"-c", str(ctx_size),
"-ngl", "0" # CPU only
]
print(f"💻 Command: {' '.join(cmd)}")
print(f"⏱️ Starting test...\n")
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600, # 1 hour timeout
cwd=os.getcwd()
)
elapsed_time = time.time() - start_time
if result.returncode == 0:
# Parse perplexity from output (check both stdout and stderr)
combined_output = result.stdout + "\n" + result.stderr
ppl = self.parse_perplexity(combined_output)
if ppl is not None:
print(f"\n✅ Perplexity: {ppl}")
print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)")
status = "success"
else:
print(f"\n⚠️ Test completed but could not parse perplexity")
print(f"Last 500 chars of stdout:")
print(result.stdout[-500:])
print(f"Last 500 chars of stderr:")
print(result.stderr[-500:])
status = "parse_error"
ppl = None
else:
print(f"\n❌ Test failed with return code {result.returncode}")
print(f"Error: {result.stderr[:500]}")
status = "failed"
ppl = None
elapsed_time = time.time() - start_time
return {
'dataset': dataset_name,
'perplexity': ppl,
'time': elapsed_time,
'status': status,
'stdout': result.stdout,
'stderr': result.stderr
}
except subprocess.TimeoutExpired:
elapsed_time = time.time() - start_time
print(f"\n❌ Timeout after {elapsed_time:.2f}s")
return {
'dataset': dataset_name,
'perplexity': None,
'time': elapsed_time,
'status': 'timeout',
'stdout': '',
'stderr': 'Test exceeded 1 hour timeout'
}
except Exception as e:
elapsed_time = time.time() - start_time
print(f"\n❌ Error: {e}")
return {
'dataset': dataset_name,
'perplexity': None,
'time': elapsed_time,
'status': 'error',
'stdout': '',
'stderr': str(e)
}
def parse_perplexity(self, output):
"""Parse perplexity value (mean±std format) from llama-perplexity output."""
# First try to match "PPL = mean +/- std" format
pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)'
match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE)
if match:
try:
mean = float(match.group(1))
std = float(match.group(2))
return f"{mean:.4f}±{std:.4f}"
except ValueError:
pass
# Fallback to patterns without std
patterns = [
r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)',
r'Final perplexity:\s*(\d+\.?\d*)',
r'PPL\s*=\s*(\d+\.?\d*)',
r'PPL:\s*(\d+\.?\d*)',
r'perplexity:\s*(\d+\.?\d*)',
r'ppl\s*=\s*(\d+\.?\d*)',
r'Perplexity:\s*(\d+\.?\d*)',
]
for pattern in patterns:
match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE)
if match:
try:
return f"{float(match.group(1)):.4f}"
except ValueError:
continue
return None
def quantize_embedding(self, embedding_type, output_suffix):
"""
Quantize model with specific embedding type.
Args:
embedding_type: Token embedding type (uppercase, e.g., 'Q6_K')
output_suffix: Output file suffix (lowercase, e.g., 'q6_k')
Returns:
Path to quantized model or None if failed
"""
# Construct output path
model_dir = self.model_path.parent
output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf"
# Check if file already exists
file_existed = output_path.exists()
if file_existed:
print(f" Model already exists: {output_path.name}")
return output_path
cmd = [
str(self.quantize_bin),
"--token-embedding-type", embedding_type,
str(self.model_path),
str(output_path),
"I2_S",
"1",
"1"
]
print(f"\n{'='*80}")
print(f"🔄 Quantizing with embedding type: {embedding_type}")
print(f"📥 Input: {self.model_path.name}")
print(f"📤 Output: {output_path.name}")
print(f"💻 Command: {' '.join(cmd)}")
print(f"{'='*80}\n")
start_time = time.time()
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=os.getcwd(),
timeout=600 # 10 minutes timeout
)
duration = time.time() - start_time
if result.returncode == 0:
file_size_mb = output_path.stat().st_size / (1024 * 1024)
print(f"✅ Quantization successful!")
print(f" Duration: {duration:.2f}s")
print(f" Size: {file_size_mb:.2f} MB")
# Mark as newly created
self.created_models.add(output_path)
return output_path
else:
print(f"❌ Quantization failed with return code {result.returncode}")
print(f"Error: {result.stderr[:500]}")
return None
except subprocess.TimeoutExpired:
print(f"❌ Quantization timeout (exceeded 10 minutes)")
return None
except Exception as e:
print(f"❌ Quantization error: {e}")
return None
def cleanup_model(self, model_path):
"""Delete model file if it was created during this session."""
if model_path in self.created_models:
try:
model_path.unlink()
print(f"🗑️ Deleted: {model_path.name}")
self.created_models.remove(model_path)
except Exception as e:
print(f"⚠️ Failed to delete {model_path.name}: {e}")
else:
print(f" Keeping existing file: {model_path.name}")
def run_all_tests(self, threads=16, ctx_size=512):
"""Run perplexity tests on all datasets."""
datasets = self.find_datasets()
if not datasets:
print(f"\n❌ No datasets found in {self.data_dir}")
print(f" Make sure each dataset directory has a test.txt file")
return
# Quick mode: test all datasets but only first 4096 chars with smaller context
if self.quick_mode:
ctx_size = min(ctx_size, 128) # Use smaller context in quick mode
print(f"\n⚡ QUICK TEST MODE ENABLED")
print(f" - Testing all datasets with first 4096 characters only")
print(f" - Using reduced context size: {ctx_size}")
# Determine models to test
if self.test_embeddings:
print(f"\n{'='*80}")
print(f"🧪 EMBEDDING QUANTIZATION TEST MODE")
print(f"{'='*80}")
print(f"📦 Base model: {self.model_path.name}")
print(f"🔢 Embedding types to test: {len(self.embedding_types)}")
print(f"📊 Datasets: {len(datasets)}")
print(f"🧵 Threads: {threads}")
print(f"📏 Context size: {ctx_size}")
print(f"{'='*80}")
total_start = time.time()
# Test each embedding type
for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1):
print(f"\n\n{'#'*80}")
print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})")
print(f"{'#'*80}")
# Quantize model
quantized_model = self.quantize_embedding(embedding_type, output_suffix)
if quantized_model is None:
print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure")
continue
# Test on all datasets
for j, dataset in enumerate(datasets, 1):
print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...")
# Use quick dataset if in quick mode
test_path = dataset['path']
if self.quick_mode:
test_path = self.create_quick_dataset(dataset['path'])
result = self.run_perplexity_test(
f"{dataset['name']}_embed-{output_suffix}",
test_path,
threads,
ctx_size,
model_override=quantized_model
)
self.results.append(result)
# Cleanup model after testing
print(f"\n🧹 Cleaning up {output_suffix} model...")
self.cleanup_model(quantized_model)
print(f"\n{'#'*80}")
print(f"✅ Completed {output_suffix}")
print(f"{'#'*80}")
total_time = time.time() - total_start
else:
# Regular single model test
print(f"\n{'='*80}")
print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}")
print(f"{'='*80}")
print(f"📦 Model: {self.model_path.name}")
print(f"📁 Model path: {self.model_path}")
print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}")
print(f"🧵 Threads: {threads}")
print(f"📏 Context size: {ctx_size}")
print(f"{'='*80}")
total_start = time.time()
# Run tests
for i, dataset in enumerate(datasets, 1):
print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...")
# Use quick dataset if in quick mode
test_path = dataset['path']
if self.quick_mode:
test_path = self.create_quick_dataset(dataset['path'])
result = self.run_perplexity_test(
dataset['name'],
test_path,
threads,
ctx_size
)
self.results.append(result)
total_time = time.time() - total_start
# Clean up temporary files
if self.quick_mode:
print(f"\n🧹 Cleaning up temporary files...")
self.cleanup_temp_files()
# Save results
self.save_results()
# Print summary
self.print_summary(total_time)
def save_results(self):
"""Save results to CSV file."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = self.model_path.stem
# Use custom CSV path if provided
if self.csv_output:
csv_file = self.csv_output
# Create parent directory if needed
csv_file.parent.mkdir(parents=True, exist_ok=True)
else:
csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv"
print(f"\n💾 Saving results...")
with open(csv_file, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status'])
writer.writeheader()
for result in self.results:
writer.writerow({
'dataset': result['dataset'],
'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A',
'time_seconds': f"{result['time']:.2f}",
'status': result['status']
})
print(f" ✅ CSV saved: {csv_file}")
# Save detailed log
log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log"
with open(log_file, 'w') as f:
f.write(f"Perplexity Test Results\n")
f.write(f"{'='*80}\n")
f.write(f"Model: {self.model_path}\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"{'='*80}\n\n")
for result in self.results:
f.write(f"\n{'='*80}\n")
f.write(f"Dataset: {result['dataset']}\n")
f.write(f"Perplexity: {result['perplexity']}\n")
f.write(f"Time: {result['time']:.2f}s\n")
f.write(f"Status: {result['status']}\n")
f.write(f"\nOutput:\n{result['stdout']}\n")
if result['stderr']:
f.write(f"\nErrors:\n{result['stderr']}\n")
print(f" ✅ Log saved: {log_file}")
def print_summary(self, total_time):
"""Print summary of all tests."""
print(f"\n\n{'='*80}")
print(f"📊 TEST SUMMARY")
print(f"{'='*80}\n")
# Sort results by perplexity (lower is better)
successful = [r for r in self.results if r['perplexity'] is not None]
failed = [r for r in self.results if r['perplexity'] is None]
if successful:
# Extract numeric value from "mean±std" format for sorting
def get_ppl_value(result):
ppl = result['perplexity']
if isinstance(ppl, str) and '±' in ppl:
return float(ppl.split('±')[0])
elif isinstance(ppl, str):
try:
return float(ppl)
except ValueError:
return float('inf')
return ppl
successful_sorted = sorted(successful, key=get_ppl_value)
print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}")
print(f"{'-'*80}")
for result in successful_sorted:
ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A'
print(f"{result['dataset']:<20} {ppl_str:>20} "
f"{result['time']:>12.2f} {result['status']:<15}")
best_ppl = str(successful_sorted[0]['perplexity'])
print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} "
f"(PPL: {best_ppl})")
if failed:
print(f"\n❌ Failed tests ({len(failed)}):")
for result in failed:
print(f" - {result['dataset']}: {result['status']}")
print(f"\n{'='*80}")
print(f"✅ Completed: {len(successful)}/{len(self.results)}")
print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f"📁 Results saved in: {self.output_dir}")
print(f"{'='*80}\n")
def main():
parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets')
parser.add_argument('--model', '-m',
required=True,
help='Path to GGUF model file')
parser.add_argument('--data-dir', '-d',
default='data',
help='Directory containing dataset folders (default: data)')
parser.add_argument('--threads', '-t',
type=int,
default=16,
help='Number of threads (default: 16)')
parser.add_argument('--ctx-size', '-c',
type=int,
default=512,
help='Context size (default: 512)')
parser.add_argument('--output-dir', '-o',
default='perplexity_results',
help='Output directory for results (default: perplexity_results)')
parser.add_argument('--llama-perplexity',
default='./build/bin/llama-perplexity',
help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)')
parser.add_argument('--quick', '-q',
action='store_true',
help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)')
parser.add_argument('--test-embeddings', '-e',
action='store_true',
help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)')
parser.add_argument('--csv-output',
help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)')
parser.add_argument('--quantize-bin',
default='./build/bin/llama-quantize',
help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)')
args = parser.parse_args()
try:
tester = PerplexityTester(
model_path=args.model,
llama_perplexity_bin=args.llama_perplexity,
data_dir=args.data_dir,
output_dir=args.output_dir,
quick_mode=args.quick,
quantize_bin=args.quantize_bin,
test_embeddings=args.test_embeddings,
csv_output=args.csv_output
)
tester.run_all_tests(
threads=args.threads,
ctx_size=args.ctx_size
)
except FileNotFoundError as e:
print(f"❌ Error: {e}")
return 1
except KeyboardInterrupt:
print("\n\n⚠️ Test interrupted by user")
return 1
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())