This commit is contained in:
Yggdrasil75
2026-01-29 09:15:02 -05:00
parent 670ff42b82
commit c282acd725
3 changed files with 86 additions and 113 deletions

View File

@@ -30,6 +30,7 @@ public:
struct NodeData {
T data;
PointType position;
int objectId;
bool active;
bool visible;
float size;
@@ -40,12 +41,12 @@ public:
float reflection;
NodeData(const T& data, const PointType& pos, bool visible, Eigen::Vector3f color, float size = 0.01f,
bool active = true, bool light = false, float emittance = 0.0f, float refraction = 0.0f,
float reflection = 0.0f) : data(data), position(pos), active(active), visible(visible),
bool active = true, int objectId = -1, bool light = false, float emittance = 0.0f, float refraction = 0.0f,
float reflection = 0.0f) : data(data), position(pos), objectId(objectId), active(active), visible(visible),
color(color), size(size), light(light), emittance(emittance), refraction(refraction),
reflection(reflection) {}
NodeData() : active(false), visible(false), size(0.0f), light(false),
NodeData() : objectId(-1), active(false), visible(false), size(0.0f), light(false),
emittance(0.0f), refraction(0.0f), reflection(0.0f) {}
};
@@ -175,6 +176,7 @@ private:
writeVal(out, pt->data);
// Write properties
writeVec3(out, pt->position);
writeVal(out, pt->objectId);
writeVal(out, pt->active);
writeVal(out, pt->visible);
writeVal(out, pt->size);
@@ -217,6 +219,7 @@ private:
auto pt = std::make_shared<NodeData>();
readVal(in, pt->data);
readVec3(in, pt->position);
readVal(in, pt->objectId);
readVal(in, pt->active);
readVal(in, pt->visible);
readVal(in, pt->size);
@@ -253,98 +256,41 @@ private:
}
void bitonic_sort_8(std::array<std::pair<int, float>, 8>& arr) const noexcept {
#ifdef SSE
alignas(32) float values[8];
alignas(32) uint32_t indices[8];
for (int i = 0; i < 8; i++) {
values[i] = arr[i].second;
indices[i] = arr[i].first;
}
__m256 val = _mm256_load_ps(values);
__m256i idx = _mm256_load_si256((__m256i*)indices);
__m256 swapped1 = _mm256_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
__m256i swapped_idx1 = _mm256_shuffle_epi32(idx, _MM_SHUFFLE(2, 3, 0, 1));
__m256 mask1 = _mm256_cmp_ps(val, swapped1, _CMP_GT_OQ);
val = _mm256_blendv_ps(val, swapped1, mask1);
idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(idx),
_mm256_castsi256_ps(swapped_idx1),
mask1));
__m256 swapped2 = _mm256_permute2f128_ps(val, val, 0x01);
__m256i swapped_idx2 = _mm256_permute2f128_si256(idx, idx, 0x01);
__m256 mask2 = _mm256_cmp_ps(val, swapped2, _CMP_GT_OQ);
val = _mm256_blendv_ps(val, swapped2, mask2);
idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(idx),
_mm256_castsi256_ps(swapped_idx2),
mask2));
__m256 swapped3 = _mm256_shuffle_ps(val, val, _MM_SHUFFLE(1, 0, 3, 2));
__m256i swapped_idx3 = _mm256_shuffle_epi32(idx, _MM_SHUFFLE(1, 0, 3, 2));
__m256 mask3 = _mm256_cmp_ps(val, swapped3, _CMP_GT_OQ);
val = _mm256_blendv_ps(val, swapped3, mask3);
idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(idx),
_mm256_castsi256_ps(swapped_idx3),
mask3));
__m256 swapped4 = _mm256_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
__m256i swapped_idx4 = _mm256_shuffle_epi32(idx, _MM_SHUFFLE(2, 3, 0, 1));
__m256 mask4 = _mm256_cmp_ps(val, swapped4, _CMP_GT_OQ);
val = _mm256_blendv_ps(val, swapped4, mask4);
idx = _mm256_castps_si256(_mm256_blendv_ps(
_mm256_castsi256_ps(idx),
_mm256_castsi256_ps(swapped_idx4),
mask4));
_mm256_store_ps(values, val);
_mm256_store_si256((__m256i*)indices, idx);
for (int i = 0; i < 8; i++) {
arr[i].second = values[i];
arr[i].first = (uint8_t)indices[i];
}
#else
auto a0 = arr[0], a1 = arr[1], a2 = arr[2], a3 = arr[3];
auto a4 = arr[4], a5 = arr[5], a6 = arr[6], a7 = arr[7];
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second < a3.second) std::swap(a2, a3);
if (a4.second > a5.second) std::swap(a4, a5);
if (a6.second < a7.second) std::swap(a6, a7);
if (a0.second > a2.second) std::swap(a0, a2);
if (a1.second > a3.second) std::swap(a1, a3);
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second > a3.second) std::swap(a2, a3);
if (a4.second < a6.second) std::swap(a4, a6);
if (a5.second < a7.second) std::swap(a5, a7);
if (a4.second < a5.second) std::swap(a4, a5);
if (a6.second < a7.second) std::swap(a6, a7);
if (a0.second > a4.second) std::swap(a0, a4);
if (a1.second > a5.second) std::swap(a1, a5);
if (a2.second > a6.second) std::swap(a2, a6);
if (a3.second > a7.second) std::swap(a3, a7);
if (a0.second > a2.second) std::swap(a0, a2);
if (a1.second > a3.second) std::swap(a1, a3);
if (a4.second > a6.second) std::swap(a4, a6);
if (a5.second > a7.second) std::swap(a5, a7);
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second > a3.second) std::swap(a2, a3);
if (a4.second > a5.second) std::swap(a4, a5);
if (a6.second > a7.second) std::swap(a6, a7);
arr[0] = a0; arr[1] = a1; arr[2] = a2; arr[3] = a3;
arr[4] = a4; arr[5] = a5; arr[6] = a6; arr[7] = a7;
#endif
auto a0 = arr[0], a1 = arr[1], a2 = arr[2], a3 = arr[3];
auto a4 = arr[4], a5 = arr[5], a6 = arr[6], a7 = arr[7];
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second < a3.second) std::swap(a2, a3);
if (a4.second > a5.second) std::swap(a4, a5);
if (a6.second < a7.second) std::swap(a6, a7);
if (a0.second > a2.second) std::swap(a0, a2);
if (a1.second > a3.second) std::swap(a1, a3);
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second > a3.second) std::swap(a2, a3);
if (a4.second < a6.second) std::swap(a4, a6);
if (a5.second < a7.second) std::swap(a5, a7);
if (a4.second < a5.second) std::swap(a4, a5);
if (a6.second < a7.second) std::swap(a6, a7);
if (a0.second > a4.second) std::swap(a0, a4);
if (a1.second > a5.second) std::swap(a1, a5);
if (a2.second > a6.second) std::swap(a2, a6);
if (a3.second > a7.second) std::swap(a3, a7);
if (a0.second > a2.second) std::swap(a0, a2);
if (a1.second > a3.second) std::swap(a1, a3);
if (a4.second > a6.second) std::swap(a4, a6);
if (a5.second > a7.second) std::swap(a5, a7);
if (a0.second > a1.second) std::swap(a0, a1);
if (a2.second > a3.second) std::swap(a2, a3);
if (a4.second > a5.second) std::swap(a4, a5);
if (a6.second > a7.second) std::swap(a6, a7);
arr[0] = a0; arr[1] = a1; arr[2] = a2; arr[3] = a3;
arr[4] = a4; arr[5] = a5; arr[6] = a6; arr[7] = a7;
}
bool rayBoxIntersect(const PointType& origin, const PointType& dir, const BoundingBox& box,
@@ -402,8 +348,8 @@ public:
Octree() : root_(nullptr), maxPointsPerNode(16), maxDepth(16), size(0) {}
bool set(const T& data, const PointType& pos, bool visible, Eigen::Vector3f color, float size, bool active,
bool light = false, float emittance = 0.0f, float refraction = 0.0f, float reflection = 0.0f) {
auto pointData = std::make_shared<NodeData>(data, pos, visible, color, size, active,
int objectId = -1, bool light = false, float emittance = 0.0f, float refraction = 0.0f, float reflection = 0.0f) {
auto pointData = std::make_shared<NodeData>(data, pos, visible, color, size, active, objectId,
light, emittance, refraction, reflection);
if (insertRecursive(root_.get(), pointData, 0)) {
this->size++;
@@ -504,7 +450,7 @@ public:
if (found) {
node->points.erase(it, node->points.end());
size--; // Decrement size counter
size--;
return true;
}
return false;
@@ -520,7 +466,7 @@ public:
return removeNode(root_.get());
}
std::vector<std::shared_ptr<NodeData>> findInRadius(const PointType& center, float radius) {
std::vector<std::shared_ptr<NodeData>> findInRadius(const PointType& center, float radius) const {
std::vector<std::shared_ptr<NodeData>> results;
if (!root_) return results;
@@ -538,8 +484,8 @@ public:
}
float distSq = (closestPoint - center).squaredNorm();
if (distSq > (radius + boxHalfSize.norm()) * (radius + boxHalfSize.norm())) {
return; // No intersection
if (distSq > radiusSq) {
return;
}
if (node->isLeaf) {
@@ -566,7 +512,7 @@ public:
bool update(const PointType& oldPos, const PointType& newPos, const T& newData = T(), bool newVisible = true,
Eigen::Vector3f newColor = Eigen::Vector3f(1.0f, 1.0f, 1.0f), float newSize = 0.01f, bool newActive = true,
bool newLight = false, float newEmittance = 0.0f, float newRefraction = 0.0f, float newReflection = 0.0f,
int newObjectId = -2, bool newLight = false, float newEmittance = 0.0f, float newRefraction = 0.0f, float newReflection = 0.0f,
float tolerance = 0.0001f) {
// Find the existing point
@@ -582,6 +528,7 @@ public:
bool visibleCopy = pointData->visible;
Eigen::Vector3f colorCopy = pointData->color;
float sizeCopy = pointData->size;
int objectIdCopy = pointData->objectId;
bool lightCopy = pointData->light;
float emittanceCopy = pointData->emittance;
float refractionCopy = pointData->refraction;
@@ -598,6 +545,7 @@ public:
newColor != Eigen::Vector3f(1.0f, 1.0f, 1.0f) ? newColor : colorCopy,
newSize > 0 ? newSize : sizeCopy,
newActive ? newActive : activeCopy,
newObjectId != -2 ? newObjectId : objectIdCopy,
newLight ? newLight : lightCopy,
newEmittance > 0 ? newEmittance : emittanceCopy,
newRefraction >= 0 ? newRefraction : refractionCopy,
@@ -609,6 +557,7 @@ public:
pointData->visible = newVisible;
pointData->color = newColor;
pointData->size = newSize;
if (newObjectId != -2) pointData->objectId = newObjectId;
pointData->active = newActive;
pointData->light = newLight;
pointData->emittance = newEmittance;
@@ -618,6 +567,13 @@ public:
}
}
bool setObjectId(const PointType& pos, int objectId, float tolerance = 0.0001f) {
auto pointData = find(pos, tolerance);
if (!pointData) return false;
pointData->objectId = objectId;
return true;
}
bool updateData(const PointType& pos, const T& newData, float tolerance = 0.0001f) {
auto pointData = find(pos, tolerance);
if (!pointData) return false;
@@ -675,7 +631,7 @@ public:
}
std::vector<std::shared_ptr<NodeData>> voxelTraverse(const PointType& origin, const PointType& direction,
float maxDist, bool stopAtFirstHit) {
float maxDist, bool stopAtFirstHit) const {
std::vector<std::shared_ptr<NodeData>> hits;
if (empty()) return hits;
@@ -763,7 +719,7 @@ public:
float tanfovy = tanHalfFov;
float tanfovx = tanHalfFov * aspect;
PointType space(0,0,0);
if (globalIllumination) space = {0.1,0.1,0.1};
if (globalIllumination) space = {0.1f, 0.1f, 0.1f};
const Eigen::Vector3f defaultColor(0.01f, 0.01f, 0.01f);
float rayLength = std::numeric_limits<float>::max();
@@ -917,7 +873,6 @@ public:
size_t maxPointsInLeaf = 0;
size_t minPointsInLeaf = std::numeric_limits<size_t>::max();
// Recursive lambda to gather stats
std::function<void(const OctreeNode*, size_t)> traverse =
[&](const OctreeNode* node, size_t depth) {
if (!node) return;