Browse Source

Tensors in schemas

David Sehnal 7 years ago
parent
commit
10e4306746

+ 10 - 11
src/mol-base/collections/database/column.ts

@@ -5,6 +5,7 @@
  */
 
 import * as ColumnHelpers from './column-helpers'
+import Tensors from '../../math/tensor'
 
 interface Column<T> {
     readonly '@type': Column.Type,
@@ -19,27 +20,25 @@ interface Column<T> {
 }
 
 namespace Column {
-    export type Type<T = any> = Type.Str | Type.Int | Type.Float | Type.Vector | Type.Matrix | Type.Aliased<T>
+    export type Type<T = any> = Type.Str | Type.Int | Type.Float | Type.Tensor | Type.Aliased<T>
     export type ArrayCtor<T> = { new(size: number): ArrayLike<T> }
 
     export namespace Type {
         export type Str = { T: string, kind: 'str' }
         export type Int = { T: number, kind: 'int' }
         export type Float = { T: number, kind: 'float' }
-        export type Vector = { T: number[], dim: number, kind: 'vector' };
-        export type Matrix = { T: number[][], rows: number, cols: number, kind: 'matrix' };
+        export type Tensor = { T: Tensors, space: Tensors.Space, kind: 'tensor' };
         export type Aliased<T> = { T: T } & { kind: 'str' | 'int' | 'float' }
 
         export const str: Str = { T: '', kind: 'str' };
         export const int: Int = { T: 0, kind: 'int' };
         export const float: Float = { T: 0, kind: 'float' };
 
-        export function vector(dim: number): Vector { return { T: [] as number[], dim, kind: 'vector' }; }
-        export function matrix(rows: number, cols: number): Matrix { return { T: [] as number[][], rows, cols, kind: 'matrix' }; }
+        export function tensor(space: Tensors.Space): Tensor { return { T: space.create(), space, kind: 'tensor' }; }
         export function aliased<T>(t: Type): Aliased<T> { return t as any as Aliased<T>; }
     }
 
-    export type Schema<T = any> = Schema.Scalar<T> | Schema.Vector | Schema.Matrix
+    export type Schema<T = any> = Schema.Str | Schema.Int | Schema.Float | Schema.Coordinate | Schema.Aliased<T> | Schema.Tensor
 
     export namespace Schema {
         export interface FloatPrecision {
@@ -48,7 +47,7 @@ namespace Column {
             full: number
         }
 
-        export type Scalar<T = any> = Schema.Str | Schema.Int | Schema.Float | Schema.Coordinate| Schema.Aliased<T>
+        export type Scalar<T = any> = Schema.Str | Schema.Int | Schema.Float | Schema.Coordinate | Schema.Aliased<T>
 
         export function FP(full: number, acceptable: number, low: number): FloatPrecision { return { low, full, acceptable }; }
 
@@ -57,8 +56,7 @@ namespace Column {
         export type Float = { '@type': 'float', T: number, kind: 'float', precision: FloatPrecision }
         export type Coordinate = { '@type': 'coord', T: number, kind: 'float' }
 
-        export type Vector = { '@type': 'vector', T: number[], dim: number, kind: 'vector' };
-        export type Matrix = { '@type': 'matrix', T: number[][], rows: number, cols: number, kind: 'matrix' };
+        export type Tensor = { '@type': 'tensor', T: Tensors, space: Tensors.Space, kind: 'tensor' };
         export type Aliased<T> = { '@type': 'aliased', T: T } & { kind: 'str' | 'int' | 'float' }
 
         export const str: Str = { '@type': 'str', T: '', kind: 'str' };
@@ -66,8 +64,9 @@ namespace Column {
         export const coord: Coordinate = { '@type': 'coord', T: 0, kind: 'float' };
         export function float(precision: FloatPrecision): Float { return { '@type': 'float', T: 0, kind: 'float', precision } };
 
-        export function vector(dim: number): Vector { return { '@type': 'vector', T: [] as number[], dim, kind: 'vector' }; }
-        export function matrix(rows: number, cols: number): Matrix { return { '@type': 'matrix', T: [] as number[][], rows, cols, kind: 'matrix' }; }
+        export function tensor(space: Tensors.Space): Tensor { return { '@type': 'tensor', T: space.create(), space, kind: 'tensor' }; }
+        export function vector(dim: number): Tensor { return tensor(Tensors.Vector(dim)); }
+        export function matrix(rows: number, cols: number): Tensor { return tensor(Tensors.ColumnMajorMatrix(rows, cols)); }
         export function aliased<T>(t: Schema): Aliased<T> { return t as any as Aliased<T>; }
     }
 

+ 2 - 2
src/mol-base/math/tensor.ts

@@ -10,6 +10,7 @@ namespace Tensor {
     export type ArrayCtor = { new (size: number): ArrayLike<number> }
 
     export interface Space {
+        readonly rank: number,
         readonly dimensions: ReadonlyArray<number>,
         readonly axisOrderSlowToFast: ReadonlyArray<number>,
         create(array?: ArrayCtor): Tensor,
@@ -19,7 +20,6 @@ namespace Tensor {
 
     interface Layout {
         dimensions: number[],
-        // slowest to fastest changing
         axisOrderSlowToFast: number[],
         axisOrderFastToSlow: number[],
         accessDimensions: number[],
@@ -40,7 +40,7 @@ namespace Tensor {
     export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
         const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
         const { get, set } = accessors(layout);
-        return { dimensions: [...dimensions], axisOrderSlowToFast: [...axisOrderSlowToFast], create: creator(layout), get, set };
+        return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set };
     }
 
     export function Vector(d: number, ctor?: ArrayCtor) { return Space([d], [0], ctor); }

+ 27 - 18
src/mol-io/reader/cif/data-model.ts

@@ -5,6 +5,7 @@
  */
 
 import { Column } from 'mol-base/collections/database'
+import Tensor from 'mol-base/math/tensor'
 
 export interface File {
     readonly name?: string,
@@ -76,24 +77,32 @@ export interface Field {
     toFloatArray(params?: Column.ToArrayParams<number>): ReadonlyArray<number>
 }
 
-export function getMatrix(category: Category, field: string, rows: number, cols: number, row: number) {
-    const ret: number[][] = [];
-    for (let i = 0; i < rows; i++) {
-        const r: number[] = [];
-        for (let j = 0; j < cols; j++) {
-            const f = category.getField(`${field}[${i + 1}][${j + 1}]`);
-            r[j] = f ? f.float(row) : 0.0;
+export function getTensor(category: Category, field: string, space: Tensor.Space, row: number): Tensor {
+    const ret = space.create();
+    if (space.rank === 1) {
+        const rows = space.dimensions[0];
+        for (let i = 0; i < rows; i++) {
+            const f = category.getField(`${field}[${i + 1}]`);
+            space.set(ret, i, !!f ? f.float(row) : 0.0);
         }
-        ret[i] = r;
-    }
-    return ret;
-}
-
-export function getVector(category: Category, field: string, rows: number, row: number) {
-    const ret: number[] = [];
-    for (let i = 0; i < rows; i++) {
-        const f = category.getField(`${field}[${i + 1}]`);
-        ret[i] = f ? f.float(row) : 0.0;
-    }
+    } else if (space.rank === 2) {
+        const rows = space.dimensions[0], cols = space.dimensions[1];
+        for (let i = 0; i < rows; i++) {
+            for (let j = 0; j < cols; j++) {
+                const f = category.getField(`${field}[${i + 1}][${j + 1}]`);
+                space.set(ret, i, j, !!f ? f.float(row) : 0.0);
+            }
+        }
+    } else if (space.rank === 3) {
+        const d0 = space.dimensions[0], d1 = space.dimensions[1], d2 = space.dimensions[2];
+        for (let i = 0; i < d0; i++) {
+            for (let j = 0; j < d1; j++) {
+                for (let k = 0; k < d2; k++) {
+                    const f = category.getField(`${field}[${i + 1}][${j + 1}][${k + 1}]`);
+                    space.set(ret, i, j, k, !!f ? f.float(row) : 0.0);
+                }
+            }
+        }
+    } else throw new Error('Tensors with rank > 3 currently not supported.');
     return ret;
 }

+ 3 - 8
src/mol-io/reader/cif/schema.ts

@@ -22,14 +22,9 @@ function getColumnCtor(t: Column.Schema): ColumnCtor {
         case 'str': return (f, c, k) => createColumn(Column.Type.str, f, f.str, f.toStringArray);
         case 'int': return (f, c, k) => createColumn(Column.Type.int, f, f.int, f.toIntArray);
         case 'float': return (f, c, k) => createColumn(Column.Type.float, f, f.float, f.toFloatArray);
-        case 'vector': return (f, c, k) => {
-            const dim = t.dim;
-            const value = (row: number) => Data.getVector(c, k, dim, row);
-            return createColumn(t, f, value, params => ColumnHelpers.createAndFillArray(f.rowCount, value, params));
-        }
-        case 'matrix': return (f, c, k) => {
-            const rows = t.rows, cols = t.cols;
-            const value = (row: number) => Data.getMatrix(c, k, rows, cols, row);
+        case 'tensor': return (f, c, k) => {
+            const space = t.space;
+            const value = (row: number) => Data.getTensor(c, k, space, row);
             return createColumn(t, f, value, params => ColumnHelpers.createAndFillArray(f.rowCount, value, params));
         }
     }