Ver Fonte

wip, molsurf

Alexander Rose há 6 anos atrás
pai
commit
2d164232ac

+ 186 - 332
src/mol-math/geometry/molecular-surface.ts

@@ -8,14 +8,15 @@
  */
 
 import { fillUniform } from 'mol-util/array';
-import { Vec3 } from 'mol-math/linear-algebra';
-import { NumberArray } from 'mol-util/type-helpers';
+import { Vec3, Tensor } from 'mol-math/linear-algebra';
 import { ParamDefinition as PD } from 'mol-util/param-definition';
-import { number } from 'prop-types';
 import { Lookup3D, Result } from './lookup3d/common';
 import { RuntimeContext } from 'mol-task';
 import { OrderedSet } from 'mol-data/int';
 import { PositionData } from './common';
+import { Mat4 } from 'mol-math/linear-algebra/3d';
+import { Box3D, GridLookup3D } from 'mol-math/geometry';
+import { getDelta } from './gaussian-density';
 
 function normalToLine (out: Vec3, p: Vec3) {
     out[0] = out[1] = out[2] = 1.0
@@ -29,12 +30,6 @@ function normalToLine (out: Vec3, p: Vec3) {
     return out
 }
 
-function fillGridDim (a: Float32Array, start: number, step: number) {
-    for (let i = 0; i < a.length; i++) {
-        a[i] = start + (step * i)
-    }
-}
-
 type AnglesTables = { cosTable: Float32Array, sinTable: Float32Array }
 function getAngleTables (probePositions: number): AnglesTables {
     let theta = 0.0
@@ -43,8 +38,8 @@ function getAngleTables (probePositions: number): AnglesTables {
     const cosTable = new Float32Array(probePositions)
     const sinTable = new Float32Array(probePositions)
     for (let i = 0; i < probePositions; i++) {
-        cosTable[ i ] = Math.cos(theta)
-        sinTable[ i ] = Math.sin(theta)
+        cosTable[i] = Math.cos(theta)
+        sinTable[i] = Math.sin(theta)
         theta += step
     }
     return { cosTable, sinTable}
@@ -57,10 +52,8 @@ function getAngleTables (probePositions: number): AnglesTables {
  * Cache the last clipped atom (as very often the same one in subsequent calls)
  */
 function obscured (state: MolSurfCalcState, x: number, y: number, z: number, a: number, b: number) {
-    let ai: number
-
     if (state.lastClip !== -1) {
-        ai = state.lastClip
+        const ai = state.lastClip
         if (ai !== a && ai !== b && singleAtomObscures(state, ai, x, y, z)) {
             return ai
         } else {
@@ -68,28 +61,25 @@ function obscured (state: MolSurfCalcState, x: number, y: number, z: number, a:
         }
     }
 
-    let ni = 0
-    ai = state.neighbours[ni]
-    while (ai >= 0) {
+    for (let j = 0, jl = state.neighbours.count; j < jl; ++j) {
+        const ai = state.neighbours.indices[j]
         if (ai !== a && ai !== b && singleAtomObscures(state, ai, x, y, z)) {
             state.lastClip = ai
             return ai
         }
-        ai = state.neighbours[++ni]
     }
 
-    state.lastClip = -1
-
     return -1
 }
 
 function singleAtomObscures (state: MolSurfCalcState, ai: number, x: number, y: number, z: number) {
-    const ra2 = state.radiusSq[ai]
-    const dx = state.xCoord[ai] - x
-    const dy = state.yCoord[ai] - y
-    const dz = state.zCoord[ai] - z
-    const d2 = dx * dx + dy * dy + dz * dz
-    return d2 < ra2
+    const j = OrderedSet.getAt(state.position.indices, ai)
+    const r = state.position.radius[j] + state.probeRadius
+    const dx = state.position.x[j] - x
+    const dy = state.position.y[j] - y
+    const dz = state.position.z[j] - z
+    const dSq = dx * dx + dy * dy + dz * dz
+    return dSq < (r * r)
 }
 
 /**
@@ -103,35 +93,37 @@ function singleAtomObscures (state: MolSurfCalcState, ai: number, x: number, y:
  *             itself and delta
  */
 async function projectPoints (ctx:  RuntimeContext, state: MolSurfCalcState) {
-    const { position, radius, scaleFactor, lookup3D, min, delta } = state
+    const { position, probeRadius, lookup3d, min, delta, space, data, idData } = state
 
-    const { indices, x, y, z } = position
+    const { indices, x, y, z, radius } = position
     const n = OrderedSet.size(indices)
 
     const v = Vec3()
     const p = Vec3()
     const c = Vec3()
+    const sp = Vec3()
 
     const beg = Vec3()
     const end = Vec3()
 
-    const gridPad = 1 / Math.max(...delta)
+    // const gridPad = 1 / Math.max(...delta)
 
     for (let i = 0; i < n; ++i) {
         const j = OrderedSet.getAt(indices, i)
 
         Vec3.set(v, x[j], y[j], z[j])
 
+        state.neighbours = lookup3d.find(v[0], v[1], v[2], radius[j] + probeRadius)
+
         Vec3.sub(v, v, min)
         Vec3.mul(c, v, delta)
 
-        const rad = radius(j)
+        const rad = radius[j] + probeRadius
         const rSq = rad * rad
 
-        const r2 = rad * 2 + gridPad
+        const r2 = rad // * 2 + gridPad
         const rad2 = Vec3.create(r2, r2, r2)
         Vec3.mul(rad2, rad2, delta)
-        const r2sq = r2 * r2
 
         const [ begX, begY, begZ ] = Vec3.floor(beg, Vec3.sub(beg, c, rad2))
         const [ endX, endY, endZ ] = Vec3.ceil(end, Vec3.add(end, c, rad2))
@@ -141,96 +133,45 @@ async function projectPoints (ctx:  RuntimeContext, state: MolSurfCalcState) {
                 for (let zi = begZ; zi < endZ; ++zi) {
                     Vec3.set(p, xi, yi, zi)
                     Vec3.div(p, p, delta)
-                    const distSq = Vec3.squaredDistance(p, v)
-                    if (distSq <= r2sq) {
-                        const dens = Math.exp(-alpha * (distSq / rSq))
-                        space.add(data, xi, yi, zi, dens)
-                        if (dens > space.get(densData, xi, yi, zi)) {
-                            space.set(densData, xi, yi, zi, dens)
-                            space.set(idData, xi, yi, zi, i)
-                        }
-                    }
-                }
-            }
-        }
 
-        if (i % 10000 === 0 && ctx.shouldUpdate) {
-            await ctx.update({ message: 'projecting points', current: i, max: n })
-        }
-    }
-
-    for (let i = 0; i < nAtoms; i++) {
-        const ax = xCoord[ i ]
-        const ay = yCoord[ i ]
-        const az = zCoord[ i ]
-        const ar = radius[ i ]
-        const ar2 = radiusSq[ i ]
-
-        state.neighbours = lookup3D.find(ax, ay, az, ar)
-
-        // Number of grid points, round this up...
-        const ng = Math.ceil(ar * scaleFactor)
-
-        // Center of the atom, mapped to grid points (take floor)
-        const iax = Math.floor(scaleFactor * (ax - min[ 0 ]))
-        const iay = Math.floor(scaleFactor * (ay - min[ 1 ]))
-        const iaz = Math.floor(scaleFactor * (az - min[ 2 ]))
-
-        // Extents of grid to consider for this atom
-        const minx = Math.max(0, iax - ng)
-        const miny = Math.max(0, iay - ng)
-        const minz = Math.max(0, iaz - ng)
+                    const dv = Vec3()
+                    Vec3.sub(dv, p, v)
 
-        // Add two to these points:
-        // - iax are floor'd values so this ensures coverage
-        // - these are loop limits (exclusive)
-        const maxx = Math.min(dim[ 0 ], iax + ng + 2)
-        const maxy = Math.min(dim[ 1 ], iay + ng + 2)
-        const maxz = Math.min(dim[ 2 ], iaz + ng + 2)
+                    // const distSq = Vec3.squaredDistance(p, v)
+                    const dSq = Vec3.squaredMagnitude(dv)
 
-        for (let ix = minx; ix < maxx; ix++) {
-            const dx = gridx[ ix ] - ax
-            const xoffset = dim[ 1 ] * dim[ 2 ] * ix
-
-            for (let iy = miny; iy < maxy; iy++) {
-                const dy = gridy[ iy ] - ay
-                const dxy2 = dx * dx + dy * dy
-                const xyoffset = xoffset + dim[ 2 ] * iy
-
-                for (let iz = minz; iz < maxz; iz++) {
-                    const dz = gridz[ iz ] - az
-                    const d2 = dxy2 + dz * dz
-
-                    if (d2 < ar2) {
-                        const idx = iz + xyoffset
-
-                        if (grid[idx] < 0.0) {
+                    if (dSq < rSq) {
+                        const val = space.get(data, xi, yi, zi)
+                        if (val < 0.0) {
                             // Unvisited, make positive
-                            grid[ idx ] = -grid[ idx ]
+                            space.set(data, xi, yi, zi, -val)
                         }
+
                         // Project on to the surface of the sphere
                         // sp is the projected point ( dx, dy, dz ) * ( ra / d )
-                        const d = Math.sqrt(d2)
-                        const ap = ar / d
-                        let spx = dx * ap
-                        let spy = dy * ap
-                        let spz = dz * ap
-
-                        spx += ax
-                        spy += ay
-                        spz += az
-
-                        if (obscured(state, spx, spy, spz, i, -1) === -1) {
-                            const dd = ar - d
-                            if (dd < grid[ idx ]) {
-                                grid[ idx ] = dd
-                                atomIndex[ idx ] = i
+                        // const dist = Math.sqrt(distSq)
+                        const d = Math.sqrt(dSq)
+                        const ap = rad / d
+                        Vec3.scale(sp, dv, ap)
+                        Vec3.add(sp, sp, v)
+                        Vec3.add(sp, sp, min)
+                        // Vec3.add(sp, v, Vec3.setMagnitude(sp, Vec3.sub(sp, p, v), rad - probeRadius))
+
+                        if (obscured(state, sp[0], sp[1], sp[2], i, -1) === -1) {
+                            const dd = rad - d
+                            if (dd < space.get(data, xi, yi, zi)) {
+                                space.set(data, xi, yi, zi, dd)
+                                space.set(idData, xi, yi, zi, i)
                             }
                         }
                     }
                 }
             }
         }
+
+        if (i % 10000 === 0 && ctx.shouldUpdate) {
+            await ctx.update({ message: 'projecting points', current: i, max: n })
+        }
     }
 }
 
@@ -239,18 +180,27 @@ const atob = Vec3()
 const mid = Vec3()
 const n1 = Vec3()
 const n2 = Vec3()
+const v = Vec3()
+const p = Vec3()
+const c = Vec3()
+const beg = Vec3()
+const end = Vec3()
+const radVec = Vec3()
 function projectTorus (state: MolSurfCalcState, a: number, b: number) {
-    const r1 = state.radius[a]
-    const r2 = state.radius[b]
-    const dx = atob[0] = state.xCoord[b] - state.xCoord[a]
-    const dy = atob[1] = state.yCoord[b] - state.yCoord[a]
-    const dz = atob[2] = state.zCoord[b] - state.zCoord[a]
-    const d2 = dx * dx + dy * dy + dz * dz
+    const { position, min, delta, space, data, idData } = state
+    const { cosTable, sinTable, probePositions, probeRadius, resolution } = state
+
+    const r1 = position.radius[a] + probeRadius
+    const r2 = position.radius[b] + probeRadius
+    const dx = atob[0] = position.x[b] - position.x[a]
+    const dy = atob[1] = position.y[b] - position.y[a]
+    const dz = atob[2] = position.z[b] - position.z[a]
+    const dSq = dx * dx + dy * dy + dz * dz
 
     // This check now redundant as already done in AVHash.withinRadii
-    // if (d2 > ((r1 + r2) * (r1 + r2))){ return; }
+    if (dSq > ((r1 + r2) * (r1 + r2))) { return }
 
-    const d = Math.sqrt(d2)
+    const d = Math.sqrt(dSq)
 
     // Find angle between a->b vector and the circle
     // of their intersection by cosine rule
@@ -276,58 +226,53 @@ function projectTorus (state: MolSurfCalcState, a: number, b: number) {
     Vec3.scale(n2, n2, rInt)
     Vec3.scale(atob, atob, dmp)
 
-    mid[0] = atob[0] + state.xCoord[a]
-    mid[1] = atob[1] + state.yCoord[a]
-    mid[2] = atob[2] + state.zCoord[a]
+    mid[0] = atob[0] + position.x[a]
+    mid[1] = atob[1] + position.y[a]
+    mid[2] = atob[2] + position.z[a]
 
     state.lastClip = -1
 
-    const { ngTorus, cosTable, sinTable, scaleFactor } = state
-
-    for (let i = 0; i < state.probePositions; i++) {
-        const cost = cosTable[ i ]
-        const sint = sinTable[ i ]
+    for (let i = 0; i < probePositions; ++i) {
+        const cost = cosTable[i]
+        const sint = sinTable[i]
 
         const px = mid[0] + cost * n1[0] + sint * n2[0]
         const py = mid[1] + cost * n1[1] + sint * n2[1]
         const pz = mid[2] + cost * n1[2] + sint * n2[2]
 
         if (obscured(state, px, py, pz, a, b) === -1) {
-            // As above, iterate over our grid...
-            // px, py, pz in grid coords
-            const iax = Math.floor(scaleFactor * (px - min[0]))
-            const iay = Math.floor(scaleFactor * (py - min[1]))
-            const iaz = Math.floor(scaleFactor * (pz - min[2]))
-
-            const minx = Math.max(0, iax - ngTorus)
-            const miny = Math.max(0, iay - ngTorus)
-            const minz = Math.max(0, iaz - ngTorus)
-
-            const maxx = Math.min(dim[0], iax + ngTorus + 2)
-            const maxy = Math.min(dim[1], iay + ngTorus + 2)
-            const maxz = Math.min(dim[2], iaz + ngTorus + 2)
-
-            for (let ix = minx; ix < maxx; ix++) {
-                const dx = px - gridx[ ix ]
-                const xoffset = dim[1] * dim[2] * ix
-
-                for (let iy = miny; iy < maxy; iy++) {
-                    const dy = py - gridy[iy]
-                    const  dxy2 = dx * dx + dy * dy
-                    const  xyoffset = xoffset + dim[2] * iy
-
-                    for (let iz = minz; iz < maxz; iz++) {
-                        const dz = pz - gridz[iz]
-                        const d2 = dxy2 + dz * dz
-                        const  idx = iz + xyoffset
-                        const  current = grid[idx]
-
-                        if (current > 0.0 && d2 < (current * current)) {
-                            grid[idx] = Math.sqrt(d2)
+
+            Vec3.set(v, px, py, pz)
+
+            Vec3.sub(v, v, min)
+            Vec3.mul(c, v, delta)
+
+            const rad = probeRadius / resolution
+            Vec3.set(radVec, rad, rad, rad)
+            Vec3.mul(radVec, radVec, delta)
+
+            const [ begX, begY, begZ ] = Vec3.floor(beg, Vec3.sub(beg, c, radVec))
+            const [ endX, endY, endZ ] = Vec3.ceil(end, Vec3.add(end, c, radVec))
+
+            for (let xi = begX; xi < endX; ++xi) {
+                for (let yi = begY; yi < endY; ++yi) {
+                    for (let zi = begZ; zi < endZ; ++zi) {
+                        Vec3.set(p, xi, yi, zi)
+                        Vec3.div(p, p, delta)
+
+                        const dv = Vec3()
+                        Vec3.sub(dv, v, p)
+                        const dSq = Vec3.squaredMagnitude(dv)
+
+                        // const distSq = Vec3.squaredDistance(p, v)
+                        const current = space.get(data, xi, yi, zi)
+
+                        if (current > 0.0 && dSq < (current * current)) {
+                            space.set(data, xi, yi, zi, Math.sqrt(dSq))
                             // Is this grid point closer to a or b?
                             // Take dot product of atob and gridpoint->p (dx, dy, dz)
                             const dp = dx * atob[0] + dy * atob [1] + dz * atob[2]
-                            atomIndex[idx] = dp < 0.0 ? b : a
+                            space.set(idData, xi, yi, zi, dp < 0.0 ? b : a)
                         }
                     }
                 }
@@ -336,30 +281,19 @@ function projectTorus (state: MolSurfCalcState, a: number, b: number) {
     }
 }
 
-function projectTorii (state: MolSurfCalcState) {
-    const { n: nAtoms, neighbours, hash, xCoord, yCoord, zCoord, radius } = state
-    for (let i = 0; i < nAtoms; i++) {
-        hash.withinRadii(xCoord[i], yCoord[i], zCoord[i], radius[i], neighbours)
-        let ia = 0
-        let ni = neighbours[ ia ]
-        while (ni >= 0) {
-            if (i < ni) {
-            projectTorus(state, i, ni)
-            }
-            ni = neighbours[ ++ia ]
+async function projectTorii (ctx: RuntimeContext, state: MolSurfCalcState) {
+    const { n, lookup3d, position, probeRadius, resolution } = state
+    const { x, y, z, radius } = position
+    for (let i = 0; i < n; ++i) {
+        state.neighbours = lookup3d.find(x[i], y[i], z[i], radius[i] + probeRadius / resolution)
+        for (let j = 0, jl = state.neighbours.count; j < jl; ++j) {
+            const ib = state.neighbours.indices[j]
+            if (i < ib) projectTorus(state, i, ib)
         }
-    }
-}
 
-function fixNegatives (grid: NumberArray) {
-    for (let i = 0; i < grid.length; i++) {
-        if (grid[i] < 0) grid[i] = 0
-    }
-}
-
-function fixAtomIDs (atomIndex: NumberArray, indexList: NumberArray) {
-    for (let i = 0; i < atomIndex.length; i++) {
-        atomIndex[i] = indexList[atomIndex[i]]
+        if (i % 10000 === 0 && ctx.shouldUpdate) {
+            await ctx.update({ message: 'projecting torii', current: i, max: n })
+        }
     }
 }
 
@@ -371,198 +305,118 @@ interface MolSurfCalcState {
     /** Neighbours as transient result array from lookup3d */
     neighbours: Result<number>
 
-    lookup3D: Lookup3D
-    position: PositionData
-    radius: (index: number) => number
+    lookup3d: Lookup3D
+    position: Required<PositionData>
     delta: Vec3
     min: Vec3
 
     maxRadius: number
 
     n: number
-    scaleFactor: number
+    resolution: number
+    probeRadius: number
 
     /** Angle lookup tables */
     cosTable: Float32Array
     sinTable: Float32Array
-
     probePositions: number
-    ngTorus: number
+
+    expandedBox: Box3D
+    space: Tensor.Space
+    data: Tensor.Data
+    idData: Tensor.Data
 }
 
+async function createState(ctx: RuntimeContext, position: Required<PositionData>, maxRadius: number, props: MolecularSurfaceCalculationProps): Promise<MolSurfCalcState> {
+    const { resolution, probeRadius, probePositions } = props
 
+    const lookup3d = GridLookup3D(position)
+    const box = lookup3d.boundary.box
+    const { indices } = position
+    const n = OrderedSet.size(indices)
 
-export const MolecularSurfaceCalculationParams = {
-    scaleFactor: PD.Numeric(2, { min: 0.1, max: 10, step: 0.1 }),
-    probeRadius: PD.Numeric(1.4, { min: 0, max: 10, step: 0.1 }),
-    probePositions: PD.Numeric(30, { min: 12, max: 90, step: 1 }),
-}
-export const DefaultMolecularSurfaceCalculationProps = PD.getDefaultValues(MolecularSurfaceCalculationParams)
-export type MolecularSurfaceCalculationProps = typeof DefaultMolecularSurfaceCalculationProps
+    const pad = maxRadius * 2 + resolution
+    const expandedBox = Box3D.expand(Box3D.empty(), box, Vec3.create(pad, pad, pad))
+    const extent = Vec3.sub(Vec3.zero(), expandedBox.max, expandedBox.min)
+    const min = expandedBox.min
+
+    const delta = getDelta(Box3D.expand(Box3D.empty(), box, Vec3.create(pad, pad, pad)), resolution)
+    const dim = Vec3.zero()
+    Vec3.ceil(dim, Vec3.mul(dim, extent, delta))
+    console.log('grid dim surf', dim)
 
-function createState(nAtoms: number, props: MolecularSurfaceCalculationProps): MolSurfCalcState {
-    const { scaleFactor, probeRadius, probePositions } = props
     const { cosTable, sinTable } = getAngleTables(probePositions)
-    const ngTorus = Math.max(5, 2 + Math.floor(probeRadius * scaleFactor))
 
+    const space = Tensor.Space(dim, [0, 1, 2], Float32Array)
+    const data = space.create()
+    const idData = space.create()
+
+    fillUniform(data, -1001.0)
+    fillUniform(idData, -1)
 
     return {
         lastClip: -1,
-        neighbours: Int32Array,
+        neighbours: lookup3d.find(0, 0, 0, 0),
+
+        lookup3d,
+        position,
+        delta,
+        min,
 
-        xCoord: new Float32Array(nAtoms),
-        yCoord: new Float32Array(nAtoms),
-        zCoord: new Float32Array(nAtoms),
-        radius: new Float32Array(nAtoms),
-        radiusSq: new Float32Array(nAtoms),
-        maxRadius: 0,
+        maxRadius,
 
-        n: nAtoms,
-        scaleFactor,
+        n,
+        resolution,
+        probeRadius,
 
         cosTable,
         sinTable,
-
         probePositions,
-        ngTorus,
+
+        expandedBox,
+        space,
+        data,
+        idData,
     }
 }
 
 //
 
-export function MolecularSurface(coordList: Float32Array, radiusList: Float32Array, indexList: Uint16Array|Uint32Array) {
-    // Field generation method adapted from AstexViewer (Mike Hartshorn)
-    // by Fred Ludlow.
-    // Other parts based heavily on NGL (Alexander Rose) EDT Surface class
-    //
-    // Should work as a drop-in alternative to EDTSurface (though some of
-    // the EDT paramters are not relevant in this method).
-
-    const nAtoms = radiusList.length
-
-    const x = new Float32Array(nAtoms)
-    const y = new Float32Array(nAtoms)
-    const z = new Float32Array(nAtoms)
-
-    for (let i = 0; i < nAtoms; i++) {
-        const ci = 3 * i
-        x[ i ] = coordList[ ci ]
-        y[ i ] = coordList[ ci + 1 ]
-        z[ i ] = coordList[ ci + 2 ]
-    }
-
-    let bbox = computeBoundingBox(coordList)
-    if (coordList.length === 0) {
-        bbox[ 0 ].set([ 0, 0, 0 ])
-        bbox[ 1 ].set([ 0, 0, 0 ])
-    }
-    const min = bbox[0]
-    const max = bbox[1]
-
-    let r: Float32Array, r2: Float32Array // Atom positions, expanded radii (squared)
-    let maxRadius: number
-
-    // Parameters
-    let probeRadius: number, scaleFactor: number, setAtomID: boolean, probePositions: number
-
-    // Grid params
-    let dim: Float32Array, matrix: Float32Array, grid: NumberArray, atomIndex: Int32Array
-
-    // grid indices -> xyz coords
-    let gridx: Float32Array, gridy: Float32Array, gridz: Float32Array
-
-    // Spatial Hash
-    let hash: iAVHash
-
-    // Neighbour array to be filled by hash
-    let neighbours: Int32Array
-
-    let ngTorus: number
-
-    function init (_probeRadius?: number, _scaleFactor?: number, _setAtomID?: boolean, _probePositions?: number) {
-        probeRadius = defaults(_probeRadius, 1.4)
-        scaleFactor = defaults(_scaleFactor, 2.0)
-        setAtomID = defaults(_setAtomID, true)
-        probePositions = defaults(_probePositions, 30)
-
-        r = new Float32Array(nAtoms)
-        r2 = new Float32Array(nAtoms)
-
-        for (let i = 0; i < r.length; ++i) {
-            var rExt = radiusList[ i ] + probeRadius
-            r[ i ] = rExt
-            r2[ i ] = rExt * rExt
-        }
-
-        maxRadius = 0
-        for (let j = 0; j < r.length; ++j) {
-            if (r[ j ] > maxRadius) maxRadius = r[ j ]
-        }
-
-        initializeGrid()
-        getAngleTables(probePositions)
-        initializeHash()
-
-        lastClip = -1
-    }
-
-    function initializeGrid () {
-        const surfGrid = getSurfaceGrid(
-            min, max, maxRadius, scaleFactor, 0.0
-        )
-
-        scaleFactor = surfGrid.scaleFactor
-        dim = surfGrid.dim
-        matrix = surfGrid.matrix
-
-        ngTorus = Math.max(5, 2 + Math.floor(probeRadius * scaleFactor))
-
-        grid = fillUniform(new Float32Array(dim[0] * dim[1] * dim[2]), -1001.0)
-
-        atomIndex = new Int32Array(grid.length)
-
-        gridx = new Float32Array(dim[0])
-        gridy = new Float32Array(dim[1])
-        gridz = new Float32Array(dim[2])
-
-        fillGridDim(gridx, min[0], 1 / scaleFactor)
-        fillGridDim(gridy, min[1], 1 / scaleFactor)
-        fillGridDim(gridz, min[2], 1 / scaleFactor)
-    }
-
-
-
-    function initializeHash () {
-        hash = makeAVHash(x, y, z, r, min, max, 2.01 * maxRadius)
-        neighbours = new Int32Array(hash.neighbourListLength)
-    }
-
+export const MolecularSurfaceCalculationParams = {
+    resolution: PD.Numeric(0.5, { min: 0.01, max: 10, step: 0.01 }),
+    probeRadius: PD.Numeric(1.4, { min: 0, max: 10, step: 0.1 }),
+    probePositions: PD.Numeric(30, { min: 12, max: 90, step: 1 }),
+}
+export const DefaultMolecularSurfaceCalculationProps = PD.getDefaultValues(MolecularSurfaceCalculationParams)
+export type MolecularSurfaceCalculationProps = typeof DefaultMolecularSurfaceCalculationProps
 
 
+export async function calcMolecularSurface(ctx: RuntimeContext, position: Required<PositionData>, maxRadius: number,  props: MolecularSurfaceCalculationProps) {
+    // Field generation method adapted from AstexViewer (Mike Hartshorn) by Fred Ludlow.
+    // Other parts based heavily on NGL (Alexander Rose) EDT Surface class
 
+    console.time('MolecularSurface')
 
-    function getVolume (probeRadius: number, scaleFactor: number, setAtomID: boolean) {
-        // Basic steps are:
-        // 1) Initialize
-        // 2) Project points
-        // 3) Project torii
+    console.time('MolecularSurface createState')
+    const state = await createState(ctx, position, maxRadius, props)
+    console.timeEnd('MolecularSurface createState')
 
-        console.time('AVSurface.getVolume')
+    console.time('MolecularSurface projectPoints')
+    await projectPoints(ctx, state)
+    console.timeEnd('MolecularSurface projectPoints')
 
-        console.time('AVSurface.init')
-        init(probeRadius, scaleFactor, setAtomID)
-        console.timeEnd('AVSurface.init')
+    console.time('MolecularSurface projectTorii')
+    await projectTorii(ctx, state)
+    console.timeEnd('MolecularSurface projectTorii')
 
-        console.time('AVSurface.projectPoints')
-        projectPoints()
-        console.timeEnd('AVSurface.projectPoints')
+    console.timeEnd('MolecularSurface')
 
-        console.time('AVSurface.projectTorii')
-        projectTorii()
-        console.timeEnd('AVSurface.projectTorii')
-        fixNegatives()
-        fixAtomIDs()
+    const field = Tensor.create(state.space, state.data)
+    const idField = Tensor.create(state.space, state.idData)
 
-        console.timeEnd('AVSurface.getVolume')
-    }
+    const transform = Mat4.identity()
+    Mat4.fromScaling(transform, Vec3.inverse(Vec3.zero(), state.delta))
+    Mat4.setTranslation(transform, state.expandedBox.min)
+    console.log({ field, idField, transform, state })
+    return { field, idField, transform }
 }

+ 6 - 3
src/mol-repr/structure/visual/molecular-surface-mesh.ts

@@ -14,7 +14,8 @@ import { Mesh } from 'mol-geo/geometry/mesh/mesh';
 import { computeMarchingCubesMesh } from 'mol-geo/util/marching-cubes/algorithm';
 import { VisualContext } from 'mol-repr/visual';
 import { Theme } from 'mol-theme/theme';
-import { MolecularSurfaceCalculationParams, MolecularSurfaceCalculationProps, computeUnitMolecularSurface } from './util/molecular-surface';
+import { computeUnitMolecularSurface } from './util/molecular-surface';
+import { MolecularSurfaceCalculationParams, MolecularSurfaceCalculationProps } from 'mol-math/geometry/molecular-surface';
 
 export const MolecularSurfaceMeshParams = {
     ...UnitsMeshParams,
@@ -26,10 +27,11 @@ export type MolecularSurfaceMeshParams = typeof MolecularSurfaceMeshParams
 //
 
 async function createMolecularSurfaceMesh(ctx: VisualContext, unit: Unit, structure: Structure, theme: Theme, props: MolecularSurfaceCalculationProps, mesh?: Mesh): Promise<Mesh> {
-    const { transform, field, idField } = await computeUnitMolecularSurface(unit, props).runInContext(ctx.runtime)
+    console.log(props)
 
+    const { transform, field, idField } = await computeUnitMolecularSurface(unit, props).runInContext(ctx.runtime)
     const params = {
-        isoLevel: 1,
+        isoLevel: props.probeRadius,
         scalarField: field,
         idField
     }
@@ -52,6 +54,7 @@ export function MolecularSurfaceMeshVisual(materialId: number): UnitsVisual<Mole
         setUpdateState: (state: VisualUpdateState, newProps: PD.Values<MolecularSurfaceMeshParams>, currentProps: PD.Values<MolecularSurfaceMeshParams>) => {
             if (newProps.resolution !== currentProps.resolution) state.createGeometry = true
             if (newProps.probeRadius !== currentProps.probeRadius) state.createGeometry = true
+            if (newProps.probePositions !== currentProps.probePositions) state.createGeometry = true
         }
     }, materialId)
 }

+ 22 - 16
src/mol-repr/structure/visual/util/molecular-surface.ts

@@ -4,32 +4,38 @@
  * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
-import { ParamDefinition as PD } from 'mol-util/param-definition';
 import { Unit } from 'mol-model/structure';
 import { Task, RuntimeContext } from 'mol-task';
 import { getUnitConformationAndRadius } from './common';
-import { PositionData, Box3D, DensityData } from 'mol-math/geometry';
-
-export const MolecularSurfaceCalculationParams = {
-    resolution: PD.Numeric(1, { min: 0.1, max: 10, step: 0.1 }),
-    probeRadius: PD.Numeric(0, { min: 0, max: 10, step: 0.1 }),
-}
-export const DefaultMolecularSurfaceCalculationProps = PD.getDefaultValues(MolecularSurfaceCalculationParams)
-export type MolecularSurfaceCalculationProps = typeof DefaultMolecularSurfaceCalculationProps
+import { PositionData, DensityData } from 'mol-math/geometry';
+import { MolecularSurfaceCalculationProps, calcMolecularSurface } from 'mol-math/geometry/molecular-surface';
+import { OrderedSet } from 'mol-data/int';
 
 export function computeUnitMolecularSurface(unit: Unit, props: MolecularSurfaceCalculationProps) {
     const { position, radius } = getUnitConformationAndRadius(unit)
+
     return Task.create('Molecular Surface', async ctx => {
-        return await MolecularSurface(ctx, position, unit.lookup3d.boundary.box, radius, props);
+        const { indices } = position
+        const n = OrderedSet.size(indices)
+        const radii = new Float32Array(n)
+
+        let maxRadius = 0
+        for (let i = 0; i < n; ++i) {
+            const r = radius(OrderedSet.getAt(indices, i))
+            if (maxRadius < r) maxRadius = r
+            radii[i] = r
+
+            if (i % 10000 === 0 && ctx.shouldUpdate) {
+                await ctx.update({ message: 'calculating max radius', current: i, max: n })
+            }
+        }
+
+        return await MolecularSurface(ctx, { ...position, radius: radii }, maxRadius, props);
     });
 }
 
 //
 
-async function MolecularSurface(ctx: RuntimeContext, position: PositionData, box: Box3D, radius: (index: number) => number,  props: MolecularSurfaceCalculationProps): Promise<DensityData> {
-    return {
-        transform: Mat4,
-        field: Tensor,
-        idField: Tensor,
-    }
+async function MolecularSurface(ctx: RuntimeContext, position: Required<PositionData>, maxRadius: number,  props: MolecularSurfaceCalculationProps): Promise<DensityData> {
+    return calcMolecularSurface(ctx, position, maxRadius, props)
 }

+ 28 - 2
src/tests/browser/render-structure.ts

@@ -13,6 +13,8 @@ import { SizeTheme } from 'mol-theme/size';
 import { CartoonRepresentationProvider } from 'mol-repr/structure/representation/cartoon';
 import { trajectoryFromMmCIF } from 'mol-model-formats/structure/mmcif';
 import { computeModelDSSP } from 'mol-model/structure/model/properties/utils/secondary-structure';
+import { MolecularSurfaceRepresentationProvider } from 'mol-repr/structure/representation/molecular-surface';
+import { BallAndStickRepresentationProvider } from 'mol-repr/structure/representation/ball-and-stick';
 
 const parent = document.getElementById('app')!
 parent.style.width = '100%'
@@ -61,8 +63,16 @@ function getCartoonRepr() {
     return CartoonRepresentationProvider.factory(reprCtx, CartoonRepresentationProvider.getParams)
 }
 
+function getBallAndStickRepr() {
+    return BallAndStickRepresentationProvider.factory(reprCtx, BallAndStickRepresentationProvider.getParams)
+}
+
+function getMolecularSurfaceRepr() {
+    return MolecularSurfaceRepresentationProvider.factory(reprCtx, MolecularSurfaceRepresentationProvider.getParams)
+}
+
 async function init() {
-    const cif = await downloadFromPdb('3j3q')
+    const cif = await downloadFromPdb('1crn')
     const models = await getModels(cif)
     console.time('computeModelDSSP')
     const secondaryStructure = computeModelDSSP(models[0].atomicHierarchy, models[0].atomicConformation)
@@ -70,6 +80,8 @@ async function init() {
     (models[0].properties as any).secondaryStructure = secondaryStructure
     const structure = await getStructure(models[0])
     const cartoonRepr = getCartoonRepr()
+    const ballAndStick = getBallAndStickRepr()
+    const molecularSurfaceRepr = getMolecularSurfaceRepr()
 
     cartoonRepr.setTheme({
         color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
@@ -77,7 +89,21 @@ async function init() {
     })
     await cartoonRepr.createOrUpdate({ ...CartoonRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
 
-    canvas3d.add(cartoonRepr)
+    ballAndStick.setTheme({
+        color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
+        size: reprCtx.sizeThemeRegistry.create('uniform', { structure })
+    })
+    await ballAndStick.createOrUpdate({ ...BallAndStickRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
+
+    molecularSurfaceRepr.setTheme({
+        color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
+        size: reprCtx.sizeThemeRegistry.create('physical', { structure })
+    })
+    await molecularSurfaceRepr.createOrUpdate({ ...MolecularSurfaceRepresentationProvider.defaultValues, quality: 'custom', alpha: 1.0, flatShaded: true, doubleSided: true, resolution: 0.5 }, structure).run()
+
+    // canvas3d.add(cartoonRepr)
+    // canvas3d.add(ballAndStick)
+    canvas3d.add(molecularSurfaceRepr)
     canvas3d.resetCamera()
 }