Kaynağa Gözat

alpha-orbitals: optimization

David Sehnal 4 yıl önce
ebeveyn
işleme
96a8cd789c

+ 2 - 28
src/examples/alpha-orbitals/index.ts

@@ -4,11 +4,8 @@
  * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
-import { Basis, computeIsocontourValues, CubeGridInfo } from '../../extensions/alpha-orbitals/cubes';
-import { AlphaOrbitalsPass } from '../../extensions/alpha-orbitals/gpu/pass';
+import { Basis, computeIsocontourValues } from '../../extensions/alpha-orbitals/cubes';
 import { SphericalBasisOrder } from '../../extensions/alpha-orbitals/orbitals';
-import { Box3D } from '../../mol-math/geometry';
-import { Vec3 } from '../../mol-math/linear-algebra';
 import { createPluginAsync, DefaultPluginSpec } from '../../mol-plugin';
 import { createVolumeRepresentationParams } from '../../mol-plugin-state/helpers/volume-representation-params';
 import { StateTransforms } from '../../mol-plugin-state/transforms';
@@ -16,7 +13,6 @@ import { PluginContext } from '../../mol-plugin/context';
 import { ColorNames } from '../../mol-util/color/names';
 import { DemoMoleculeSDF, DemoOrbitals } from './example-data';
 import './index.html';
-import { TestWaterParams } from './test-water';
 import { CreateOrbitalVolume, StaticBasisAndOrbitals } from './transforms';
 require('mol-plugin-ui/skin/light.scss');
 
@@ -66,7 +62,7 @@ class AlphaOrbitalsExample {
 
         const volumeRef = await this.plugin.build().toRoot()
             .apply(StaticBasisAndOrbitals, { basis: input.basis, order: input.order, orbitals: input.orbitals })
-            .apply(CreateOrbitalVolume, { index: 44 })
+            .apply(CreateOrbitalVolume, { index: 32 })
             .commit();
 
         if (!volumeRef.isOk) return;
@@ -96,29 +92,7 @@ class AlphaOrbitalsExample {
         }
 
         await repr.commit();
-
-        this.gpu();
-    }
-
-    gpu() {
-        // const pass = new AlphaOrbitalsPass(this.plugin.canvas3d!.webgl, TestWaterParams);
-        // pass.getData();
-        // pass.getData();
     }
 }
 
-function createCubeGrid(): CubeGridInfo {
-    const box = Box3D.create(Vec3.create(-1, -1, -1), Vec3.create(1, 1, 1));
-    const dimensions = Vec3.create(50, 50, 50);
-    const size = Box3D.size(Vec3(), box);
-
-    return {
-        box,
-        dimensions,
-        npoints: dimensions[0] * dimensions[1] * dimensions[2],
-        size,
-        delta: Vec3.div(Vec3(), size, Vec3.subScalar(Vec3(), dimensions, 1))
-    };
-}
-
 (window as any).AlphaOrbitalsExample = new AlphaOrbitalsExample();

+ 0 - 2
src/extensions/alpha-orbitals/collocation.ts

@@ -51,8 +51,6 @@ export async function sphericalCollocation(
 
     const matrix = new Float32Array(grid.npoints);
 
-    let ii = 0;
-
     let baseIndex = 0;
     for (const atom of basis.atoms) {
         for (const shell of atom.shells) {

+ 45 - 16
src/extensions/alpha-orbitals/cubes.ts

@@ -65,13 +65,17 @@ export function createSphericalCollocationGrid(
     params: SphericalCollocationParams, webgl?: WebGLContext
 ): Task<CubeGrid> {
     return Task.create('Spherical Collocation Grid', async (ctx) => {
-        const grid = initBox(
-            params.basis.atoms.map((a) => a.center),
-            params.gridSpacing,
-            params.boxExpand
-        );
+        const centers = params.basis.atoms.map(a => a.center);
 
-        let matrix: Float32Array;
+        const grid = initBox(centers, params.gridSpacing, params.boxExpand, true);
+
+        // const cParams: CollocationParams = {
+        //     grid,
+        //     basis: params.basis,
+        //     alphaOrbitals: params.alphaOrbitals,
+        //     cutoffThreshold: params.cutoffThreshold,
+        //     sphericalOrder: params.sphericalOrder
+        // };
 
         const cParams: CollocationParams = {
             grid,
@@ -81,25 +85,33 @@ export function createSphericalCollocationGrid(
             sphericalOrder: params.sphericalOrder
         };
 
+        console.log(cParams);
 
         console.time('gpu');
         const pass = new AlphaOrbitalsPass(webgl!, cParams);
         const matrixGL = pass.getData();
         console.timeEnd('gpu');
 
+        // TODO: remove the 2nd run
         console.time('gpu');
         const pass0 = new AlphaOrbitalsPass(webgl!, cParams);
         pass0.getData();
         console.timeEnd('gpu');
 
-        if (false && webgl) {
-        } else {
-            console.time('cpu');
-            matrix = await sphericalCollocation(cParams, ctx);
-            console.timeEnd('cpu');
-        }
+        // if (false && webgl) {
+        // } else {
+        console.time('cpu');
+        const matrix = await sphericalCollocation({
+            grid: initBox(centers, params.gridSpacing, params.boxExpand, false),
+            basis: params.basis,
+            alphaOrbitals: params.alphaOrbitals,
+            cutoffThreshold: params.cutoffThreshold,
+            sphericalOrder: params.sphericalOrder
+        }, ctx);
+        console.timeEnd('cpu');
+        // // }
 
-        // console.log(matrixGL);
+        console.log(matrixGL);
         // console.log(matrix);
 
         // for (let i = 0; i < matrixGL.length; i++) {
@@ -160,7 +172,8 @@ function createCubeGrid(gridInfo: CubeGridInfo, values: Float32Array, axisOrder:
 function initBox(
     geometry: Vec3[],
     spacing: SphericalCollocationParams['gridSpacing'],
-    expand: number
+    expand: number,
+    isGpu: boolean
 ): CubeGridInfo {
     const count = geometry.length;
     const box = Box3D.expand(
@@ -180,9 +193,16 @@ function initBox(
         if (spacingThresholds[i][0] <= count) break;
     }
 
+    const dimensions = Vec3.ceil(Vec3(), Vec3.scale(Vec3(), size, 1 / s));
+
     // dimensions need to be powers of 2 otherwise it leads to roudning error
-    // TODO: possible to avoid?
-    const dimensions = Vec3.create(64, 64, 64); //   Vec3.ceil(Vec3(), Vec3.scale(Vec3(), size, 1 / s));
+    // TODO: possible to avoid this having to be power of 2?
+    if (isGpu) {
+        dimensions[0] = reasonablePowerOf2(dimensions[0]);
+        dimensions[1] = reasonablePowerOf2(dimensions[1]);
+        dimensions[2] = reasonablePowerOf2(dimensions[2]);
+    }
+
     return {
         box,
         dimensions,
@@ -192,6 +212,15 @@ function initBox(
     };
 }
 
+function reasonablePowerOf2(x: number) {
+    const high = Math.pow(2, Math.ceil(Math.log2(x)));
+    if (high < 129) return high | 0;
+
+    const low = Math.pow(2, Math.floor(Math.log2(x)));
+    if ((x - low) / (high - low) < 0.33) return low | 0;
+    return high | 0;
+}
+
 export function computeIsocontourValues(
     values: Float32Array,
     cumulativeThreshold: number

+ 15 - 9
src/extensions/alpha-orbitals/gpu/pass.ts

@@ -13,6 +13,7 @@ import { WebGLContext } from '../../../mol-gl/webgl/context';
 import { createComputeRenderItem } from '../../../mol-gl/webgl/render-item';
 import { RenderTarget } from '../../../mol-gl/webgl/render-target';
 import { ValueCell } from '../../../mol-util';
+import { arrayMin } from '../../../mol-util/array';
 import { CollocationParams } from '../collocation';
 import { normalizeBasicOrder } from '../orbitals';
 import shader_frag from './shader.frag';
@@ -22,7 +23,7 @@ const AlphaOrbitalsSchema = {
     uDimensions: UniformSpec('v3'),
     uMin: UniformSpec('v3'),
     uDelta: UniformSpec('v3'),
-    tCenters: TextureSpec('image-float32', 'rgb', 'float', 'nearest'),
+    tCenters: TextureSpec('image-float32', 'rgba', 'float', 'nearest'),
     tInfo: TextureSpec('image-float32', 'rgba', 'float', 'nearest'),
     tCoeff: TextureSpec('image-float32', 'rgb', 'float', 'nearest'),
     tAlpha: TextureSpec('image-float32', 'alpha', 'float', 'nearest'),
@@ -38,6 +39,7 @@ function createTextureData({
     basis,
     sphericalOrder,
     alphaOrbitals,
+    cutoffThreshold
 }: CollocationParams) {
     let centerCount = 0;
     let baseCount = 0;
@@ -57,7 +59,7 @@ function createTextureData({
         }
     }
 
-    const centers = new Float32Array(3 * centerCount);
+    const centers = new Float32Array(4 * centerCount);
     // L, alpha_offset, coeff_offset_start, coeff_offset_end
     const info = new Float32Array(4 * centerCount);
     const alpha = new Float32Array(baseCount);
@@ -71,13 +73,14 @@ function createTextureData({
             for (const L of shell.angularMomentum) {
                 const a0 = normalizeBasicOrder(L, alphaOrbitals.slice(aO, aO + 2 * L + 1), sphericalOrder);
 
-                if (cO === 1) {
-                    console.log('y', atom.center[1]);
-                }
+                const cutoffRadius = cutoffThreshold > 0
+                    ? Math.sqrt(-Math.log(cutoffThreshold) / arrayMin(shell.exponents))
+                    : 10000;
 
-                centers[3 * cO + 0] = atom.center[0];
-                centers[3 * cO + 1] = atom.center[1];
-                centers[3 * cO + 2] = atom.center[2];
+                centers[4 * cO + 0] = atom.center[0];
+                centers[4 * cO + 1] = atom.center[1];
+                centers[4 * cO + 2] = atom.center[2];
+                centers[4 * cO + 3] = cutoffRadius * cutoffRadius;
 
                 info[4 * cO + 0] = L;
                 info[4 * cO + 1] = aO;
@@ -141,7 +144,7 @@ export class AlphaOrbitalsPass {
         this.renderable = getPostprocessingRenderable(webgl, params);
     }
 
-    render() {
+    private render() {
         const [nx, ny, nz] = this.params.grid.dimensions;
         const width = nx;
         const height = ny * nz;
@@ -175,6 +178,9 @@ export class AlphaOrbitalsPass {
         // console.log(array);
         // console.log(floats);
 
+        this.renderable.dispose();
+        this.target.destroy();
+
         return floats;
 
         // return new ImageData(new Uint8ClampedArray(array), width, height);

+ 9 - 2
src/extensions/alpha-orbitals/gpu/shader.frag.ts

@@ -114,7 +114,15 @@ void main(void) {
     for (int i = 0; i < uNCenters; i++) {
         vec2 cCoord = vec2(float(i) * fCenter, 0.5);
 
-        vec3 X = xyz - texture2D(tCenters, cCoord).xyz;
+        vec4 center = texture2D(tCenters, cCoord);
+        vec3 X = xyz - center.xyz;
+        float R2 = dot(X, X);
+
+        // center.w is squared cutoff radius
+        if (R2 > center.w) {
+            continue;
+        }
+
         vec4 info = texture2D(tInfo, cCoord);
 
         int L = int(info.x);
@@ -122,7 +130,6 @@ void main(void) {
         int coeffStart = int(info.z);
         int coeffEnd = int(info.w);
 
-        float R2 = dot(X, X);
 
         float gauss = 0.0;
         for (int j = coeffStart; j < coeffEnd; j++) {