Browse Source

Adding nearest method to lookup3d at unit and structure level.

giagitom 2 years ago
parent
commit
718f76313f

+ 1 - 2
src/mol-math/geometry/lookup3d/common.ts

@@ -41,8 +41,7 @@ export namespace Result {
 export interface Lookup3D<T = number> {
     // The result is mutated with each call to find.
     find(x: number, y: number, z: number, radius: number, result?: Result<T>): Result<T>,
-    nearest(x: number, y: number, z: number): { index: number, squaredDistance: number } | undefined,
-    distanceTo(x: number, y: number, z: number): number,
+    nearest(x: number, y: number, z: number, k: number, result?: Result<T>): Result<T>,
     check(x: number, y: number, z: number, radius: number): boolean,
     readonly boundary: { readonly box: Box3D, readonly sphere: Sphere3D }
     /** transient result */

+ 128 - 25
src/mol-math/geometry/lookup3d/grid.ts

@@ -3,13 +3,14 @@
  *
  * @author David Sehnal <david.sehnal@gmail.com>
  * @author Alexander Rose <alexander.rose@weirdbyte.de>
+ * @author Gianluca Tomasello <giagitom@gmail.com>
  */
 
 import { Result, Lookup3D } from './common';
 import { Box3D } from '../primitives/box3d';
 import { Sphere3D } from '../primitives/sphere3d';
 import { PositionData } from '../common';
-import { Vec3 } from '../../linear-algebra';
+import { Vec3, EPSILON } from '../../linear-algebra';
 import { OrderedSet } from '../../../mol-data/int';
 import { Boundary } from '../boundary';
 
@@ -40,32 +41,14 @@ class GridLookup3DImpl<T extends number = number> implements GridLookup3D<T> {
         return ret;
     }
 
-    nearest(x: number, y: number, z: number): { index: number, squaredDistance: number } | undefined {
-        if (!OrderedSet.size(this.ctx.grid.data.indices)) return undefined;
+    nearest(x: number, y: number, z: number, k: number = 1, result?: Result<T>): Result<T> {
         this.ctx.x = x;
         this.ctx.y = y;
         this.ctx.z = z;
-        const radiusIncrement = 0.1; // how to choose a good increment?
-        const startingRadius = this.distanceTo(x, y, z);
-        this.ctx.radius = startingRadius < 0 ? radiusIncrement : startingRadius + radiusIncrement;
-        this.ctx.isCheck = false;
-        const result = this.result;
-        query(this.ctx, result);
-        while (!result.count) { // necessary to check if grid is empty to avoid infinite loop
-            this.ctx.radius += radiusIncrement;
-            query(this.ctx, result);
-        }
-        const { indices, squaredDistances } = result;
-        let index = indices[0], nearestDist = squaredDistances[0];
-
-        for (let i = 1, l = indices.length; i < l; i++) {
-            const dist = squaredDistances[i];
-            if (dist < nearestDist) {
-                nearestDist = dist;
-                index = indices[i];
-            }
-        }
-        return { index: index, squaredDistance: nearestDist };
+        this.ctx.k = k;
+        const ret = result ?? this.result;
+        queryNearest(this.ctx, ret);
+        return ret;
     }
 
     check(x: number, y: number, z: number, radius: number): boolean {
@@ -249,12 +232,13 @@ interface QueryContext {
     x: number,
     y: number,
     z: number,
+    k: number,
     radius: number,
     isCheck: boolean
 }
 
 function createContext(grid: Grid3D): QueryContext {
-    return { grid, x: 0.1, y: 0.1, z: 0.1, radius: 0.1, isCheck: false };
+    return { grid, x: 0.1, y: 0.1, z: 0.1, k: 1, radius: 0.1, isCheck: false };
 }
 
 function query<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
@@ -306,3 +290,122 @@ function query<T extends number = number>(ctx: QueryContext, result: Result<T>):
     }
     return result.count > 0;
 }
+
+const tmpDirVec = Vec3();
+const tmpVec = Vec3();
+const tmpMapG = new Map<number, boolean>();
+const tmpArrG1 = new Array();
+const tmpArrG2 = new Array();
+const tmpNearestIds = new Array();
+function queryNearest<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
+    const { expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
+    const [minX, minY, minZ] = box.min;
+    const { x, y, z, k } = ctx;
+    const indicesCount = OrderedSet.size(indices);
+    Result.reset(result);
+    if (indicesCount === 0) return false;
+    let gX, gY, gZ;
+    Vec3.set(tmpVec, x, y, z);
+    if (!Box3D.containsVec3(box, tmpVec)) {
+        // intersect ray pointing to box center
+        Box3D.nearestIntersectionWithRay(tmpVec, box, tmpVec, Vec3.normalize(tmpDirVec, Vec3.sub(tmpDirVec, center, tmpVec)));
+        gX = Math.max(0, Math.min(sX - 1, Math.floor((tmpVec[0] - minX) / delta[0])));
+        gY = Math.max(0, Math.min(sY - 1, Math.floor((tmpVec[1] - minY) / delta[1])));
+        gZ = Math.max(0, Math.min(sZ - 1, Math.floor((tmpVec[2] - minZ) / delta[2])));
+    } else {
+        gX = Math.floor((x - minX) / delta[0]);
+        gY = Math.floor((y - minY) / delta[1]);
+        gZ = Math.floor((z - minZ) / delta[2]);
+    }
+    let gCount = 1, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, prevFurthestDist = Number.MAX_VALUE, prevNearestDist = -Number.MAX_VALUE, nearestDist = Number.MAX_VALUE, findFurthest = true, furthestDist = - Number.MAX_VALUE, distSqG: number;
+    arrG.length = 0;
+    arrG.push(gX, gY, gZ);
+    tmpMapG.clear();
+    while (result.count < k && result.count < indicesCount) {
+        const arrGLen = gCount * 3;
+        for (let ig = 0; ig < arrGLen; ig += 3) {
+            gX = arrG[ig];
+            gY = arrG[ig + 1];
+            gZ = arrG[ig + 2];
+            const gridId = (((gX * sY) + gY) * sZ) + gZ;
+            if (tmpMapG.get(gridId)) continue; // already visited
+            tmpMapG.set(gridId, true);
+            distSqG = (gX - x) * (gX - x) + (gY - y) * (gY - y) + (gZ - z) * (gZ - z);
+            if (!findFurthest && distSqG > furthestDist && distSqG < nearestDist) continue;
+
+            // evaluate distances in the current grid point
+            const bucketIdx = grid[gridId];
+            if (bucketIdx !== 0) {
+                const ki = bucketIdx - 1;
+                const offset = bucketOffset[ki];
+                const count = bucketCounts[ki];
+                const end = offset + count;
+                for (let i = offset; i < end; i++) {
+                    const idx = OrderedSet.getAt(indices, bucketArray[i]);
+                    const dx = px[idx] - x;
+                    const dy = py[idx] - y;
+                    const dz = pz[idx] - z;
+                    let distSq = dx * dx + dy * dy + dz * dz;
+                    if (maxRadius !== 0) {
+                        const r = radius![idx];
+                        const sqR = r * r;
+                        if (findFurthest && distSq > furthestDist) furthestDist = distSq + sqR;
+                        distSq = distSq - sqR;
+
+                    } else {
+                        if (findFurthest && distSq > furthestDist) furthestDist = distSq;
+                    }
+
+                    if (distSq <= nearestDist && distSq > prevNearestDist) {
+                        if (nearestDist === distSq) { // handles multiple elements exactly at same distance
+                            tmpNearestIds.push(idx);
+                        } else {
+                            if (tmpNearestIds.length > 1) tmpNearestIds.length = 1;
+                            tmpNearestIds[0] = idx;
+                        }
+                        nearestDist = distSq;
+                    }
+                }
+                if (prevFurthestDist < furthestDist) findFurthest = false;
+            }
+
+            // fill grid points array with valid adiacent positions
+            for (let ix = -1; ix <= 1; ix++) {
+                const xPos = gX + ix;
+                if (xPos < 0 || xPos >= sX) continue;
+                for (let iy = -1; iy <= 1; iy++) {
+                    const yPos = gY + iy;
+                    if (yPos < 0 || yPos >= sY) continue;
+                    for (let iz = -1; iz <= 1; iz++) {
+                        const zPos = gZ + iz;
+                        if (zPos < 0 || zPos >= sZ) continue;
+                        nextArrG.push(xPos, yPos, zPos);
+                        nextGCount++;
+                    }
+                }
+            }
+        }
+        if (nextGCount === 0) {
+            prevNearestDist = nearestDist;
+            for (let i = 0, l = tmpNearestIds.length; i < l; i++) {
+                Result.add(result, tmpNearestIds[i], nearestDist);
+            }
+            if (furthestDist === nearestDist) {
+                findFurthest = true;
+                prevFurthestDist = furthestDist;
+                nearestDist = Number.MAX_VALUE;
+            } else {
+                nearestDist = furthestDist + EPSILON; // adding EPSILON fixes a bug
+            }
+            tmpMapG.clear();
+        } else {
+            const tmp = arrG;
+            arrG = nextArrG;
+            nextArrG = tmp;
+            nextArrG.length = 0;
+            gCount = nextGCount;
+            nextGCount = 0;
+        }
+    }
+    return result.count > 0;
+}

+ 43 - 1
src/mol-math/geometry/primitives/box3d.ts

@@ -138,6 +138,48 @@ namespace Box3D {
             a.max[2] < b.min[2] || a.min[2] > b.max[2]
         );
     }
+
+    // const tmpTransformV = Vec3();
+    export function nearestIntersectionWithRay(out: Vec3, box: Box3D, origin: Vec3, dir: Vec3): Vec3 {
+        const [minX, minY, minZ] = box.min;
+        const [maxX, maxY, maxZ] = box.max;
+        const [x, y, z] = origin;
+        const invDirX = 1.0 / dir[0];
+        const invDirY = 1.0 / dir[1];
+        const invDirZ = 1.0 / dir[2];
+        let tmin, tmax, tymin, tymax, tzmin, tzmax;
+        if (invDirX >= 0) {
+            tmin = (minX - x) * invDirX;
+            tmax = (maxX - x) * invDirX;
+        } else {
+            tmin = (maxX - x) * invDirX;
+            tmax = (minX - x) * invDirX;
+        }
+        if (invDirY >= 0) {
+            tymin = (minY - y) * invDirY;
+            tymax = (maxY - y) * invDirY;
+        } else {
+            tymin = (maxY - y) * invDirY;
+            tymax = (minY - y) * invDirY;
+        }
+        if (invDirZ >= 0) {
+            tzmin = (minZ - z) * invDirZ;
+            tzmax = (maxZ - z) * invDirZ;
+        } else {
+            tzmin = (maxZ - z) * invDirZ;
+            tzmax = (minZ - z) * invDirZ;
+        }
+        if (tymin > tmin)
+            tmin = tymin;
+        if (tymax < tmax)
+            tmax = tymax;
+        if (tzmin > tmin)
+            tmin = tzmin;
+        if (tzmax < tmax)
+            tmax = tzmax;
+        Vec3.scale(out, dir, tmin);
+        return Vec3.set(out, out[0] + x, out[1] + y, out[2] + z);
+    }
 }
 
-export { Box3D };
+export { Box3D };

+ 55 - 35
src/mol-model/structure/structure/util/lookup3d.ts

@@ -3,12 +3,14 @@
  *
  * @author David Sehnal <david.sehnal@gmail.com>
  * @author Alexander Rose <alexander.rose@weirdbyte.de>
+ * @author Gianluca Tomasello <giagitom@gmail.com>
  */
 
 import { Structure } from '../structure';
 import { Lookup3D, GridLookup3D, Result } from '../../../../mol-math/geometry';
 import { Vec3 } from '../../../../mol-math/linear-algebra';
 import { OrderedSet } from '../../../../mol-data/int';
+import { arrayLess, arraySwap } from '../../../../mol-data/util';
 import { StructureUniqueSubsetBuilder } from './unique-subset-builder';
 import { StructureElement } from '../element';
 import { Unit } from '../unit';
@@ -39,6 +41,25 @@ export namespace StructureResult {
         out.count = result.count;
         return out;
     }
+
+    // sort in ascending order based on squaredDistances
+    export function sort(result: StructureResult) {
+        const { indices, squaredDistances, units, count } = result;
+        if (indices.length > count) {
+            // clear arrays before doing sorting if necessary
+            indices.length = count;
+            squaredDistances.length = count;
+            units.length = count;
+        }
+        for (let i = 1, c = result.count; i < c; i++) {
+            if (arrayLess(squaredDistances, i - 1, i) > 0) {
+                arraySwap(squaredDistances, i - 1, i);
+                arraySwap(indices, i - 1, i);
+                arraySwap(units, i - 1, i);
+                i = Math.max(0, i - 2);
+            }
+        }
+    }
 }
 
 export interface StructureLookup3DResultContext {
@@ -59,6 +80,10 @@ export class StructureLookup3D {
         return this.unitLookup.find(x, y, z, radius);
     }
 
+    nearestUnitIndices(x: number, y: number, z: number, k: number = 1): Result<number> {
+        return this.unitLookup.nearest(x, y, z, k);
+    }
+
     private findContext = StructureLookup3DResultContext();
 
     find(x: number, y: number, z: number, radius: number, ctx?: StructureLookup3DResultContext): StructureResult {
@@ -86,45 +111,40 @@ export class StructureLookup3D {
         return ctx.result;
     }
 
-    nearest(x: number, y: number, z: number): { index: number, unit: Unit, squaredDistance: number } | undefined {
-        return this._nearest(x, y, z);
+    nearest(x: number, y: number, z: number, k: number = 1, ctx?: StructureLookup3DResultContext): StructureResult {
+        return this._nearest(x, y, z, k, ctx ?? this.findContext);
     }
 
-    _nearest(x: number, y: number, z: number): { index: number, unit: Unit, squaredDistance: number } | undefined {
-        const ctx = this.findContext;
-        const { units, elementCount } = this.structure;
-
-        if (!elementCount) return undefined;
-
-        const radiusIncrement = 0.1; // how to choose a good increment?
-        const startingRadius = this.distanceTo(x, y, z);
-        let radius = startingRadius < 0 ? radiusIncrement : startingRadius + radiusIncrement;
-
-        while (true) {
-            const closeUnits = this.unitLookup.find(x, y, z, radius, ctx.closeUnitsResult);
-            radius += radiusIncrement;
-            if (closeUnits.count) {
-                let nearestIndex: number, nearestUnit: Unit, nearestDist = Number.MAX_SAFE_INTEGER;
-                for (let t = 1, _t = closeUnits.count; t < _t; t++) {
-                    const unit = units[closeUnits.indices[t]];
-                    Vec3.set(this.pivot, x, y, z);
-                    if (!unit.conformation.operator.isIdentity) {
-                        Vec3.transformMat4(this.pivot, this.pivot, unit.conformation.operator.inverse);
-                    }
-                    const unitLookup = unit.lookup3d;
-                    const nearestResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2]);
-                    if (nearestResult) {
-                        const { index: unitNearestIndex, squaredDistance: unitNearestDist } = nearestResult;
-                        if (unitNearestDist < nearestDist) {
-                            nearestDist = unitNearestDist;
-                            nearestIndex = unitNearestIndex;
-                            nearestUnit = unit;
-                        }
-                    }
-                }
-                return { index: nearestIndex!, unit: nearestUnit!, squaredDistance: nearestDist };
+    _nearest(x: number, y: number, z: number, k: number, ctx: StructureLookup3DResultContext): StructureResult {
+        const result = ctx.result;
+        Result.reset(result);
+        const { units } = this.structure;
+        const closeUnits = this.unitLookup.nearest(x, y, z, units.length, ctx.closeUnitsResult); // sort all units based on distance to the point
+        if (closeUnits.count === 0) return result;
+        let totalCount = 0, maxDistResult = -Number.MAX_VALUE;
+        for (let t = 0, _t = closeUnits.count; t < _t; t++) {
+            const unitSqDist = closeUnits.squaredDistances[t];
+            if (totalCount >= k && maxDistResult < unitSqDist) break;
+            Vec3.set(this.pivot, x, y, z);
+            const unit = units[closeUnits.indices[t]];
+            if (!unit.conformation.operator.isIdentity) {
+                Vec3.transformMat4(this.pivot, this.pivot, unit.conformation.operator.inverse);
+            }
+            const unitLookup = unit.lookup3d;
+            const groupResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2], k, ctx.unitGroupResult);
+            if (groupResult.count === 0) continue;
+            maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
+            totalCount += groupResult.count;
+            for (let j = 0, _j = groupResult.count; j < _j; j++) {
+                StructureResult.add(result, unit, groupResult.indices[j], groupResult.squaredDistances[j]);
+            }
+            StructureResult.sort(result);
+            if (totalCount > k) {
+                result.count = k;
+                totalCount = k;
             }
         }
+        return result;
     }
 
     findIntoBuilder(x: number, y: number, z: number, radius: number, builder: StructureUniqueSubsetBuilder) {