浏览代码

alpha-orbitals: optimization & webgl1 support

David Sehnal 4 年之前
父节点
当前提交
0764795c08

+ 16 - 41
src/extensions/alpha-orbitals/cubes.ts

@@ -13,8 +13,8 @@ import { Mat4, Tensor, Vec3 } from '../../mol-math/linear-algebra';
 import { Grid } from '../../mol-model/volume';
 import { Task } from '../../mol-task';
 import { arrayMax, arrayMin, arrayRms } from '../../mol-util/array';
-import { CollocationParams } from './collocation';
-import { AlphaOrbitalsPass } from './gpu/pass';
+import { CollocationParams, sphericalCollocation } from './collocation';
+import { canComputeAlphaOrbitalsOnGPU, gpuComputeAlphaOrbitalsGridValues } from './gpu/compute';
 import { SphericalBasisOrder } from './orbitals';
 
 export interface CubeGridInfo {
@@ -76,38 +76,18 @@ export function createSphericalCollocationGrid(
             sphericalOrder: params.sphericalOrder
         };
 
-        // console.log(cParams);
-
-        console.time('gpu');
-        const pass = new AlphaOrbitalsPass(webgl!, cParams);
-        const matrixGL = pass.getData();
-        console.timeEnd('gpu');
-
-        // TODO: remove the 2nd run
-        // console.time('gpu');
-        // const pass0 = new AlphaOrbitalsPass(webgl!, cParams);
-        // pass0.getData();
-        // console.timeEnd('gpu');
-
-        // if (false && webgl) {
-        // } else {
-        // console.time('cpu');
-        // const matrix = await sphericalCollocation(cParams, ctx);
-        // console.timeEnd('cpu');
-        // // }
-
-        // console.log(matrixGL);
-        // console.log(matrix);
-
-        // for (let i = 0; i < matrixGL.length; i++) {
-        //     if (Math.abs(matrixGL[i] - matrix[i]) > 1e-4) {
-        //         console.log('err', i, matrixGL[i], matrix[i]);
-        //         // console.log()
-        //         break;
-        //     }
-        // }
-
-        return createCubeGrid(cParams.grid, matrixGL, [0, 1, 2], !params.doNotComputeIsovalues);
+        let matrix: Float32Array;
+        if (canComputeAlphaOrbitalsOnGPU(webgl)) {
+            console.time('gpu');
+            matrix = gpuComputeAlphaOrbitalsGridValues(webgl!, cParams);
+            console.timeEnd('gpu');
+        } else {
+            console.time('cpu');
+            matrix = await sphericalCollocation(cParams, ctx);
+            console.timeEnd('cpu');
+        }
+
+        return createCubeGrid(cParams.grid, matrix, [0, 1, 2], !params.doNotComputeIsovalues);
     });
 }
 
@@ -146,18 +126,13 @@ function createCubeGrid(gridInfo: CubeGridInfo, values: Float32Array, axisOrder:
     };
 
     // TODO: when using GPU rendering, the cumulative sum can be computed
-    // along the ray on the fly
-
+    //       along the ray on the fly?
     let isovalues: { negative?: number, positive?: number } | undefined;
 
     if (computeIsovalues) {
-        console.time('iso');
-        const isovalues = computeIsocontourValues(values, 0.85);
-        console.timeEnd('iso');
-        console.log(isovalues);
+        isovalues = computeIsocontourValues(values, 0.85);
     }
 
-
     return { grid, isovalues };
 }
 

+ 44 - 54
src/extensions/alpha-orbitals/gpu/pass.ts → src/extensions/alpha-orbitals/gpu/compute.ts

@@ -6,14 +6,14 @@
 
 import { QuadSchema, QuadValues } from '../../../mol-gl/compute/util';
 import { ComputeRenderable, createComputeRenderable } from '../../../mol-gl/renderable';
-import { TextureSpec, UniformSpec, Values } from '../../../mol-gl/renderable/schema';
+import { DefineSpec, TextureSpec, UniformSpec, Values } from '../../../mol-gl/renderable/schema';
 import { ShaderCode } from '../../../mol-gl/shader-code';
 import quad_vert from '../../../mol-gl/shader/quad.vert';
 import { WebGLContext } from '../../../mol-gl/webgl/context';
 import { createComputeRenderItem } from '../../../mol-gl/webgl/render-item';
-import { RenderTarget } from '../../../mol-gl/webgl/render-target';
 import { ValueCell } from '../../../mol-util';
 import { arrayMin } from '../../../mol-util/array';
+import { isLittleEndian } from '../../../mol-util/is-little-endian';
 import { CollocationParams } from '../collocation';
 import { normalizeBasicOrder } from '../orbitals';
 import shader_frag from './shader.frag';
@@ -30,6 +30,7 @@ const AlphaOrbitalsSchema = {
     uNCenters: UniformSpec('i'),
     uNAlpha: UniformSpec('i'),
     uNCoeff: UniformSpec('i'),
+    uMaxCoeffs: UniformSpec('i'),
     uLittleEndian: UniformSpec('i') // TODO: boolean uniforms
 };
 const AlphaOrbitalsShaderCode = ShaderCode('postprocessing', quad_vert, shader_frag);
@@ -65,6 +66,7 @@ function createTextureData({
     const alpha = new Float32Array(baseCount);
     const coeff = new Float32Array(3 * coeffCount);
 
+    let maxCoeffs = 0;
     let cO = 0, aO = 0, coeffO = 0;
     for (const atom of basis.atoms) {
         for (const shell of atom.shells) {
@@ -95,6 +97,10 @@ function createTextureData({
                     coeff[3 * (coeffO + i) + 1] = shell.exponents[i];
                 }
 
+                if (c0.length > maxCoeffs) {
+                    maxCoeffs = c0.length;
+                }
+
                 cO++;
                 aO += 2 * L + 1;
                 coeffO += shell.exponents.length;
@@ -102,14 +108,12 @@ function createTextureData({
         }
     }
 
-    return { nCenters: centerCount, nAlpha: baseCount, nCoeff: coeffCount, centers, info, alpha, coeff };
+    return { nCenters: centerCount, nAlpha: baseCount, nCoeff: coeffCount, maxCoeffs, centers, info, alpha, coeff };
 }
 
 function getPostprocessingRenderable(ctx: WebGLContext, params: CollocationParams): AlphaOrbitalsRenderable {
     const data = createTextureData(params);
 
-    // console.log(data);
-
     const values: Values<typeof AlphaOrbitalsSchema> = {
         ...QuadValues,
         uDimensions: ValueCell.create(params.grid.dimensions),
@@ -118,6 +122,7 @@ function getPostprocessingRenderable(ctx: WebGLContext, params: CollocationParam
         uNCenters: ValueCell.create(data.nCenters),
         uNAlpha: ValueCell.create(data.nAlpha),
         uNCoeff: ValueCell.create(data.nCoeff),
+        uMaxCoeffs: ValueCell.create(data.maxCoeffs),
         tCenters: ValueCell.create({ width: data.nCenters, height: 1, array: data.centers }),
         tInfo: ValueCell.create({ width: data.nCenters, height: 1, array: data.info }),
         tCoeff: ValueCell.create({ width: data.nCoeff, height: 1, array: data.coeff }),
@@ -131,64 +136,49 @@ function getPostprocessingRenderable(ctx: WebGLContext, params: CollocationParam
     return createComputeRenderable(renderItem, values);
 }
 
-export class AlphaOrbitalsPass {
-    target: RenderTarget
-    renderable: AlphaOrbitalsRenderable
+function normalizeParams(webgl: WebGLContext) {
+    if (!webgl.isWebGL2) {
+        // workaround for webgl1 limitation that loop counters need to be `const`
+        (AlphaOrbitalsSchema.uNCenters as any) = DefineSpec('number');
+        (AlphaOrbitalsSchema.uMaxCoeffs as any) = DefineSpec('number');
+    }
+}
 
-    constructor(private webgl: WebGLContext, private params: CollocationParams) {
-        const [nx, ny, nz] = params.grid.dimensions;
+export function gpuComputeAlphaOrbitalsGridValues(webgl: WebGLContext, params: CollocationParams) {
+    const [nx, ny, nz] = params.grid.dimensions;
 
-        // TODO: add single component float32 render target option for WebGL2?
-        // TODO: figure out the ordering so that it does not have to be remapped in the shader
-        this.target = webgl.createRenderTarget(nx, ny * nz, false, 'uint8', 'nearest');
-        this.renderable = getPostprocessingRenderable(webgl, params);
-    }
+    normalizeParams(webgl);
 
-    private render() {
-        const [nx, ny, nz] = this.params.grid.dimensions;
-        const width = nx;
-        const height = ny * nz;
-        const { gl, state } = this.webgl;
-        this.target.bind();
-        gl.viewport(0, 0, width, height);
-        gl.scissor(0, 0, width, height);
-        state.disable(gl.SCISSOR_TEST);
-        state.disable(gl.BLEND);
-        state.disable(gl.DEPTH_TEST);
-        state.depthMask(false);
-        this.renderable.render();
+    if (!webgl.computeTargets['alpha-oribtals']) {
+        webgl.computeTargets['alpha-oribtals'] = webgl.createRenderTarget(nx, ny * nz, false, 'uint8', 'nearest');
+    } else {
+        webgl.computeTargets['alpha-oribtals'].setSize(nx, ny * nz);
     }
 
-    getData() {
-        const [nx, ny, nz] = this.params.grid.dimensions;
-        const width = nx;
-        const height = ny * nz;
+    const target = webgl.computeTargets['alpha-oribtals'];
+    const renderable = getPostprocessingRenderable(webgl, params);
 
-        this.render();
-        this.target.bind();
-        const array = new Uint8Array(width * height * 4);
-        this.webgl.readPixels(0, 0, width, height, array);
-        // PixelData.flipY({ array, width, height });
-        const floats = new Float32Array(array.buffer, array.byteOffset, width * height);
+    const width = nx;
+    const height = ny * nz;
 
-        // console.log(array);
-        // console.log(floats);
+    const { gl, state } = webgl;
+    target.bind();
+    gl.viewport(0, 0, width, height);
+    gl.scissor(0, 0, width, height);
+    state.disable(gl.SCISSOR_TEST);
+    state.disable(gl.BLEND);
+    state.disable(gl.DEPTH_TEST);
+    state.depthMask(false);
+    renderable.render();
 
-        this.renderable.dispose();
-        this.target.destroy();
+    const array = new Uint8Array(width * height * 4);
+    webgl.readPixels(0, 0, width, height, array);
+    const floats = new Float32Array(array.buffer, array.byteOffset, width * height);
+    renderable.dispose();
 
-        return floats;
-
-        // return new ImageData(new Uint8ClampedArray(array), width, height);
-    }
+    return floats;
 }
 
-function isLittleEndian() {
-    const arrayBuffer = new ArrayBuffer(2);
-    const uint8Array = new Uint8Array(arrayBuffer);
-    const uint16array = new Uint16Array(arrayBuffer);
-    uint8Array[0] = 0xAA; // set first byte
-    uint8Array[1] = 0xBB; // set second byte
-    if(uint16array[0] === 0xBBAA) return 1;
-    return 0;
+export function canComputeAlphaOrbitalsOnGPU(webgl?: WebGLContext) {
+    return !!webgl?.extensions.textureFloat;
 }

+ 59 - 34
src/extensions/alpha-orbitals/gpu/shader.frag.ts

@@ -19,7 +19,10 @@ uniform sampler2D tInfo;
 uniform sampler2D tCoeff;
 uniform sampler2D tAlpha;
 
-uniform int uNCenters;
+#ifndef uNCenters
+    uniform int uNCenters;
+#endif
+
 uniform int uNCoeff;
 uniform int uNAlpha;
 
@@ -115,6 +118,58 @@ float alpha(const in float offset, const in float f) {
     return texture2D(tAlpha, vec2(offset * f, 0.5)).x;
 }
 
+float Y(const in int L, const in vec3 X, const in float aO, const in float fA) {
+    if (L == 0) {
+        return alpha(aO, fA);
+    } else if (L == 1) {
+        return L1(X,
+            alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA)
+        );
+    } else if (L == 2) {
+        return L2(X,
+            alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA)
+        );
+    } else if (L == 3) {
+        return L3(X,
+            alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA), 
+            alpha(aO + 5.0, fA), alpha(aO + 6.0, fA)
+        );
+    } else if (L == 4) {
+        return L4(X,
+            alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA), 
+            alpha(aO + 5.0, fA), alpha(aO + 6.0, fA), alpha(aO + 7.0, fA), alpha(aO + 8.0, fA)
+        );
+    }
+    // TODO: do we need L > 4?
+    return 0.0;
+}
+
+#ifndef uMaxCoeffs
+    float R(const in float R2, const in int start, const in int end, const in float fCoeff) {
+        float gauss = 0.0;
+        for (int i = start; i < end; i++) {
+            vec2 c = texture2D(tCoeff, vec2(float(i) * fCoeff, 0.5)).xy;
+            gauss += c.x * exp(-c.y * R2);
+        }
+        return gauss;
+    }
+#endif
+
+#ifdef uMaxCoeffs
+    float R(const in float R2, const in int start, const in int end, const in float fCoeff) {
+        float gauss = 0.0;
+        int o = start;
+        for (int i = 0; i < uMaxCoeffs; i++) {
+            if (o >= end) break;
+
+            vec2 c = texture2D(tCoeff, vec2(float(o) * fCoeff, 0.5)).xy;
+            o++;
+            gauss += c.x * exp(-c.y * R2);
+        }
+        return gauss;
+    }
+#endif
+
 float intDiv(const in float a, const in float b) { return float(int(a) / int(b)); }
 float intMod(const in float a, const in float b) { return a - b * float(int(a) / int(b)); }
 
@@ -122,7 +177,7 @@ void main(void) {
     float offset = round(floor(gl_FragCoord.x) + floor(gl_FragCoord.y) * uDimensions.x);
     
     // axis order fast to slow Z, Y, X
-    // TODO: support arbitrary axis orders
+    // TODO: support arbitrary axis orders?
     float k = intMod(offset, uDimensions.z), kk = intDiv(offset, uDimensions.z);
     float j = intMod(kk, uDimensions.y);
     float i = intDiv(kk, uDimensions.y);
@@ -153,38 +208,8 @@ void main(void) {
         float aO = info.y;
         int coeffStart = int(info.z);
         int coeffEnd = int(info.w);
-
-        float gauss = 0.0;
-        for (int j = coeffStart; j < coeffEnd; j++) {
-            vec2 c = texture2D(tCoeff, vec2(float(j) * fCoeff, 0.5)).xy;
-            gauss += c.x * exp(-c.y * R2);
-        }
-
-        float spherical = 0.0;
-        if (L == 0) {
-            spherical = alpha(aO, fA);
-        } else if (L == 1) {
-            spherical = L1(X,
-                alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA)
-            );
-        } else if (L == 2) {
-            spherical = L2(X,
-                alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA)
-            );
-        } else if (L == 3) {
-            spherical = L3(X,
-                alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA), 
-                alpha(aO + 5.0, fA), alpha(aO + 6.0, fA)
-            );
-        } else if (L == 4) {
-            spherical = L4(X,
-                alpha(aO, fA), alpha(aO + 1.0, fA), alpha(aO + 2.0, fA), alpha(aO + 3.0, fA), alpha(aO + 4.0, fA), 
-                alpha(aO + 5.0, fA), alpha(aO + 6.0, fA), alpha(aO + 7.0, fA), alpha(aO + 8.0, fA)
-            );
-        } 
-        // TODO: do we need L > 4?
-
-        v += gauss * spherical;
+        
+        v += R(R2, coeffStart, coeffEnd, fCoeff) * Y(L, X, aO, fA);
     }
 
     // TODO: render to single component float32 texture in WebGL2

+ 7 - 0
src/mol-gl/webgl/context.ts

@@ -191,6 +191,9 @@ export interface WebGLContext {
     setContextLost: () => void
     handleContextRestored: () => void
 
+    // A cache for compute targets, managed by the code that uses it
+    readonly computeTargets: { [name: string]: RenderTarget }
+
     createRenderTarget: (width: number, height: number, depth?: boolean, type?: 'uint8' | 'float32', filter?: TextureFilter) => RenderTarget
     unbindFramebuffer: () => void
     readPixels: (x: number, y: number, width: number, height: number, buffer: Uint8Array | Float32Array) => void
@@ -261,6 +264,8 @@ export function createContext(gl: GLRenderingContext, props: Partial<{ pixelScal
 
     const renderTargets = new Set<RenderTarget>();
 
+    const computeTargets = Object.create(null);
+
     return {
         gl,
         isWebGL2: isWebGL2(gl),
@@ -278,6 +283,8 @@ export function createContext(gl: GLRenderingContext, props: Partial<{ pixelScal
         get maxRenderbufferSize () { return parameters.maxRenderbufferSize; },
         get maxDrawBuffers () { return parameters.maxDrawBuffers; },
 
+        computeTargets,
+
         get isContextLost () {
             return isContextLost || gl.isContextLost();
         },

+ 4 - 0
src/mol-gl/webgl/render-target.ts

@@ -58,6 +58,10 @@ export function createRenderTarget(gl: GLRenderingContext, resources: WebGLResou
             gl.viewport(0, 0, _width, _height);
         },
         setSize: (width: number, height: number) => {
+            if (_width === width && _height === height) {
+                return;
+            }
+
             _width = width;
             _height = height;
             targetTexture.define(_width, _height);

+ 15 - 0
src/mol-util/is-little-endian.ts

@@ -0,0 +1,15 @@
+/**
+ * Copyright (c) 2020 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ *
+ * @author David Sehnal <david.sehnal@gmail.com>
+ */
+
+export function isLittleEndian() {
+    const arrayBuffer = new ArrayBuffer(2);
+    const uint8Array = new Uint8Array(arrayBuffer);
+    const uint16array = new Uint16Array(arrayBuffer);
+    uint8Array[0] = 0xAA;
+    uint8Array[1] = 0xBB;
+    if (uint16array[0] === 0xBBAA) return 1;
+    return 0;
+}