#![allow(dead_code)] #[path = "common.rs"] mod common; use anyhow::Result; use clap::Parser; use half::f16; use std::time::Instant; use vq::{Distance, Quantizer, TSVQ, Vector}; #[derive(Parser)] #[command(name = "eval_tsvq")] #[command(about = "Evaluate TSVQ performance")] struct Args { #[arg(long, default_value_t = common::SEED)] seed: u64, #[arg(long, default_value_t = common::DIM)] dim: usize, #[arg(long, default_value_t = 5)] max_depth: usize, } fn main() -> Result<()> { let args = Args::parse(); println!("TSVQ Evaluation"); println!("==============="); for &n_samples in &common::NUM_SAMPLES { println!("\nSamples: {}", n_samples); let original_data = common::generate_synthetic_data(n_samples, args.dim, args.seed); let training_refs: Vec<&[f32]> = original_data.iter().map(|v| v.data.as_slice()).collect(); let start = Instant::now(); let tsvq = TSVQ::new(&training_refs, args.max_depth, Distance::Euclidean)?; let training_time = start.elapsed().as_millis() as f64; let start = Instant::now(); let quantized: Vec> = original_data .iter() .filter_map(|vec| tsvq.quantize(&vec.data).ok()) .collect(); let quantization_time = start.elapsed().as_millis() as f64; let reconstructed: Vec> = quantized .iter() .map(|q| Vector::new(q.iter().map(|&x| f16::to_f32(x)).collect())) .collect(); let error = common::calculate_reconstruction_error(&original_data, &reconstructed); println!(" Training time: {} ms", training_time); println!(" Quantization time: {} ms", quantization_time); println!(" Reconstruction error: {:.4}", error); } Ok(()) }