Browse Source

Tensor.Space.getCoords

David Sehnal 5 years ago
parent
commit
acf793f112
2 changed files with 131 additions and 14 deletions
  1. 55 0
      src/mol-math/linear-algebra/_spec/tensor.spec.ts
  2. 76 14
      src/mol-math/linear-algebra/tensor.ts

+ 55 - 0
src/mol-math/linear-algebra/_spec/tensor.spec.ts

@@ -263,4 +263,59 @@ describe('tensor', () => {
 
         expect(data).toEqual(exp);
     });
+
+    it('indexing', () => {
+        function permutations<T>(inputArr: T[]): T[][] {
+            let result: T[][] = [];
+            function permute(arr: any, m: any = []) {
+                if (arr.length === 0) {
+                    result.push(m);
+                } else {
+                    for (let i = 0; i < arr.length; i++) {
+                        let curr = arr.slice();
+                        let next = curr.splice(i, 1);
+                        permute(curr.slice(), m.concat(next));
+                    }
+                }
+            }
+            permute(inputArr);
+
+            return result;
+        }
+
+        for (let dim = 1; dim <= 5; dim++) {
+            const axes = [], dims: number[] = [];
+            const u: number[] = [], v: number[] = [];
+
+            for (let i = 0; i < dim; i++) {
+                axes.push(i);
+                dims.push(3);
+                u.push(0);
+                v.push(0);
+            }
+
+            const forEachDim = (space: T.Space, d: number): boolean => {
+                if (d === dim) {
+                    const o = space.dataOffset(...u);
+                    space.getCoords(o, v);
+
+                    for (let e = 0; e < dims.length; e++) {
+                        expect(u[e]).toEqual(v[e]);
+                        return false;
+                    }
+                } else {
+                    for (let i = 0; i < dims[d]; i++) {
+                        u[d] = i;
+                        if (!forEachDim(space, d + 1)) return false;
+                    }
+                }
+                return true;
+            };
+
+            for (const ao of permutations(axes)) {
+                const space = T.Space(dims, ao);
+                if (!forEachDim(space, 0)) break;
+            }
+        }
+    });
 });

+ 76 - 14
src/mol-math/linear-algebra/tensor.ts

@@ -22,7 +22,8 @@ export namespace Tensor {
         get(data: Tensor.Data, ...coords: number[]): number
         set(data: Tensor.Data, ...coordsAndValue: number[]): number
         add(data: Tensor.Data, ...coordsAndValue: number[]): number
-        dataOffset(...coords: number[]): number
+        dataOffset(...coords: number[]): number,
+        getCoords(dataOffset: number, coords: { [i: number]: number }): number[]
     }
 
     interface Layout {
@@ -48,8 +49,8 @@ export namespace Tensor {
 
     export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
         const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
-        const { get, set, add, dataOffset } = accessors(layout);
-        return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add, dataOffset };
+        const { get, set, add, dataOffset, getCoords } = accessors(layout);
+        return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add, dataOffset, getCoords };
     }
 
     export function Data1(values: ArrayLike<number>): Data { return values as Data; }
@@ -97,14 +98,15 @@ export namespace Tensor {
         return true;
     }
 
-    function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'], dataOffset: Space['dataOffset'] } {
+    function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'], dataOffset: Space['dataOffset'], getCoords: Space['getCoords'] } {
         const { dimensions, axisOrderFastToSlow: ao } = layout;
         switch (dimensions.length) {
             case 1: return {
                 get: (t, d) => t[d],
                 set: (t, d, x) => t[d] = x,
                 add: (t, d, x) => t[d] += x,
-                dataOffset: (d) => d
+                dataOffset: (d) => d,
+                getCoords: (o, c) => { c[0] = o; return c as number[]; }
             };
             case 2: {
                 // column major
@@ -114,7 +116,8 @@ export namespace Tensor {
                         get: (t, i, j) => t[j * rows + i],
                         set: (t, i, j, x) => t[j * rows + i] = x,
                         add: (t, i, j, x) => t[j * rows + i] += x,
-                        dataOffset: (i, j) => j * rows + i
+                        dataOffset: (i, j) => j * rows + i,
+                        getCoords: (o, c) => { c[0] = o % rows; c[1] = Math.floor(o / rows) ; return c as number[]; }
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 0) {
@@ -123,7 +126,8 @@ export namespace Tensor {
                         get: (t, i, j) => t[i * cols + j],
                         set: (t, i, j, x) => t[i * cols + j] = x,
                         add: (t, i, j, x) => t[i * cols + j] += x,
-                        dataOffset: (i, j) => i * cols + j
+                        dataOffset: (i, j) => i * cols + j,
+                        getCoords: (o, c) => { c[0] = Math.floor(o / cols); c[1] = o % cols; return c as number[]; }
                     };
                 }
                 throw new Error('bad axis order');
@@ -135,7 +139,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[i + j * u + k * uv],
                         set: (t, i, j, k, x ) => t[i + j * u + k * uv] = x,
                         add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x,
-                        dataOffset: (i, j, k) => i + j * u + k * uv
+                        dataOffset: (i, j, k) => i + j * u + k * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = o % u;
+                            c[1] = p % v;
+                            c[2] = Math.floor(p / v);
+                            return c as number[];
+                        }
                     };
                 }
                 if (ao[0] === 0 && ao[1] === 2 && ao[2] === 1) { // 021 ikj
@@ -144,7 +155,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[i + k * u + j * uv],
                         set: (t, i, j, k, x ) => t[i + k * u + j * uv] = x,
                         add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x,
-                        dataOffset: (i, j, k) => i + k * u + j * uv
+                        dataOffset: (i, j, k) => i + k * u + j * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = o % u;
+                            c[1] = Math.floor(p / v);
+                            c[2] = p % v;
+                            return c as number[];
+                        }
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 0 && ao[2] === 2) { // 102 jik
@@ -153,7 +171,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[j + i * u + k * uv],
                         set: (t, i, j, k, x ) => t[j + i * u + k * uv] = x,
                         add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x,
-                        dataOffset: (i, j, k) => j + i * u + k * uv
+                        dataOffset: (i, j, k) => j + i * u + k * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = p % v;
+                            c[1] = o % u;
+                            c[2] = Math.floor(p / v);
+                            return c as number[];
+                        }
                     };
                 }
                 if (ao[0] === 1 && ao[1] === 2 && ao[2] === 0) { // 120 jki
@@ -162,7 +187,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[j + k * u + i * uv],
                         set: (t, i, j, k, x ) => t[j + k * u + i * uv] = x,
                         add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x,
-                        dataOffset: (i, j, k) => j + k * u + i * uv
+                        dataOffset: (i, j, k) => j + k * u + i * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = Math.floor(p / v);
+                            c[1] = o % u;
+                            c[2] = p % v;
+                            return c as number[];
+                        }
                     };
                 }
                 if (ao[0] === 2 && ao[1] === 0 && ao[2] === 1) { // 201 kij
@@ -171,7 +203,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[k + i * u + j * uv],
                         set: (t, i, j, k, x ) => t[k + i * u + j * uv] = x,
                         add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x,
-                        dataOffset: (i, j, k) => k + i * u + j * uv
+                        dataOffset: (i, j, k) => k + i * u + j * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = p % v;
+                            c[1] = Math.floor(p / v);
+                            c[2] = o % u;
+                            return c as number[];
+                        }
                     };
                 }
                 if (ao[0] === 2 && ao[1] === 1 && ao[2] === 0) { // 210 kji
@@ -180,7 +219,14 @@ export namespace Tensor {
                         get: (t, i, j, k) => t[k + j * u + i * uv],
                         set: (t, i, j, k, x ) => t[k + j * u + i * uv] = x,
                         add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x,
-                        dataOffset: (i, j, k) => k + j * u + i * uv
+                        dataOffset: (i, j, k) => k + j * u + i * uv,
+                        getCoords: (o, c) => {
+                            const p = Math.floor(o / u);
+                            c[0] = Math.floor(p / v);
+                            c[1] = p % v;
+                            c[2] = o % u;
+                            return c as number[];
+                        }
                     };
                 }
                 throw new Error('bad axis order');
@@ -189,7 +235,8 @@ export namespace Tensor {
                 get: (t, ...c) => t[dataOffset(layout, c)],
                 set: (t, ...c) => t[dataOffset(layout, c)] = c[c.length - 1],
                 add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1],
-                dataOffset: (...c) => dataOffset(layout, c)
+                dataOffset: (...c) => dataOffset(layout, c),
+                getCoords: (o, c) => getCoords(layout, o, c as number[]),
             };
         }
     }
@@ -211,6 +258,21 @@ export namespace Tensor {
         return o;
     }
 
+    function getCoords(layout: Layout, o: number, coords: number[]) {
+        const { dimensions: dim, axisOrderFastToSlow: ao } = layout;
+        const d = dim.length;
+
+        let c = o;
+        for (let i = 0; i < d; i++) {
+            const d = dim[ao[i]];
+            coords[ao[i]] = c % d;
+            c = Math.floor(c / d);
+        }
+        coords[ao[d + 1]] = c;
+
+        return coords;
+    }
+
     // Convers "slow to fast" axis order to "fast to slow" and vice versa.
     export function invertAxisOrder(v: number[]) {
         const ret: number[] = [];