Browse Source

add dataOffset method Tensor.Space

Alexander Rose 5 years ago
parent
commit
e39304c7cf
1 changed files with 26 additions and 14 deletions
  1. 26 14
      src/mol-math/linear-algebra/tensor.ts

+ 26 - 14
src/mol-math/linear-algebra/tensor.ts

@@ -1,7 +1,8 @@
 /**
- * Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ * Copyright (c) 2017-2020 mol* contributors, licensed under MIT, See LICENSE file for more info.
  *
  * @author David Sehnal <david.sehnal@gmail.com>
+ * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
 import { Mat4, Vec3, Vec4, Mat3 } from './3d';
@@ -21,6 +22,7 @@ export namespace Tensor {
         get(data: Tensor.Data, ...coords: number[]): number
         set(data: Tensor.Data, ...coordsAndValue: number[]): number
         add(data: Tensor.Data, ...coordsAndValue: number[]): number
+        dataOffset(...coords: number[]): number
     }
 
     interface Layout {
@@ -46,8 +48,8 @@ export namespace Tensor {
 
     export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
         const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
-        const { get, set, add } = accessors(layout);
-        return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add };
+        const { get, set, add, dataOffset } = accessors(layout);
+        return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add, dataOffset };
     }
 
     export function Data1(values: ArrayLike<number>): Data { return values as Data; }
@@ -95,13 +97,14 @@ export namespace Tensor {
         return true;
     }
 
-    function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'] } {
+    function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'], dataOffset: Space['dataOffset'] } {
         const { dimensions, axisOrderFastToSlow: ao } = layout;
         switch (dimensions.length) {
             case 1: return {
                 get: (t, d) => t[d],
                 set: (t, d, x) => t[d] = x,
-                add: (t, d, x) => t[d] += x
+                add: (t, d, x) => t[d] += x,
+                dataOffset: (d) => d
             };
             case 2: {
                 // column major
@@ -110,7 +113,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j) => t[j * rows + i],
                         set: (t, i, j, x) => t[j * rows + i] = x,
-                        add: (t, i, j, x) => t[j * rows + i] += x
+                        add: (t, i, j, x) => t[j * rows + i] += x,
+                        dataOffset: (i, j) => j * rows + i
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 0) {
@@ -118,7 +122,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j) => t[i * cols + j],
                         set: (t, i, j, x) => t[i * cols + j] = x,
-                        add: (t, i, j, x) => t[i * cols + j] += x
+                        add: (t, i, j, x) => t[i * cols + j] += x,
+                        dataOffset: (i, j) => i * cols + j
                     };
                 }
                 throw new Error('bad axis order');
@@ -129,7 +134,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[i + j * u + k * uv],
                         set: (t, i, j, k, x ) => t[i + j * u + k * uv] = x,
-                        add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x
+                        add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x,
+                        dataOffset: (i, j, k) => i + j * u + k * uv
                     };
                 }
                 if (ao[0] === 0 && ao[1] === 2 && ao[2] === 1) { // 021 ikj
@@ -137,7 +143,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[i + k * u + j * uv],
                         set: (t, i, j, k, x ) => t[i + k * u + j * uv] = x,
-                        add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x
+                        add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x,
+                        dataOffset: (i, j, k) => i + k * u + j * uv
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 0 && ao[2] === 2) { // 102 jik
@@ -145,7 +152,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[j + i * u + k * uv],
                         set: (t, i, j, k, x ) => t[j + i * u + k * uv] = x,
-                        add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x
+                        add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x,
+                        dataOffset: (i, j, k) => j + i * u + k * uv
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 2 && ao[2] === 0) { // 120 jki
@@ -153,7 +161,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[j + k * u + i * uv],
                         set: (t, i, j, k, x ) => t[j + k * u + i * uv] = x,
-                        add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x
+                        add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x,
+                        dataOffset: (i, j, k) => j + k * u + i * uv
                     };
                 }
                 if (ao[0] === 2 && ao[1] === 0 && ao[2] === 1) { // 201 kij
@@ -161,7 +170,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[k + i * u + j * uv],
                         set: (t, i, j, k, x ) => t[k + i * u + j * uv] = x,
-                        add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x
+                        add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x,
+                        dataOffset: (i, j, k) => k + i * u + j * uv
                     };
                 }
                 if (ao[0] === 2 && ao[1] === 1 && ao[2] === 0) { // 210 kji
@@ -169,7 +179,8 @@ export namespace Tensor {
                     return {
                         get: (t, i, j, k) => t[k + j * u + i * uv],
                         set: (t, i, j, k, x ) => t[k + j * u + i * uv] = x,
-                        add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x
+                        add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x,
+                        dataOffset: (i, j, k) => k + j * u + i * uv
                     };
                 }
                 throw new Error('bad axis order');
@@ -177,7 +188,8 @@ export namespace Tensor {
             default: return {
                 get: (t, ...c) => t[dataOffset(layout, c)],
                 set: (t, ...c) => t[dataOffset(layout, c)] = c[c.length - 1],
-                add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1]
+                add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1],
+                dataOffset: (...c) => dataOffset(layout, c)
             };
         }
     }