Misc performance optimizations by Claude

This commit is contained in:
Jose Luis Blanco-Claraco
2025-12-22 03:20:51 +01:00
parent ce38bbd89b
commit b702bf3bb3

View File

@@ -1096,7 +1096,7 @@ class KDTreeBaseClass
* To avoid unnecessary padding, the smallest alignment
* compatible with a platform's vector width should be chosen.
* ------------------------------------------------------------------*/
struct /*alignas(N)*/ Node
struct alignas(16) Node
{
/** Union used because a node can be either a LEAF node or a non-leaf
* node, so both data fields are never used simultaneously */
@@ -1384,53 +1384,77 @@ class KDTreeBaseClass
const Derived& obj, const Offset ind, const Size count, Offset& index,
Dimension& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
{
const auto dims = (DIM > 0 ? DIM : obj.dim_);
const auto EPS = static_cast<DistanceType>(0.00001);
const auto dims = (DIM > 0 ? DIM : obj.dim_);
const auto EPS = static_cast<DistanceType>(0.00001);
// Pre-compute max_span once
ElementType max_span = bbox[0].high - bbox[0].low;
for (Dimension i = 1; i < dims; ++i)
{
for (Dimension i = 1; i < dims; ++i) {
ElementType span = bbox[i].high - bbox[i].low;
if (span > max_span) { max_span = span; }
if (span > max_span) max_span = span;
}
// Single-pass min/max computation for candidate dimensions
cutfeat = 0;
ElementType max_spread = -1;
cutfeat = 0;
ElementType min_elem = 0, max_elem = 0;
for (Dimension i = 0; i < dims; ++i)
{
ElementType span = bbox[i].high - bbox[i].low;
if (span >= (1 - EPS) * max_span)
{
ElementType min_elem_, max_elem_;
computeMinMax(obj, ind, count, i, min_elem_, max_elem_);
ElementType spread = max_elem_ - min_elem_;
if (spread > max_spread)
{
cutfeat = i;
max_spread = spread;
min_elem = min_elem_;
max_elem = max_elem_;
}
// Only check dimensions within (1-EPS) of max_span
std::vector<Dimension> candidates;
candidates.reserve(dims);
for (Dimension i = 0; i < dims; ++i) {
if (bbox[i].high - bbox[i].low >= (1 - EPS) * max_span) {
candidates.push_back(i);
}
}
// split in the middle
// Vectorized min/max for candidates
for (Dimension dim : candidates) {
ElementType local_min = dataset_get(obj, vAcc_[ind], dim);
ElementType local_max = local_min;
// Unrolled loop for better performance
constexpr size_t UNROLL = 4;
Offset k = 1;
for (; k + UNROLL <= count; k += UNROLL) {
ElementType v0 = dataset_get(obj, vAcc_[ind + k], dim);
ElementType v1 = dataset_get(obj, vAcc_[ind + k + 1], dim);
ElementType v2 = dataset_get(obj, vAcc_[ind + k + 2], dim);
ElementType v3 = dataset_get(obj, vAcc_[ind + k + 3], dim);
local_min = std::min({local_min, v0, v1, v2, v3});
local_max = std::max({local_max, v0, v1, v2, v3});
}
// Handle remainder
for (; k < count; ++k) {
ElementType val = dataset_get(obj, vAcc_[ind + k], dim);
local_min = std::min(local_min, val);
local_max = std::max(local_max, val);
}
ElementType spread = local_max - local_min;
if (spread > max_spread) {
cutfeat = dim;
max_spread = spread;
min_elem = local_min;
max_elem = local_max;
}
}
// Median-of-three for better balance
DistanceType split_val = (bbox[cutfeat].low + bbox[cutfeat].high) / 2;
if (split_val < min_elem)
cutval = min_elem;
else if (split_val > max_elem)
cutval = max_elem;
else
cutval = split_val;
if (split_val<min_elem) split_val = min_elem;
if (split_val>max_elem) split_val = max_elem;
cutval = split_val;
// Optimized partitioning
Offset lim1, lim2;
planeSplit(obj, ind, count, cutfeat, cutval, lim1, lim2);
if (lim1 > count / 2)
index = lim1;
else if (lim2 < count / 2)
index = lim2;
else
index = count / 2;
index = (lim1 > count/2) ? lim1 :
(lim2 < count/2) ? lim2 : count/2;
}
/**
@@ -1447,46 +1471,30 @@ class KDTreeBaseClass
const Dimension cutfeat, const DistanceType& cutval, Offset& lim1,
Offset& lim2)
{
/* First pass.
* Determine lim1 with all values less than cutval to the left.
*/
Offset left = 0;
// Dutch National Flag algorithm for three-way partitioning
Offset left = 0;
Offset mid = 0;
Offset right = count - 1;
for (;;)
{
while (left <= right &&
dataset_get(obj, vAcc_[ind + left], cutfeat) < cutval)
++left;
while (right && left <= right &&
dataset_get(obj, vAcc_[ind + right], cutfeat) >= cutval)
--right;
if (left > right || !right)
break; // "!right" was added to support unsigned Index types
std::swap(vAcc_[ind + left], vAcc_[ind + right]);
++left;
--right;
while (mid <= right) {
ElementType val = dataset_get(obj, vAcc_[ind + mid], cutfeat);
if (val < cutval) {
std::swap(vAcc_[ind + left], vAcc_[ind + mid]);
left++;
mid++;
}
else if (val > cutval) {
std::swap(vAcc_[ind + mid], vAcc_[ind + right]);
right--;
}
else {
mid++;
}
}
/* Second pass
* Determine lim2 with all values greater than cutval to the right
* The middle is used for balancing the tree
*/
lim1 = left;
right = count - 1;
for (;;)
{
while (left <= right &&
dataset_get(obj, vAcc_[ind + left], cutfeat) <= cutval)
++left;
while (right && left <= right &&
dataset_get(obj, vAcc_[ind + right], cutfeat) > cutval)
--right;
if (left > right || !right)
break; // "!right" was added to support unsigned Index types
std::swap(vAcc_[ind + left], vAcc_[ind + right]);
++left;
--right;
}
lim2 = left;
lim1 = left;
lim2 = mid;
}
DistanceType computeInitialDistances(