Browse Source

add tensor handling to cif encoder

Alexander Rose 7 years ago
parent
commit
3add0d0636

+ 1 - 1
src/mol-io/reader/cif/data-model.ts

@@ -106,6 +106,6 @@ export function getTensor(category: Category, field: string, space: Tensor.Space
                 }
             }
         }
-    } else throw new Error('Tensors with rank > 3 currently not supported.');
+    } else throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.');
     return ret;
 }

+ 9 - 2
src/mol-io/reader/cif/schema.ts

@@ -25,7 +25,7 @@ function getColumnCtor(t: Column.Schema): ColumnCtor {
         case 'int': return (f, c, k) => createColumn(t, f, f.int, f.toIntArray);
         case 'float': return (f, c, k) => createColumn(t, f, f.float, f.toFloatArray);
         case 'list': return (f, c, k) => createColumn(t, f, f.list, f.toListArray);
-        case 'tensor': throw new Error(`Use createTensorColumn instead.`);
+        case 'tensor': throw new Error('Use createTensorColumn instead.');
     }
 }
 
@@ -44,7 +44,14 @@ function createColumn<T>(schema: Column.Schema, field: Data.Field, value: (row:
 
 function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Category, key: string): Column<Tensor> {
     const space = schema.space;
-    const first = category.getField(`${key}[1]`) || Column.Undefined(category.rowCount, schema);
+    let firstFieldName: string;
+    switch (space.rank) {
+        case 1: firstFieldName = `${key}[1]`; break;
+        case 2: firstFieldName = `${key}[1][1]`; break;
+        case 3: firstFieldName = `${key}[1][1][1]`; break;
+        default: throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.');
+    }
+    const first = category.getField(firstFieldName) || Column.Undefined(category.rowCount, schema);
     const value = (row: number) => Data.getTensor(category, key, space, row);
     const toArray: Column<Tensor>['toArray'] = params => ColumnHelpers.createAndFillArray(category.rowCount, value, params)
 

+ 80 - 5
src/mol-io/writer/cif/encoder.ts

@@ -1,11 +1,13 @@
 /**
- * Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ * Copyright (c) 2017-2018 mol* contributors, licensed under MIT, See LICENSE file for more info.
  *
  * @author David Sehnal <david.sehnal@gmail.com>
+ * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
 import Iterator from 'mol-data/iterator'
-import { Column } from 'mol-data/db'
+import { Column, Table } from 'mol-data/db'
+import { Tensor } from 'mol-math/linear-algebra'
 import Encoder from '../encoder'
 
 // TODO: support for "coordinate fields", make "coordinate precision" a parameter of the encoder
@@ -13,7 +15,6 @@ import Encoder from '../encoder'
 // TODO: automatically detect "best encoding" for integer arrays. This could be used for "fixed-point" as well.
 // TODO: add "repeat encoding"? [[1, 2], [1, 2], [1, 2]] --- Repeat ---> [[1, 2], 3]
 // TODO: Add "higher level fields"? (i.e. generalization of repeat)
-// TODO: Add tensor field definition
 // TODO: align "data blocks" to 8 byte offsets for fast typed array windows? (prolly needs some testing if this is actually the case too)
 // TODO: "parametric encoders"? Specify encoding as [{ param: 'value1', encoding1 }, { param: 'value2', encoding2 }]
 //       then the encoder can specify { param: 'value1' } and the correct encoding will be used.
@@ -35,7 +36,6 @@ export type FieldDefinition<Key = any, Data = any> =
     | FieldDefinitionBase<Key, Data> & { type: FieldType.Str, value(key: Key, data: Data): string }
     | FieldDefinitionBase<Key, Data> & { type: FieldType.Int, value(key: Key, data: Data): number }
     | FieldDefinitionBase<Key, Data> & { type: FieldType.Float, value(key: Key, data: Data): number }
-    // TODO: add tensor
 
 export interface FieldFormat {
     // TODO: do we actually need this?
@@ -74,4 +74,79 @@ export interface CIFEncoder<T = string | Uint8Array, Context = any> extends Enco
     startDataBlock(header: string): void,
     writeCategory(category: CategoryProvider, contexts?: Context[]): void,
     getData(): T
-}
+}
+
+function columnValue(k: string) {
+    return (i: number, d: any) => d[k].value(i);
+}
+
+function columnTensorValue(k: string, ...coords: number[]) {
+    return (i: number, d: any) => d[k].schema.space.get(d[k].value(i), ...coords);
+}
+
+function columnValueKind(k: string) {
+    return (i: number, d: any) => d[k].valueKind(i);
+}
+
+function getTensorDefinitions(field: string, space: Tensor.Space) {
+    const fieldDefinitions: FieldDefinition[] = []
+    const type = FieldType.Float
+    const valueKind = columnValueKind(field)
+    if (space.rank === 1) {
+        const rows = space.dimensions[0]
+        for (let i = 0; i < rows; i++) {
+            const name = `${field}[${i + 1}]`
+            fieldDefinitions.push({ name, type, value: columnTensorValue(field, i), valueKind })
+        }
+    } else if (space.rank === 2) {
+        const rows = space.dimensions[0], cols = space.dimensions[1]
+        for (let i = 0; i < rows; i++) {
+            for (let j = 0; j < cols; j++) {
+                const name = `${field}[${i + 1}][${j + 1}]`
+                fieldDefinitions.push({ name, type, value: columnTensorValue(field, i, j), valueKind })
+            }
+        }
+    } else if (space.rank === 3) {
+        const d0 = space.dimensions[0], d1 = space.dimensions[1], d2 = space.dimensions[2]
+        for (let i = 0; i < d0; i++) {
+            for (let j = 0; j < d1; j++) {
+                for (let k = 0; k < d2; k++) {
+                    const name = `${field}[${i + 1}][${j + 1}][${k + 1}]`
+                    fieldDefinitions.push({ name, type, value: columnTensorValue(field, i, j, k), valueKind })
+                }
+            }
+        }
+    } else {
+        throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.')
+    }
+    return fieldDefinitions
+}
+
+export namespace FieldDefinitions {
+    export function ofSchema(schema: Table.Schema) {
+        const fields: FieldDefinition[] = [];
+        for (const k of Object.keys(schema)) {
+            const t = schema[k];
+            if (t.valueType === 'int') {
+                fields.push({ name: k, type: FieldType.Int, value: columnValue(k), valueKind: columnValueKind(k) });
+            } else if (t.valueType === 'float') {
+                fields.push({ name: k, type: FieldType.Float, value: columnValue(k), valueKind: columnValueKind(k) });
+            } else if (t.valueType === 'str') {
+                fields.push({ name: k, type: FieldType.Str, value: columnValue(k), valueKind: columnValueKind(k) });
+            } else if (t.valueType === 'list') {
+                throw new Error('list not implemented');
+            } else if (t.valueType === 'tensor') {
+                fields.push(...getTensorDefinitions(k, t.space))
+            } else {
+                throw new Error(`Unknown valueType ${t.valueType}`);
+            }
+        }
+        return fields;
+    }
+}
+
+export namespace CategoryDefinition {
+    export function ofTable<S extends Table.Schema>(name: string, table: Table<S>): CategoryDefinition<number> {
+        return { name, fields: FieldDefinitions.ofSchema(table._schema) }
+    }
+}

+ 5 - 28
src/mol-model/structure/export/mmcif.ts

@@ -1,13 +1,14 @@
 /**
- * Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info.
+ * Copyright (c) 2017-2018 mol* contributors, licensed under MIT, See LICENSE file for more info.
  *
  * @author David Sehnal <david.sehnal@gmail.com>
+ * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
-import { Column, Table } from 'mol-data/db'
+import { Column } from 'mol-data/db'
 import Iterator from 'mol-data/iterator'
 import * as Encoder from 'mol-io/writer/cif'
-//import { mmCIF_Schema } from 'mol-io/reader/cif/schema/mmcif'
+// import { mmCIF_Schema } from 'mol-io/reader/cif/schema/mmcif'
 import { Structure, Atom, AtomSet } from '../structure'
 import { Model } from '../model'
 import P from '../query/properties'
@@ -36,29 +37,6 @@ function float<K, D = any>(name: string, value: (k: K, d: D) => number, valueKin
 //     return { name, type, value, valueKind }
 // }
 
-function columnValue(k: string) {
-    return (i: number, d: any) => d[k].value(i);
-}
-
-function columnValueKind(k: string) {
-    return (i: number, d: any) => d[k].valueKind(i);
-}
-
-function ofSchema(schema: Table.Schema) {
-    const fields: Encoder.FieldDefinition[] = [];
-    for (const k of Object.keys(schema)) {
-        const t = schema[k];
-        // TODO: matrix/vector/support
-        const type: any = t.valueType === 'str' ? Encoder.FieldType.Str : t.valueType === 'int' ? Encoder.FieldType.Int : Encoder.FieldType.Float;
-        fields.push({ name: k, type, value: columnValue(k), valueKind: columnValueKind(k) })
-    }
-    return fields;
-}
-
-function ofTable<S extends Table.Schema>(name: string, table: Table<S>): Encoder.CategoryDefinition<number> {
-    return { name, fields: ofSchema(table._schema) }
-}
-
 // type Entity = Table.Columns<typeof mmCIF_Schema.entity>
 
 // const entity: Encoder.CategoryDefinition<number, Entity> = {
@@ -66,7 +44,6 @@ function ofTable<S extends Table.Schema>(name: string, table: Table<S>): Encoder
 //     fields: ofSchema(mmCIF_Schema.entity)
 // }
 
-
 // [
 //     str('id', (i, e) => e.id.value(i)),
 //     str('type', (i, e) => e.type.value(i)),
@@ -126,7 +103,7 @@ const atom_site: Encoder.CategoryDefinition<Atom.Location> = {
 function entityProvider({ model }: Context): Encoder.CategoryInstance {
     return {
         data: model.hierarchy.entities,
-        definition: ofTable('entity', model.hierarchy.entities), //entity,
+        definition: Encoder.CategoryDefinition.ofTable('entity', model.hierarchy.entities),
         keys: () => Iterator.Range(0, model.hierarchy.entities._rowCount - 1),
         rowCount: model.hierarchy.entities._rowCount
     }