Browse Source

Implement lookup & grid nearest search using fibonacci heap

giagitom 2 years ago
parent
commit
d7cbd5570c

+ 28 - 21
src/mol-math/geometry/lookup3d/grid.ts

@@ -13,6 +13,7 @@ import { PositionData } from '../common';
 import { Vec3, EPSILON } from '../../linear-algebra';
 import { OrderedSet } from '../../../mol-data/int';
 import { Boundary } from '../boundary';
+import { FibonacciHeap } from '../../../mol-util/fibonacci-heap';
 
 interface GridLookup3D<T = number> extends Lookup3D<T> {
     readonly buckets: { readonly offset: ArrayLike<number>, readonly count: ArrayLike<number>, readonly array: ArrayLike<number> }
@@ -294,16 +295,16 @@ function query<T extends number = number>(ctx: QueryContext, result: Result<T>):
 const tmpDirVec = Vec3();
 const tmpVec = Vec3();
 const tmpMapG = new Map<number, boolean>();
-const tmpArrG1 = new Array();
-const tmpArrG2 = new Array();
-const tmpNearestIds = new Array();
+const tmpArrG1 = [0.1];
+const tmpArrG2 = [0.1];
+const tmpHeapG = new FibonacciHeap();
 function queryNearest<T extends number = number>(ctx: QueryContext, result: Result<T>): boolean {
     const { expandedBox: box, boundingSphere: { center }, size: [sX, sY, sZ], bucketOffset, bucketCounts, bucketArray, grid, data: { x: px, y: py, z: pz, indices, radius }, delta, maxRadius } = ctx.grid;
     const [minX, minY, minZ] = box.min;
     const { x, y, z, k } = ctx;
     const indicesCount = OrderedSet.size(indices);
     Result.reset(result);
-    if (indicesCount === 0) return false;
+    if (indicesCount === 0 || k <= 0) return false;
     let gX, gY, gZ;
     Vec3.set(tmpVec, x, y, z);
     if (!Box3D.containsVec3(box, tmpVec)) {
@@ -319,9 +320,11 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
     }
     let gCount = 1, nextGCount = 0, arrG = tmpArrG1, nextArrG = tmpArrG2, prevFurthestDist = Number.MAX_VALUE, prevNearestDist = -Number.MAX_VALUE, nearestDist = Number.MAX_VALUE, findFurthest = true, furthestDist = - Number.MAX_VALUE, distSqG: number;
     arrG.length = 0;
+    nextArrG.length = 0;
     arrG.push(gX, gY, gZ);
     tmpMapG.clear();
-    while (result.count < k && result.count < indicesCount) {
+    tmpHeapG.clear();
+    while (result.count < indicesCount) {
         const arrGLen = gCount * 3;
         for (let ig = 0; ig < arrGLen; ig += 3) {
             gX = arrG[ig];
@@ -332,7 +335,6 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
             tmpMapG.set(gridId, true);
             distSqG = (gX - x) * (gX - x) + (gY - y) * (gY - y) + (gZ - z) * (gZ - z);
             if (!findFurthest && distSqG > furthestDist && distSqG < nearestDist) continue;
-
             // evaluate distances in the current grid point
             const bucketIdx = grid[gridId];
             if (bucketIdx !== 0) {
@@ -351,24 +353,16 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
                         const sqR = r * r;
                         if (findFurthest && distSq > furthestDist) furthestDist = distSq + sqR;
                         distSq = distSq - sqR;
-
                     } else {
                         if (findFurthest && distSq > furthestDist) furthestDist = distSq;
                     }
-
-                    if (distSq <= nearestDist && distSq > prevNearestDist) {
-                        if (nearestDist === distSq) { // handles multiple elements exactly at same distance
-                            tmpNearestIds.push(idx);
-                        } else {
-                            if (tmpNearestIds.length > 1) tmpNearestIds.length = 1;
-                            tmpNearestIds[0] = idx;
-                        }
-                        nearestDist = distSq;
+                    if (distSq > prevNearestDist && distSq <= furthestDist) {
+                        tmpHeapG.insert(distSq, idx);
+                        nearestDist = tmpHeapG.findMinimum()!.key as unknown as number;
                     }
                 }
                 if (prevFurthestDist < furthestDist) findFurthest = false;
             }
-
             // fill grid points array with valid adiacent positions
             for (let ix = -1; ix <= 1; ix++) {
                 const xPos = gX + ix;
@@ -379,6 +373,8 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
                     for (let iz = -1; iz <= 1; iz++) {
                         const zPos = gZ + iz;
                         if (zPos < 0 || zPos >= sZ) continue;
+                        const gridId = (((xPos * sY) + yPos) * sZ) + zPos;
+                        if (tmpMapG.get(gridId)) continue; // already visited
                         nextArrG.push(xPos, yPos, zPos);
                         nextGCount++;
                     }
@@ -386,10 +382,14 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
             }
         }
         if (nextGCount === 0) {
-            prevNearestDist = nearestDist;
-            for (let i = 0, l = tmpNearestIds.length; i < l; i++) {
-                Result.add(result, tmpNearestIds[i], nearestDist);
+            while (!tmpHeapG.isEmpty() && result.count < k) {
+                const node = tmpHeapG.extractMinimum();
+                if (!node) throw new Error('Cannot extract minimum, should not happen');
+                const { key: squaredDistance, value: index } = node;
+                Result.add(result, index as number, squaredDistance as number);
             }
+            if (result.count >= k) return result.count > 0;
+            prevNearestDist = nearestDist;
             if (furthestDist === nearestDist) {
                 findFurthest = true;
                 prevFurthestDist = furthestDist;
@@ -397,7 +397,14 @@ function queryNearest<T extends number = number>(ctx: QueryContext, result: Resu
             } else {
                 nearestDist = furthestDist + EPSILON; // adding EPSILON fixes a bug
             }
-            tmpMapG.clear();
+            // resotre visibility of current gid points
+            for (let ig = 0; ig < arrGLen; ig += 3) {
+                gX = arrG[ig];
+                gY = arrG[ig + 1];
+                gZ = arrG[ig + 2];
+                const gridId = (((gX * sY) + gY) * sZ) + gZ;
+                tmpMapG.set(gridId, false);
+            }
         } else {
             const tmp = arrG;
             arrG = nextArrG;

+ 12 - 26
src/mol-model/structure/structure/util/lookup3d.ts

@@ -10,11 +10,11 @@ import { Structure } from '../structure';
 import { Lookup3D, GridLookup3D, Result } from '../../../../mol-math/geometry';
 import { Vec3 } from '../../../../mol-math/linear-algebra';
 import { OrderedSet } from '../../../../mol-data/int';
-import { arrayLess, arraySwap } from '../../../../mol-data/util';
 import { StructureUniqueSubsetBuilder } from './unique-subset-builder';
 import { StructureElement } from '../element';
 import { Unit } from '../unit';
 import { UnitIndex } from '../element/util';
+import { FibonacciHeap } from '../../../../mol-util/fibonacci-heap';
 
 export interface StructureResult extends Result<StructureElement.UnitIndex> {
     units: Unit[]
@@ -41,25 +41,6 @@ export namespace StructureResult {
         out.count = result.count;
         return out;
     }
-
-    // sort in ascending order based on squaredDistances
-    export function sort(result: StructureResult) {
-        const { indices, squaredDistances, units, count } = result;
-        if (indices.length > count) {
-            // clear arrays before doing sorting if necessary
-            indices.length = count;
-            squaredDistances.length = count;
-            units.length = count;
-        }
-        for (let i = 1, c = result.count; i < c; i++) {
-            if (arrayLess(squaredDistances, i - 1, i) > 0) {
-                arraySwap(squaredDistances, i - 1, i);
-                arraySwap(indices, i - 1, i);
-                arraySwap(units, i - 1, i);
-                i = Math.max(0, i - 2);
-            }
-        }
-    }
 }
 
 export interface StructureLookup3DResultContext {
@@ -75,6 +56,7 @@ export function StructureLookup3DResultContext(): StructureLookup3DResultContext
 export class StructureLookup3D {
     private unitLookup: Lookup3D;
     private pivot = Vec3();
+    private tmpHeap = new FibonacciHeap();
 
     findUnitIndices(x: number, y: number, z: number, radius: number): Result<number> {
         return this.unitLookup.find(x, y, z, radius);
@@ -117,7 +99,9 @@ export class StructureLookup3D {
 
     _nearest(x: number, y: number, z: number, k: number, ctx: StructureLookup3DResultContext): StructureResult {
         const result = ctx.result;
+        const heap = this.tmpHeap;
         Result.reset(result);
+        this.tmpHeap.clear();
         const { units } = this.structure;
         const closeUnits = this.unitLookup.nearest(x, y, z, units.length, ctx.closeUnitsResult); // sort all units based on distance to the point
         if (closeUnits.count === 0) return result;
@@ -136,14 +120,16 @@ export class StructureLookup3D {
             maxDistResult = Math.max(maxDistResult, groupResult.squaredDistances[groupResult.count - 1]);
             totalCount += groupResult.count;
             for (let j = 0, _j = groupResult.count; j < _j; j++) {
-                StructureResult.add(result, unit, groupResult.indices[j], groupResult.squaredDistances[j]);
-            }
-            StructureResult.sort(result);
-            if (totalCount > k) {
-                result.count = k;
-                totalCount = k;
+                heap.insert(groupResult.squaredDistances[j], { index: groupResult.indices[j], unit: unit });
             }
         }
+        while (!heap.isEmpty() && result.count < k) {
+            const node = heap.extractMinimum();
+            if (!node) throw new Error('Cannot extract minimum, should not happen');
+            const { key: squaredDistance } = node;
+            const { unit, index } = node.value as { index: UnitIndex, unit: Unit };
+            StructureResult.add(result, unit as Unit, index as UnitIndex, squaredDistance as number);
+        }
         return result;
     }
 

+ 388 - 0
src/mol-util/fibonacci-heap.ts

@@ -0,0 +1,388 @@
+/**
+ * Copyright (c) 2018-2022 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ *
+ * @author Gianluca Tomasello <giagitom@gmail.com>
+ *
+ * Adapted from https://github.com/gwtw/ts-fibonacci-heap, Copyright (c) 2014 Daniel Imms, MIT
+ */
+
+type CompareFunction<K, V> = (a: INode<K, V>, b: INode<K, V>) => number;
+
+interface INode<K, V> {
+    key: K;
+    value?: V;
+}
+
+class Node<K, V> implements INode<K, V> {
+    public key: K;
+    public value: V | undefined;
+    public prev: Node<K, V>;
+    public next: Node<K, V>;
+    public parent: Node<K, V> | null = null;
+    public child: Node<K, V> | null = null;
+
+    public degree: number = 0;
+    public isMarked: boolean = false;
+
+    constructor(key: K, value?: V) {
+        this.key = key;
+        this.value = value;
+        this.prev = this;
+        this.next = this;
+    }
+}
+
+class NodeListIterator<K, V> {
+    private _index: number;
+    private _items: Node<K, V>[];
+
+    /**
+   * Creates an Iterator used to simplify the consolidate() method. It works by
+   * making a shallow copy of the nodes in the root list and iterating over the
+   * shallow copy instead of the source as the source will be modified.
+   * @param start A node from the root list.
+   */
+    constructor(start: Node<K, V>) {
+        this._index = -1;
+        this._items = [];
+        let current = start;
+        do {
+            this._items.push(current);
+            current = current.next;
+        } while (start !== current);
+    }
+
+    /**
+   * @return Whether there is a next node in the iterator.
+   */
+    public hasNext(): boolean {
+        return this._index < this._items.length - 1;
+    }
+
+    /**
+   * @return The next node.
+   */
+    public next(): Node<K, V> {
+        return this._items[++this._index];
+    }
+}
+
+/**
+ * A Fibonacci heap data structure with a key and optional value.
+*/
+export class FibonacciHeap<K, V> {
+    private _minNode: Node<K, V> | null = null;
+    private _nodeCount: number = 0;
+    private _compare: CompareFunction<K, V>;
+
+    constructor(
+        compare?: CompareFunction<K, V>
+    ) {
+        this._compare = compare ? compare : this._defaultCompare;
+    }
+
+    /**
+   * Clears the heap's data, making it an empty heap.
+   */
+    public clear(): void {
+        this._minNode = null;
+        this._nodeCount = 0;
+    }
+
+    /**
+   * Decreases a key of a node.
+   * @param node The node to decrease the key of.
+   * @param newKey The new key to assign to the node.
+   */
+    public decreaseKey(node: Node<K, V>, newKey: K): void {
+        if (!node) {
+            throw new Error('Cannot decrease key of non-existent node');
+        }
+        if (this._compare({ key: newKey }, { key: node.key }) > 0) {
+            throw new Error('New key is larger than old key');
+        }
+
+        node.key = newKey;
+        const parent = node.parent;
+        if (parent && this._compare(node, parent) < 0) {
+            this._cut(node, parent, <Node<K, V>> this._minNode);
+            this._cascadingCut(parent, <Node<K, V>> this._minNode);
+        }
+        if (this._compare(node, <Node<K, V>> this._minNode) < 0) {
+            this._minNode = node;
+        }
+    }
+
+    /**
+   * Deletes a node.
+   * @param node The node to delete.
+   */
+    public delete(node: Node<K, V>): void {
+    // This is a special implementation of decreaseKey that sets the argument to
+    // the minimum value. This is necessary to make generic keys work, since there
+    // is no MIN_VALUE constant for generic types.
+        const parent = node.parent;
+        if (parent) {
+            this._cut(node, parent, <Node<K, V>> this._minNode);
+            this._cascadingCut(parent, <Node<K, V>> this._minNode);
+        }
+        this._minNode = node;
+
+        this.extractMinimum();
+    }
+
+    /**
+   * Extracts and returns the minimum node from the heap.
+   * @return The heap's minimum node or null if the heap is empty.
+   */
+    public extractMinimum(): Node<K, V> | null {
+        const extractedMin = this._minNode;
+        if (extractedMin) {
+            // Set parent to null for the minimum's children
+            if (extractedMin.child) {
+                let child = extractedMin.child;
+                do {
+                    child.parent = null;
+                    child = child.next;
+                } while (child !== extractedMin.child);
+            }
+
+            let nextInRootList = null;
+            if (extractedMin.next !== extractedMin) {
+                nextInRootList = extractedMin.next;
+            }
+            // Remove min from root list
+            this._removeNodeFromList(extractedMin);
+            this._nodeCount--;
+
+            // Merge the children of the minimum node with the root list
+            this._minNode = this._mergeLists(nextInRootList, extractedMin.child);
+            if (this._minNode) {
+                this._minNode = this._consolidate(this._minNode);
+            }
+        }
+        return extractedMin;
+    }
+
+    /**
+   * Returns the minimum node from the heap.
+   * @return The heap's minimum node or null if the heap is empty.
+   */
+    public findMinimum(): Node<K, V> | null {
+        return this._minNode;
+    }
+
+    /**
+   * Inserts a new key-value pair into the heap.
+   * @param key The key to insert.
+   * @param value The value to insert.
+   * @return node The inserted node.
+   */
+    public insert(key: K, value?: V): Node<K, V> {
+        const node = new Node(key, value);
+        this._minNode = this._mergeLists(this._minNode, node);
+        this._nodeCount++;
+        return node;
+    }
+
+    /**
+   * @return Whether the heap is empty.
+   */
+    public isEmpty(): boolean {
+        return this._minNode === null;
+    }
+
+    /**
+   * @return The size of the heap.
+   */
+    public size(): number {
+        if (this._minNode === null) {
+            return 0;
+        }
+        return this._getNodeListSize(this._minNode);
+    }
+
+    /**
+   * Joins another heap to this heap.
+   * @param other The other heap.
+   */
+    public union(other: FibonacciHeap<K, V>): void {
+        this._minNode = this._mergeLists(this._minNode, other._minNode);
+        this._nodeCount += other._nodeCount;
+    }
+
+    /**
+   * Compares two nodes with each other.
+   * @param a The first key to compare.
+   * @param b The second key to compare.
+   * @return -1, 0 or 1 if a < b, a == b or a > b respectively.
+   */
+    private _defaultCompare(a: INode<K, V>, b: INode<K, V>): number {
+        if (a.key > b.key) {
+            return 1;
+        }
+        if (a.key < b.key) {
+            return -1;
+        }
+        return 0;
+    }
+
+    /**
+   * Cut the link between a node and its parent, moving the node to the root list.
+   * @param node The node being cut.
+   * @param parent The parent of the node being cut.
+   * @param minNode The minimum node in the root list.
+   * @return The heap's new minimum node.
+   */
+    private _cut(node: Node<K, V>, parent: Node<K, V>, minNode: Node<K, V>): Node<K, V> | null {
+        node.parent = null;
+        parent.degree--;
+        if (node.next === node) {
+            parent.child = null;
+        } else {
+            parent.child = node.next;
+        }
+        this._removeNodeFromList(node);
+        const newMinNode = this._mergeLists(minNode, node);
+        node.isMarked = false;
+        return newMinNode;
+    }
+
+    /**
+   * Perform a cascading cut on a node; mark the node if it is not marked,
+   * otherwise cut the node and perform a cascading cut on its parent.
+   * @param node The node being considered to be cut.
+   * @param minNode The minimum node in the root list.
+   * @return The heap's new minimum node.
+   */
+    private _cascadingCut(node: Node<K, V>, minNode: Node<K, V> | null): Node<K, V> | null {
+        const parent = node.parent;
+        if (parent) {
+            if (node.isMarked) {
+                minNode = this._cut(node, parent, <Node<K, V>>minNode);
+                minNode = this._cascadingCut(parent, minNode);
+            } else {
+                node.isMarked = true;
+            }
+        }
+        return minNode;
+    }
+
+    /**
+   * Merge all trees of the same order together until there are no two trees of
+   * the same order.
+   * @param minNode The current minimum node.
+   * @return The new minimum node.
+   */
+    private _consolidate(minNode: Node<K, V>): Node<K, V> | null {
+
+        const aux = [];
+        const it = new NodeListIterator<K, V>(minNode);
+        while (it.hasNext()) {
+            let current = it.next();
+
+            // If there exists another node with the same degree, merge them
+            let auxCurrent = aux[current.degree];
+            while (auxCurrent) {
+                if (this._compare(current, auxCurrent) > 0) {
+                    const temp = current;
+                    current = auxCurrent;
+                    auxCurrent = temp;
+                }
+                this._linkHeaps(auxCurrent, current);
+                aux[current.degree] = null;
+                current.degree++;
+                auxCurrent = aux[current.degree];
+            }
+
+            aux[current.degree] = current;
+        }
+
+        let newMinNode = null;
+        for (let i = 0; i < aux.length; i++) {
+            const node = aux[i];
+            if (node) {
+                // Remove siblings before merging
+                node.next = node;
+                node.prev = node;
+                newMinNode = this._mergeLists(newMinNode, node);
+            }
+        }
+        return newMinNode;
+    }
+
+    /**
+   * Removes a node from a node list.
+   * @param node The node to remove.
+   */
+    private _removeNodeFromList(node: Node<K, V>): void {
+        const prev = node.prev;
+        const next = node.next;
+        prev.next = next;
+        next.prev = prev;
+        node.next = node;
+        node.prev = node;
+    }
+
+    /**
+   * Links two heaps of the same order together.
+   *
+   * @private
+   * @param max The heap with the larger root.
+   * @param min The heap with the smaller root.
+   */
+    private _linkHeaps(max: Node<K, V>, min: Node<K, V>): void {
+        this._removeNodeFromList(max);
+        min.child = this._mergeLists(max, min.child);
+        max.parent = min;
+        max.isMarked = false;
+    }
+
+    /**
+   * Merge two lists of nodes together.
+   *
+   * @private
+   * @param a The first list to merge.
+   * @param b The second list to merge.
+   * @return The new minimum node from the two lists.
+   */
+    private _mergeLists(a: Node<K, V> | null, b: Node<K, V> | null): Node<K, V> | null {
+        if (!a) {
+            if (!b) {
+                return null;
+            }
+            return b;
+        }
+        if (!b) {
+            return a;
+        }
+
+        const temp = a.next;
+        a.next = b.next;
+        a.next.prev = a;
+        b.next = temp;
+        b.next.prev = b;
+
+        return this._compare(a, b) < 0 ? a : b;
+    }
+
+    /**
+   * Gets the size of a node list.
+   * @param node A node within the node list.
+   * @return The size of the node list.
+   */
+    private _getNodeListSize(node: Node<K, V>): number {
+        let count = 0;
+        let current = node;
+
+        do {
+            count++;
+            if (current.child) {
+                count += this._getNodeListSize(current.child);
+            }
+            current = current.next;
+        } while (current !== node);
+
+        return count;
+    }
+}