Ver Fonte

wip, molsurf

Alexander Rose há 6 anos atrás
pai
commit
ec379968cd

+ 116 - 85
src/mol-math/geometry/molecular-surface.ts

@@ -45,6 +45,12 @@ function getAngleTables (probePositions: number): AnglesTables {
     return { cosTable, sinTable}
 }
 
+function fillGridDim(a: Float32Array, start: number, step: number) {
+    for (let i = 0; i < a.length; i++) {
+        a[i] = start + (step * i)
+    }
+}
+
 /**
  * Is the point at x,y,z obscured by any of the atoms specifeid by indices in neighbours.
  * Ignore indices a and b (these are the relevant atoms in projectPoints/Torii)
@@ -74,7 +80,7 @@ function obscured (state: MolSurfCalcState, x: number, y: number, z: number, a:
 
 function singleAtomObscures (state: MolSurfCalcState, ai: number, x: number, y: number, z: number) {
     const j = OrderedSet.getAt(state.position.indices, ai)
-    const r = state.position.radius[j] + state.probeRadius
+    const r = state.position.radius[j]
     const dx = state.position.x[j] - x
     const dy = state.position.y[j] - y
     const dz = state.position.z[j] - z
@@ -93,75 +99,73 @@ function singleAtomObscures (state: MolSurfCalcState, ai: number, x: number, y:
  *             itself and delta
  */
 async function projectPoints (ctx:  RuntimeContext, state: MolSurfCalcState) {
-    const { position, probeRadius, lookup3d, min, delta, space, data, idData } = state
+    const { position, lookup3d, min, space, data, idData, scaleFactor } = state
+    const { gridx, gridy, gridz } = state
 
     const { indices, x, y, z, radius } = position
     const n = OrderedSet.size(indices)
 
-    const v = Vec3()
-    const p = Vec3()
-    const c = Vec3()
-    const sp = Vec3()
-
-    const beg = Vec3()
-    const end = Vec3()
-
-    // const gridPad = 1 / Math.max(...delta)
+    const [ dimX, dimY, dimZ ] = space.dimensions
+    const iu = dimZ, iv = dimY, iuv = iu * iv
 
     for (let i = 0; i < n; ++i) {
         const j = OrderedSet.getAt(indices, i)
+        const vx = x[j], vy = y[j], vz = z[j]
+        const rad = radius[j]
+        const rSq = rad * rad
 
-        Vec3.set(v, x[j], y[j], z[j])
-
-        state.neighbours = lookup3d.find(v[0], v[1], v[2], radius[j] + probeRadius)
+        state.neighbours = lookup3d.find(vx, vy, vz, rad)
 
-        Vec3.sub(v, v, min)
-        Vec3.mul(c, v, delta)
+        // Number of grid points, round this up...
+        const ng = Math.ceil(rad * scaleFactor)
 
-        const rad = radius[j] + probeRadius
-        const rSq = rad * rad
+        // Center of the atom, mapped to grid points (take floor)
+        const iax = Math.floor(scaleFactor * (vx - min[0]))
+        const iay = Math.floor(scaleFactor * (vy - min[1]))
+        const iaz = Math.floor(scaleFactor * (vz - min[2]))
 
-        const r2 = rad // * 2 + gridPad
-        const rad2 = Vec3.create(r2, r2, r2)
-        Vec3.mul(rad2, rad2, delta)
+        // Extents of grid to consider for this atom
+        const begX = Math.max(0, iax - ng)
+        const begY = Math.max(0, iay - ng)
+        const begZ = Math.max(0, iaz - ng)
 
-        const [ begX, begY, begZ ] = Vec3.floor(beg, Vec3.sub(beg, c, rad2))
-        const [ endX, endY, endZ ] = Vec3.ceil(end, Vec3.add(end, c, rad2))
+        // Add two to these points:
+        // - iax are floor'd values so this ensures coverage
+        // - these are loop limits (exclusive)
+        const endX = Math.min(dimX, iax + ng + 2)
+        const endY = Math.min(dimY, iay + ng + 2)
+        const endZ = Math.min(dimZ, iaz + ng + 2)
 
         for (let xi = begX; xi < endX; ++xi) {
+            const dx = gridx[xi] - vx
+            const xIdx = xi * iuv
             for (let yi = begY; yi < endY; ++yi) {
+                const dy = gridy[yi] - vy
+                const dxySq = dx * dx + dy * dy
+                const xyIdx = yi * iu + xIdx
                 for (let zi = begZ; zi < endZ; ++zi) {
-                    Vec3.set(p, xi, yi, zi)
-                    Vec3.div(p, p, delta)
-
-                    const dv = Vec3()
-                    Vec3.sub(dv, p, v)
-
-                    // const distSq = Vec3.squaredDistance(p, v)
-                    const dSq = Vec3.squaredMagnitude(dv)
+                    const dz = gridz[zi] - vz
+                    const dSq = dxySq + dz * dz
 
                     if (dSq < rSq) {
-                        const val = space.get(data, xi, yi, zi)
-                        if (val < 0.0) {
-                            // Unvisited, make positive
-                            space.set(data, xi, yi, zi, -val)
-                        }
+                        const idx = zi + xyIdx
+
+                        // if unvisited, make positive
+                        if (data[idx] < 0.0) data[idx] *= -1
 
                         // Project on to the surface of the sphere
                         // sp is the projected point ( dx, dy, dz ) * ( ra / d )
-                        // const dist = Math.sqrt(distSq)
                         const d = Math.sqrt(dSq)
                         const ap = rad / d
-                        Vec3.scale(sp, dv, ap)
-                        Vec3.add(sp, sp, v)
-                        Vec3.add(sp, sp, min)
-                        // Vec3.add(sp, v, Vec3.setMagnitude(sp, Vec3.sub(sp, p, v), rad - probeRadius))
+                        const spx = dx * ap + vx
+                        const spy = dy * ap + vy
+                        const spz = dz * ap + vz
 
-                        if (obscured(state, sp[0], sp[1], sp[2], i, -1) === -1) {
+                        if (obscured(state, spx, spy, spz, i, -1) === -1) {
                             const dd = rad - d
-                            if (dd < space.get(data, xi, yi, zi)) {
-                                space.set(data, xi, yi, zi, dd)
-                                space.set(idData, xi, yi, zi, i)
+                            if (dd < data[idx]) {
+                                data[idx] = dd
+                                idData[idx] = i
                             }
                         }
                     }
@@ -180,34 +184,34 @@ const atob = Vec3()
 const mid = Vec3()
 const n1 = Vec3()
 const n2 = Vec3()
-const v = Vec3()
-const p = Vec3()
-const c = Vec3()
-const beg = Vec3()
-const end = Vec3()
-const radVec = Vec3()
 function projectTorus (state: MolSurfCalcState, a: number, b: number) {
-    const { position, min, delta, space, data, idData } = state
-    const { cosTable, sinTable, probePositions, probeRadius, resolution } = state
+    const { position, min, space, data, idData } = state
+    const { cosTable, sinTable, probePositions, probeRadius, scaleFactor } = state
+    const { gridx, gridy, gridz } = state
 
-    const r1 = position.radius[a] + probeRadius
-    const r2 = position.radius[b] + probeRadius
+    const [ dimX, dimY, dimZ ] = space.dimensions
+    const iu = dimZ, iv = dimY, iuv = iu * iv
+
+    const ng = Math.max(5, 2 + Math.floor(probeRadius * scaleFactor))
+
+    const rA = position.radius[a]
+    const rB = position.radius[b]
     const dx = atob[0] = position.x[b] - position.x[a]
     const dy = atob[1] = position.y[b] - position.y[a]
     const dz = atob[2] = position.z[b] - position.z[a]
     const dSq = dx * dx + dy * dy + dz * dz
 
     // This check now redundant as already done in AVHash.withinRadii
-    if (dSq > ((r1 + r2) * (r1 + r2))) { return }
+    if (dSq > ((rA + rB) * (rA + rB))) { return }
 
     const d = Math.sqrt(dSq)
 
     // Find angle between a->b vector and the circle
     // of their intersection by cosine rule
-    const cosA = (r1 * r1 + d * d - r2 * r2) / (2.0 * r1 * d)
+    const cosA = (rA * rA + d * d - rB * rB) / (2.0 * rA * d)
 
     // distance along a->b at intersection
-    const dmp = r1 * cosA
+    const dmp = rA * cosA
 
     Vec3.normalize(atob, atob)
 
@@ -220,7 +224,7 @@ function projectTorus (state: MolSurfCalcState, a: number, b: number) {
     Vec3.normalize(n2, n2)
 
     // r is radius of circle of intersection
-    const rInt = Math.sqrt(r1 * r1 - dmp * dmp)
+    const rInt = Math.sqrt(rA * rA - dmp * dmp)
 
     Vec3.scale(n1, n1, rInt)
     Vec3.scale(n2, n2, rInt)
@@ -241,38 +245,40 @@ function projectTorus (state: MolSurfCalcState, a: number, b: number) {
         const pz = mid[2] + cost * n1[2] + sint * n2[2]
 
         if (obscured(state, px, py, pz, a, b) === -1) {
+            const iax = Math.floor(scaleFactor * (px - min[0]))
+            const iay = Math.floor(scaleFactor * (py - min[1]))
+            const iaz = Math.floor(scaleFactor * (pz - min[2]))
 
-            Vec3.set(v, px, py, pz)
-
-            Vec3.sub(v, v, min)
-            Vec3.mul(c, v, delta)
+            const begX = Math.max(0, iax - ng)
+            const begY = Math.max(0, iay - ng)
+            const begZ = Math.max(0, iaz - ng)
 
-            const rad = probeRadius / resolution
-            Vec3.set(radVec, rad, rad, rad)
-            Vec3.mul(radVec, radVec, delta)
-
-            const [ begX, begY, begZ ] = Vec3.floor(beg, Vec3.sub(beg, c, radVec))
-            const [ endX, endY, endZ ] = Vec3.ceil(end, Vec3.add(end, c, radVec))
+            const endX = Math.min(dimX, iax + ng + 2)
+            const endY = Math.min(dimY, iay + ng + 2)
+            const endZ = Math.min(dimZ, iaz + ng + 2)
 
             for (let xi = begX; xi < endX; ++xi) {
+                const dx = px - gridx[xi]
+                const xIdx = xi * iuv
+
                 for (let yi = begY; yi < endY; ++yi) {
-                    for (let zi = begZ; zi < endZ; ++zi) {
-                        Vec3.set(p, xi, yi, zi)
-                        Vec3.div(p, p, delta)
+                    const dy = py - gridy[yi]
+                    const dxySq = dx * dx + dy * dy
+                    const xyIdx = yi * iu + xIdx
 
-                        const dv = Vec3()
-                        Vec3.sub(dv, v, p)
-                        const dSq = Vec3.squaredMagnitude(dv)
+                    for (let zi = begZ; zi < endZ; ++zi) {
+                        const dz = pz - gridz[zi]
+                        const dSq = dxySq + dz * dz
 
-                        // const distSq = Vec3.squaredDistance(p, v)
-                        const current = space.get(data, xi, yi, zi)
+                        const idx = zi + xyIdx
+                        const current = data[idx]
 
                         if (current > 0.0 && dSq < (current * current)) {
-                            space.set(data, xi, yi, zi, Math.sqrt(dSq))
+                            data[idx] = Math.sqrt(dSq)
                             // Is this grid point closer to a or b?
                             // Take dot product of atob and gridpoint->p (dx, dy, dz)
-                            const dp = dx * atob[0] + dy * atob [1] + dz * atob[2]
-                            space.set(idData, xi, yi, zi, dp < 0.0 ? b : a)
+                            const dp = dx * atob[0] + dy * atob[1] + dz * atob[2]
+                            idData[idx] = dp < 0.0 ? b : a
                         }
                     }
                 }
@@ -282,13 +288,14 @@ function projectTorus (state: MolSurfCalcState, a: number, b: number) {
 }
 
 async function projectTorii (ctx: RuntimeContext, state: MolSurfCalcState) {
-    const { n, lookup3d, position, probeRadius, resolution } = state
-    const { x, y, z, radius } = position
+    const { n, lookup3d, position } = state
+    const { x, y, z, indices, radius } = position
     for (let i = 0; i < n; ++i) {
-        state.neighbours = lookup3d.find(x[i], y[i], z[i], radius[i] + probeRadius / resolution)
+        const k = OrderedSet.getAt(indices, i)
+        state.neighbours = lookup3d.find(x[k], y[k], z[k], radius[k])
         for (let j = 0, jl = state.neighbours.count; j < jl; ++j) {
-            const ib = state.neighbours.indices[j]
-            if (i < ib) projectTorus(state, i, ib)
+            const l = state.neighbours.indices[j]
+            if (k < l) projectTorus(state, k, l)
         }
 
         if (i % 10000 === 0 && ctx.shouldUpdate) {
@@ -308,12 +315,14 @@ interface MolSurfCalcState {
     lookup3d: Lookup3D
     position: Required<PositionData>
     delta: Vec3
+    invDelta: Vec3
     min: Vec3
 
     maxRadius: number
 
     n: number
     resolution: number
+    scaleFactor: number
     probeRadius: number
 
     /** Angle lookup tables */
@@ -321,6 +330,11 @@ interface MolSurfCalcState {
     sinTable: Float32Array
     probePositions: number
 
+    /** grid lookup tables */
+    gridx: Float32Array
+    gridy: Float32Array
+    gridz: Float32Array
+
     expandedBox: Box3D
     space: Tensor.Space
     data: Tensor.Data
@@ -330,6 +344,8 @@ interface MolSurfCalcState {
 async function createState(ctx: RuntimeContext, position: Required<PositionData>, maxRadius: number, props: MolecularSurfaceCalculationProps): Promise<MolSurfCalcState> {
     const { resolution, probeRadius, probePositions } = props
 
+    const scaleFactor = 1 / resolution
+
     const lookup3d = GridLookup3D(position)
     const box = lookup3d.boundary.box
     const { indices } = position
@@ -344,6 +360,7 @@ async function createState(ctx: RuntimeContext, position: Required<PositionData>
     const dim = Vec3.zero()
     Vec3.ceil(dim, Vec3.mul(dim, extent, delta))
     console.log('grid dim surf', dim)
+    const invDelta = Vec3.inverse(Vec3(), delta)
 
     const { cosTable, sinTable } = getAngleTables(probePositions)
 
@@ -354,6 +371,14 @@ async function createState(ctx: RuntimeContext, position: Required<PositionData>
     fillUniform(data, -1001.0)
     fillUniform(idData, -1)
 
+    const gridx = new Float32Array(dim[0])
+    const gridy = new Float32Array(dim[1])
+    const gridz = new Float32Array(dim[2])
+
+    fillGridDim(gridx, min[0], resolution)
+    fillGridDim(gridy, min[1], resolution)
+    fillGridDim(gridz, min[2], resolution)
+
     return {
         lastClip: -1,
         neighbours: lookup3d.find(0, 0, 0, 0),
@@ -361,18 +386,24 @@ async function createState(ctx: RuntimeContext, position: Required<PositionData>
         lookup3d,
         position,
         delta,
+        invDelta,
         min,
 
         maxRadius,
 
         n,
         resolution,
+        scaleFactor,
         probeRadius,
 
         cosTable,
         sinTable,
         probePositions,
 
+        gridx,
+        gridy,
+        gridz,
+
         expandedBox,
         space,
         data,

+ 1 - 1
src/mol-repr/structure/visual/util/molecular-surface.ts

@@ -23,7 +23,7 @@ export function computeUnitMolecularSurface(unit: Unit, props: MolecularSurfaceC
         for (let i = 0; i < n; ++i) {
             const r = radius(OrderedSet.getAt(indices, i))
             if (maxRadius < r) maxRadius = r
-            radii[i] = r
+            radii[i] = r + props.probeRadius
 
             if (i % 10000 === 0 && ctx.shouldUpdate) {
                 await ctx.update({ message: 'calculating max radius', current: i, max: n })

+ 28 - 11
src/tests/browser/render-structure.ts

@@ -15,6 +15,7 @@ import { trajectoryFromMmCIF } from 'mol-model-formats/structure/mmcif';
 import { computeModelDSSP } from 'mol-model/structure/model/properties/utils/secondary-structure';
 import { MolecularSurfaceRepresentationProvider } from 'mol-repr/structure/representation/molecular-surface';
 import { BallAndStickRepresentationProvider } from 'mol-repr/structure/representation/ball-and-stick';
+import { GaussianSurfaceRepresentationProvider } from 'mol-repr/structure/representation/gaussian-surface';
 
 const parent = document.getElementById('app')!
 parent.style.width = '100%'
@@ -71,6 +72,10 @@ function getMolecularSurfaceRepr() {
     return MolecularSurfaceRepresentationProvider.factory(reprCtx, MolecularSurfaceRepresentationProvider.getParams)
 }
 
+function getGaussianSurfaceRepr() {
+    return GaussianSurfaceRepresentationProvider.factory(reprCtx, GaussianSurfaceRepresentationProvider.getParams)
+}
+
 async function init() {
     const cif = await downloadFromPdb('1crn')
     const models = await getModels(cif)
@@ -82,28 +87,40 @@ async function init() {
     const cartoonRepr = getCartoonRepr()
     const ballAndStick = getBallAndStickRepr()
     const molecularSurfaceRepr = getMolecularSurfaceRepr()
+    const gaussianSurfaceRepr = getGaussianSurfaceRepr()
 
-    cartoonRepr.setTheme({
-        color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
-        size: reprCtx.sizeThemeRegistry.create('uniform', { structure })
-    })
-    await cartoonRepr.createOrUpdate({ ...CartoonRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
+    // cartoonRepr.setTheme({
+    //     color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
+    //     size: reprCtx.sizeThemeRegistry.create('uniform', { structure })
+    // })
+    // await cartoonRepr.createOrUpdate({ ...CartoonRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
 
-    ballAndStick.setTheme({
-        color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
-        size: reprCtx.sizeThemeRegistry.create('uniform', { structure })
-    })
-    await ballAndStick.createOrUpdate({ ...BallAndStickRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
+    // ballAndStick.setTheme({
+    //     color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
+    //     size: reprCtx.sizeThemeRegistry.create('uniform', { structure })
+    // })
+    // await ballAndStick.createOrUpdate({ ...BallAndStickRepresentationProvider.defaultValues, quality: 'auto' }, structure).run()
 
     molecularSurfaceRepr.setTheme({
         color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
         size: reprCtx.sizeThemeRegistry.create('physical', { structure })
     })
-    await molecularSurfaceRepr.createOrUpdate({ ...MolecularSurfaceRepresentationProvider.defaultValues, quality: 'custom', alpha: 1.0, flatShaded: true, doubleSided: true, resolution: 0.5 }, structure).run()
+    console.time('molecular surface')
+    await molecularSurfaceRepr.createOrUpdate({ ...MolecularSurfaceRepresentationProvider.defaultValues, quality: 'custom', alpha: 1.0, flatShaded: true, doubleSided: true, resolution: 0.3 }, structure).run()
+    console.timeEnd('molecular surface')
+
+    // gaussianSurfaceRepr.setTheme({
+    //     color: reprCtx.colorThemeRegistry.create('secondary-structure', { structure }),
+    //     size: reprCtx.sizeThemeRegistry.create('physical', { structure })
+    // })
+    // console.time('gaussian surface')
+    // await gaussianSurfaceRepr.createOrUpdate({ ...GaussianSurfaceRepresentationProvider.defaultValues, quality: 'custom', alpha: 1.0, flatShaded: true, doubleSided: true, resolution: 0.3 }, structure).run()
+    // console.timeEnd('gaussian surface')
 
     // canvas3d.add(cartoonRepr)
     // canvas3d.add(ballAndStick)
     canvas3d.add(molecularSurfaceRepr)
+    // canvas3d.add(gaussianSurfaceRepr)
     canvas3d.resetCamera()
 }