Ver Fonte

Increased performances of lookup3d nearest search.

giagitom há 2 anos atrás
pai
commit
8814b60d0b

+ 25 - 1
src/mol-math/geometry/_spec/lookup3d.spec.ts

@@ -2,6 +2,7 @@
  * Copyright (c) 2018-2020 mol* contributors, licensed under MIT, See LICENSE file for more info.
  *
  * @author David Sehnal <david.sehnal@gmail.com>
+ * @author Gianluca Tomasello <giagitom@gmail.com>
  */
 
 import { GridLookup3D } from '../../geometry';
@@ -24,9 +25,17 @@ describe('GridLookup3d', () => {
         expect(r.count).toBe(1);
         expect(r.indices[0]).toBe(0);
 
+        r = grid.nearest(0, 0, 0, 1);
+        expect(r.count).toBe(1);
+        expect(r.indices[0]).toBe(0);
+
         r = grid.find(0, 0, 0, 1);
         expect(r.count).toBe(3);
         expect(sortArray(r.indices)).toEqual([0, 1, 2]);
+
+        r = grid.nearest(0, 0, 0, 3);
+        expect(r.count).toBe(3);
+        expect(sortArray(r.indices)).toEqual([0, 1, 2]);
     });
 
     it('radius', () => {
@@ -38,9 +47,17 @@ describe('GridLookup3d', () => {
         expect(r.count).toBe(1);
         expect(r.indices[0]).toBe(0);
 
+        r = grid.nearest(0, 0, 0, 1);
+        expect(r.count).toBe(1);
+        expect(r.indices[0]).toBe(0);
+
         r = grid.find(0, 0, 0, 0.5);
         expect(r.count).toBe(2);
         expect(sortArray(r.indices)).toEqual([0, 1]);
+
+        r = grid.nearest(0, 0, 0, 3);
+        expect(r.count).toBe(3);
+        expect(sortArray(r.indices)).toEqual([0, 1, 2]);
     });
 
     it('indexed', () => {
@@ -51,8 +68,15 @@ describe('GridLookup3d', () => {
         let r = grid.find(0, 0, 0, 0);
         expect(r.count).toBe(0);
 
+        r = grid.nearest(0, 0, 0, 1);
+        expect(r.count).toBe(1);
+
         r = grid.find(0, 0, 0, 0.5);
         expect(r.count).toBe(1);
         expect(sortArray(r.indices)).toEqual([0]);
+
+        r = grid.nearest(0, 0, 0, 3);
+        expect(r.count).toBe(1);
+        expect(sortArray(r.indices)).toEqual([0]);
     });
-});
+});

+ 1 - 1
src/mol-math/geometry/lookup3d/common.ts

@@ -41,7 +41,7 @@ export namespace Result {
 export interface Lookup3D<T = number> {
     // The result is mutated with each call to find.
     find(x: number, y: number, z: number, radius: number, result?: Result<T>): Result<T>,
-    nearest(x: number, y: number, z: number, k: number, result?: Result<T>): Result<T>,
+    nearest(x: number, y: number, z: number, k: number, stopIf?: Function, result?: Result<T>): Result<T>,
     check(x: number, y: number, z: number, radius: number): boolean,
     readonly boundary: { readonly box: Box3D, readonly sphere: Sphere3D }
     /** transient result */

+ 97 - 82
src/mol-math/geometry/lookup3d/grid.ts

@@ -10,7 +10,7 @@ import { Result, Lookup3D } from './common';
 import { Box3D } from '../primitives/box3d';
 import { Sphere3D } from '../primitives/sphere3d';
 import { PositionData } from '../common';
-import { Vec3, EPSILON } from '../../linear-algebra';
+import { Vec3 } from '../../linear-algebra';
 import { OrderedSet } from '../../../mol-data/int';
 import { Boundary } from '../boundary';
 import { FibonacciHeap } from '../../../mol-util/fibonacci-heap';
@@ -42,11 +42,12 @@ class GridLookup3DImpl<T extends number = number> implements GridLookup3D<T> {
         return ret;
     }
 
-    nearest(x: number, y: number, z: number, k: number = 1, result?: Result<T>): Result<T> {
+    nearest(x: number, y: number, z: number, k: number = 1, stopIf?: Function, result?: Result<T>): Result<T> {
         this.ctx.x = x;
         this.ctx.y = y;
         this.ctx.z = z;
         this.ctx.k = k;
+        this.ctx.stopIf = stopIf;
         const ret = result ?? this.result;
         queryNearest(this.ctx, ret);
         return ret;
@@ -234,12 +235,13 @@ interface QueryContext {
     y: number,
     z: number,
     k: number,
+    stopIf?: Function,
     radius: number,
     isCheck: boolean
 }
 
 function createContext(grid: Grid3D): QueryContext {
-    return { grid, x: 0.1, y: 0.1, z: 0.1, k: 1, radius: 0.1, isCheck: false };
+    return { grid, x: 0.1, y: 0.1, z: 0.1, k: 1, stopIf: undefined, radius: 0.1, isCheck: false };
 }
 
 function query<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
@@ -294,124 +296,137 @@ function query<T extends number = number>(ctx: QueryContext, result: Result<T>):
 
 const tmpDirVec = Vec3();
 const tmpVec = Vec3();
-const tmpMapG = new Map<number, boolean>();
+const tmpSetG = new Set<number>();
+const tmpSetG2 = new Set<number>();
 const tmpArrG1 = [0.1];
 const tmpArrG2 = [0.1];
+const tmpArrG3 = [0.1];
 const tmpHeapG = new FibonacciHeap();
 function queryNearest<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
-    const { expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
-    const [minX, minY, minZ] = box.min;
-    const { x, y, z, k } = ctx;
+    const { min, expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
+    const { x, y, z, k, stopIf } = ctx;
     const indicesCount = OrderedSet.size(indices);
     Result.reset(result);
     if (indicesCount === 0 || k <= 0) return false;
-    let gX, gY, gZ;
+    let gX, gY, gZ, stop = false, gCount = 1, expandGrid = true, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, maxRange = 0, expandRange = true, gridId: number, gridPointsFinished = false;
+    const expandedArrG = tmpArrG3, sqMaxRadius = maxRadius * maxRadius;
+    arrG.length = 0;
+    expandedArrG.length = 0;
+    tmpSetG.clear();
+    tmpHeapG.clear();
     Vec3.set(tmpVec, x, y, z);
     if (!Box3D.containsVec3(box, tmpVec)) {
         // intersect ray pointing to box center
         Box3D.nearestIntersectionWithRay(tmpVec, box, tmpVec, Vec3.normalize(tmpDirVec, Vec3.sub(tmpDirVec, center, tmpVec)));
-        gX = Math.max(0, Math.min(sX - 1, Math.floor((tmpVec[0] - minX) / delta[0])));
-        gY = Math.max(0, Math.min(sY - 1, Math.floor((tmpVec[1] - minY) / delta[1])));
-        gZ = Math.max(0, Math.min(sZ - 1, Math.floor((tmpVec[2] - minZ) / delta[2])));
+        gX = Math.max(0, Math.min(sX - 1, Math.floor((tmpVec[0] - min[0]) / delta[0])));
+        gY = Math.max(0, Math.min(sY - 1, Math.floor((tmpVec[1] - min[1]) / delta[1])));
+        gZ = Math.max(0, Math.min(sZ - 1, Math.floor((tmpVec[2] - min[2]) / delta[2])));
     } else {
-        gX = Math.floor((x - minX) / delta[0]);
-        gY = Math.floor((y - minY) / delta[1]);
-        gZ = Math.floor((z - minZ) / delta[2]);
+        gX = Math.floor((x - min[0]) / delta[0]);
+        gY = Math.floor((y - min[1]) / delta[1]);
+        gZ = Math.floor((z - min[2]) / delta[2]);
     }
-    let gCount = 1, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, prevFurthestDist = Number.MAX_VALUE, prevNearestDist = -Number.MAX_VALUE, nearestDist = Number.MAX_VALUE, findFurthest = true, furthestDist = - Number.MAX_VALUE, distSqG: number;
-    arrG.length = 0;
-    nextArrG.length = 0;
-    arrG.push(gX, gY, gZ);
-    tmpMapG.clear();
-    tmpHeapG.clear();
+    const dX = maxRadius !== 0 ? Math.max(1, Math.min(sX - 1, Math.ceil(maxRadius / delta[0]))) : 1;
+    const dY = maxRadius !== 0 ? Math.max(1, Math.min(sY - 1, Math.ceil(maxRadius / delta[1]))) : 1;
+    const dZ = maxRadius !== 0 ? Math.max(1, Math.min(sZ - 1, Math.ceil(maxRadius / delta[2]))) : 1;
+    arrG.push(gX, gY, gZ, (((gX * sY) + gY) * sZ) + gZ);
     while (result.count < indicesCount) {
-        const arrGLen = gCount * 3;
-        for (let ig = 0; ig < arrGLen; ig += 3) {
-            gX = arrG[ig];
-            gY = arrG[ig + 1];
-            gZ = arrG[ig + 2];
-            const gridId = (((gX * sY) + gY) * sZ) + gZ;
-            if (tmpMapG.get(gridId)) continue; // already visited
-            tmpMapG.set(gridId, true);
-            distSqG = (gX - x) * (gX - x) + (gY - y) * (gY - y) + (gZ - z) * (gZ - z);
-            if (!findFurthest && distSqG > furthestDist && distSqG < nearestDist) continue;
-            // evaluate distances in the current grid point
-            const bucketIdx = grid[gridId];
-            if (bucketIdx !== 0) {
-                const ki = bucketIdx - 1;
-                const offset = bucketOffset[ki];
-                const count = bucketCounts[ki];
-                const end = offset + count;
-                for (let i = offset; i < end; i++) {
-                    const idx = OrderedSet.getAt(indices, bucketArray[i]);
-                    const dx = px[idx] - x;
-                    const dy = py[idx] - y;
-                    const dz = pz[idx] - z;
-                    let distSq = dx * dx + dy * dy + dz * dz;
-                    if (maxRadius !== 0) {
-                        const r = radius![idx];
-                        const sqR = r * r;
-                        if (findFurthest && distSq > furthestDist) furthestDist = distSq + sqR;
-                        distSq = distSq - sqR;
-                    } else {
-                        if (findFurthest && distSq > furthestDist) furthestDist = distSq;
-                    }
-                    if (distSq > prevNearestDist && distSq <= furthestDist) {
-                        tmpHeapG.insert(distSq, idx);
-                        nearestDist = tmpHeapG.findMinimum()!.key as unknown as number;
+        const arrGLen = gCount * 4;
+        for (let ig = 0; ig < arrGLen; ig += 4) {
+            gridId = arrG[ig + 3];
+            if (!tmpSetG.has(gridId)) {
+                tmpSetG.add(gridId);
+                gridPointsFinished = tmpSetG.size >= grid.length;
+                const bucketIdx = grid[gridId];
+                if (bucketIdx !== 0) {
+                    const _maxRange = maxRange;
+                    const ki = bucketIdx - 1;
+                    const offset = bucketOffset[ki];
+                    const count = bucketCounts[ki];
+                    const end = offset + count;
+                    for (let i = offset; i < end; i++) {
+                        const bIdx = bucketArray[i];
+                        const idx = OrderedSet.getAt(indices, bIdx);
+                        const dx = px[idx] - x;
+                        const dy = py[idx] - y;
+                        const dz = pz[idx] - z;
+                        let distSq = dx * dx + dy * dy + dz * dz;
+                        if (maxRadius !== 0) {
+                            const r = radius![idx];
+                            distSq -= r * r;
+                        }
+                        if (expandRange && distSq > maxRange) {
+                            maxRange = distSq;
+                        }
+                        tmpHeapG.insert(distSq, bIdx);
                     }
+                    if (_maxRange < maxRange) expandRange = false;
                 }
-                if (prevFurthestDist < furthestDist) findFurthest = false;
             }
+        }
+        // find next grid points
+        nextArrG.length = 0;
+        nextGCount = 0;
+        tmpSetG2.clear();
+        for (let ig = 0; ig < arrGLen; ig += 4) {
+            gX = arrG[ig];
+            gY = arrG[ig + 1];
+            gZ = arrG[ig + 2];
             // fill grid points array with valid adiacent positions
-            for (let ix = -1; ix <= 1; ix++) {
+            for (let ix = -dX; ix <= dX; ix++) {
                 const xPos = gX + ix;
                 if (xPos < 0 || xPos >= sX) continue;
-                for (let iy = -1; iy <= 1; iy++) {
+                for (let iy = -dY; iy <= dY; iy++) {
                     const yPos = gY + iy;
                     if (yPos < 0 || yPos >= sY) continue;
-                    for (let iz = -1; iz <= 1; iz++) {
+                    for (let iz = -dZ; iz <= dZ; iz++) {
                         const zPos = gZ + iz;
                         if (zPos < 0 || zPos >= sZ) continue;
-                        const gridId = (((xPos * sY) + yPos) * sZ) + zPos;
-                        if (tmpMapG.get(gridId)) continue; // already visited
-                        nextArrG.push(xPos, yPos, zPos);
+                        gridId = (((xPos * sY) + yPos) * sZ) + zPos;
+                        if (tmpSetG2.has(gridId)) continue; // already scanned
+                        tmpSetG2.add(gridId);
+                        if (tmpSetG.has(gridId)) continue; // already visited
+                        if (!expandGrid) {
+                            const xP = min[0] + xPos * delta[0] - x;
+                            const yP = min[1] + yPos * delta[1] - y;
+                            const zP = min[2] + zPos * delta[2] - z;
+                            const distSqG = (xP * xP) + (yP * yP) + (zP * zP) - sqMaxRadius; // is sqMaxRadius necessary?
+                            if (distSqG > maxRange) {
+                                expandedArrG.push(xPos, yPos, zPos, gridId);
+                                continue;
+                            }
+                        }
+                        nextArrG.push(xPos, yPos, zPos, gridId);
                         nextGCount++;
                     }
                 }
             }
         }
+        expandGrid = false;
         if (nextGCount === 0) {
-            while (!tmpHeapG.isEmpty() && result.count < k) {
+            while (!tmpHeapG.isEmpty() && (gridPointsFinished || tmpHeapG.findMinimum()!.key as unknown as number <= maxRange) && result.count < k) {
                 const node = tmpHeapG.extractMinimum();
-                if (!node) throw new Error('Cannot extract minimum, should not happen');
-                const { key: squaredDistance, value: index } = node;
+                const squaredDistance = node!.key, index = node!.value;
                 Result.add(result, index as number, squaredDistance as number);
+                if (stopIf && !stop) {
+                    stop = stopIf(index, squaredDistance);
+                }
             }
-            if (result.count >= k) return result.count > 0;
-            prevNearestDist = nearestDist;
-            if (furthestDist === nearestDist) {
-                findFurthest = true;
-                prevFurthestDist = furthestDist;
-                nearestDist = Number.MAX_VALUE;
-            } else {
-                nearestDist = furthestDist + EPSILON; // adding EPSILON fixes a bug
-            }
-            // resotre visibility of current gid points
-            for (let ig = 0; ig < arrGLen; ig += 3) {
-                gX = arrG[ig];
-                gY = arrG[ig + 1];
-                gZ = arrG[ig + 2];
-                const gridId = (((gX * sY) + gY) * sZ) + gZ;
-                tmpMapG.set(gridId, false);
+            if (result.count >= k || stop || result.count >= indicesCount) return result.count > 0;
+            expandGrid = true;
+            expandRange = true;
+            if (expandedArrG.length > 0) {
+                for (let i = 0, l = expandedArrG.length; i < l; i++) {
+                    arrG.push(expandedArrG[i]);
+                }
+                expandedArrG.length = 0;
+                gCount = arrG.length;
             }
         } else {
             const tmp = arrG;
             arrG = nextArrG;
             nextArrG = tmp;
-            nextArrG.length = 0;
             gCount = nextGCount;
-            nextGCount = 0;
         }
     }
     return result.count > 0;

+ 10 - 13
src/mol-model/structure/structure/util/lookup3d.ts

@@ -53,19 +53,16 @@ export function StructureLookup3DResultContext(): StructureLookup3DResultContext
     return { result: StructureResult.create(), closeUnitsResult: Result.create(), unitGroupResult: Result.create() };
 }
 
+const tmpHeap = new FibonacciHeap();
+
 export class StructureLookup3D {
     private unitLookup: Lookup3D;
     private pivot = Vec3();
-    private tmpHeap = new FibonacciHeap();
 
     findUnitIndices(x: number, y: number, z: number, radius: number): Result<number> {
         return this.unitLookup.find(x, y, z, radius);
     }
 
-    nearestUnitIndices(x: number, y: number, z: number, k: number = 1): Result<number> {
-        return this.unitLookup.nearest(x, y, z, k);
-    }
-
     private findContext = StructureLookup3DResultContext();
 
     find(x: number, y: number, z: number, radius: number, ctx?: StructureLookup3DResultContext): StructureResult {
@@ -99,11 +96,11 @@ export class StructureLookup3D {
 
     _nearest(x: number, y: number, z: number, k: number, ctx: StructureLookup3DResultContext): StructureResult {
         const result = ctx.result;
-        const heap = this.tmpHeap;
         Result.reset(result);
-        this.tmpHeap.clear();
+        tmpHeap.clear();
         const { units } = this.structure;
-        const closeUnits = this.unitLookup.nearest(x, y, z, units.length, ctx.closeUnitsResult); // sort all units based on distance to the point
+        let elementsCount = 0;
+        const closeUnits = this.unitLookup.nearest(x, y, z, units.length, (uid: number) => (elementsCount += units[uid].elements.length) >= k, ctx.closeUnitsResult); // sort units based on distance to the point
         if (closeUnits.count === 0) return result;
         let totalCount = 0, maxDistResult = -Number.MAX_VALUE;
         for (let t = 0, _t = closeUnits.count; t < _t; t++) {
@@ -115,16 +112,16 @@ export class StructureLookup3D {
                 Vec3.transformMat4(this.pivot, this.pivot, unit.conformation.operator.inverse);
             }
             const unitLookup = unit.lookup3d;
-            const groupResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2], k, ctx.unitGroupResult);
+            const groupResult = unitLookup.nearest(this.pivot[0], this.pivot[1], this.pivot[2], k, void 0, ctx.unitGroupResult);
             if (groupResult.count === 0) continue;
-            maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
             totalCount += groupResult.count;
+            maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
             for (let j = 0, _j = groupResult.count; j < _j; j++) {
-                heap.insert(groupResult.squaredDistances[j], { index: groupResult.indices[j], unit: unit });
+                tmpHeap.insert(groupResult.squaredDistances[j], { index: groupResult.indices[j], unit: unit });
             }
         }
-        while (!heap.isEmpty() && result.count < k) {
-            const node = heap.extractMinimum();
+        while (!tmpHeap.isEmpty() && result.count < k) {
+            const node = tmpHeap.extractMinimum();
             if (!node) throw new Error('Cannot extract minimum, should not happen');
             const { key: squaredDistance } = node;
             const { unit, index } = node.value as { index: UnitIndex, unit: Unit };

+ 22 - 0
src/mol-util/_spec/fibonacci-heap.spec.ts

@@ -0,0 +1,22 @@
+/**
+ * Copyright (c) 2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ *
+ * @author Gianluca Tomasello <giagitom@gmail.com>
+ */
+
+import { FibonacciHeap } from '../fibonacci-heap';
+
+describe('fibonacci-heap', () => {
+    it('basic', () => {
+        const heap = new FibonacciHeap();
+        heap.insert(1, 2);
+        heap.insert(4);
+        heap.insert(2);
+        heap.insert(3);
+        expect(heap.size()).toBe(4);
+        const node = heap.extractMinimum();
+        expect(node!.key).toBe(1);
+        expect(node!.value).toBe(2);
+        expect(heap.size()).toBe(3);
+    });
+});

+ 33 - 14
src/mol-util/fibonacci-heap.ts

@@ -1,18 +1,18 @@
 /**
- * Copyright (c) 2018-2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ * Copyright (c) 2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
  *
  * @author Gianluca Tomasello <giagitom@gmail.com>
  *
  * Adapted from https://github.com/gwtw/ts-fibonacci-heap, Copyright (c) 2014 Daniel Imms, MIT
  */
 
-type CompareFunction<K, V> = (a: INode<K, V>, b: INode<K, V>) => number;
-
 interface INode<K, V> {
     key: K;
     value?: V;
 }
 
+type CompareFunction<K, V> = (a: INode<K, V>, b: INode<K, V>) => number;
+
 class Node<K, V> implements INode<K, V> {
     public key: K;
     public value: V | undefined;
@@ -35,28 +35,32 @@ class Node<K, V> implements INode<K, V> {
 class NodeListIterator<K, V> {
     private _index: number;
     private _items: Node<K, V>[];
-
+    private _len: number;
     /**
    * Creates an Iterator used to simplify the consolidate() method. It works by
    * making a shallow copy of the nodes in the root list and iterating over the
    * shallow copy instead of the source as the source will be modified.
    * @param start A node from the root list.
    */
-    constructor(start: Node<K, V>) {
+    constructor(start?: Node<K, V>) {
         this._index = -1;
         this._items = [];
-        let current = start;
-        do {
-            this._items.push(current);
-            current = current.next;
-        } while (start !== current);
+        this._len = 0;
+        if (start) {
+            let current = start, l = 0;
+            do {
+                this._items[l++] = current;
+                current = current.next;
+            } while (start !== current);
+            this._len = l;
+        }
     }
 
     /**
    * @return Whether there is a next node in the iterator.
    */
     public hasNext(): boolean {
-        return this._index < this._items.length - 1;
+        return this._index < this._len - 1;
     }
 
     /**
@@ -65,8 +69,23 @@ class NodeListIterator<K, V> {
     public next(): Node<K, V> {
         return this._items[++this._index];
     }
+
+    /**
+   * @return Resets iterator to reuse it.
+   */
+    public reset(start: Node<K, V>) {
+        this._index = -1;
+        this._len = 0;
+        let current = start, l = 0;
+        do {
+            this._items[l++] = current;
+            current = current.next;
+        } while (start !== current);
+        this._len = l;
+    }
 }
 
+const tmpIt = new NodeListIterator<any, any>();
 /**
  * A Fibonacci heap data structure with a key and optional value.
 */
@@ -277,9 +296,9 @@ export class FibonacciHeap<K, V> {
     private _consolidate(minNode: Node<K, V>): Node<K, V> | null {
 
         const aux = [];
-        const it = new NodeListIterator<K, V>(minNode);
-        while (it.hasNext()) {
-            let current = it.next();
+        tmpIt.reset(minNode);
+        while (tmpIt.hasNext()) {
+            let current = tmpIt.next();
 
             // If there exists another node with the same degree, merge them
             let auxCurrent = aux[current.degree];