Browse Source

wip, refactoring CIF encoder

David Sehnal 6 years ago
parent
commit
08c2fdb766

+ 21 - 7
src/mol-io/writer/cif/encoder.ts

@@ -36,7 +36,11 @@ export namespace Field {
         typedArray?: ArrayEncoding.TypedArrayCtor
     }
 
-    export type ParamsBase<K, D> = { valueKind?: (k: K, d: D) => Column.ValueKind, encoder?: ArrayEncoder, shouldInclude?: (data: D) => boolean }
+    export type ParamsBase<K, D> = {
+        valueKind?: (k: K, d: D) => Column.ValueKind,
+        encoder?: ArrayEncoder,
+        shouldInclude?: (data: D) => boolean
+    }
 
     export function str<K, D = any>(name: string, value: (k: K, d: D, index: number) => string, params?: ParamsBase<K, D>): Field<K, D> {
         return { name, type: Type.Str, value, valueKind: params && params.valueKind, defaultFormat: params && params.encoder ? { encoder: params.encoder } : void 0, shouldInclude: params && params.shouldInclude };
@@ -105,15 +109,19 @@ export interface Category<Ctx = any> {
 }
 
 export namespace Category {
-    export const Empty: Instance = { fields: [], rowCount: 0 };
+    export const Empty: Instance = { fields: [], source: [] };
 
-    export interface Instance<Key = any, Data = any> {
-        fields: Field[],
+    export interface DataSource<Key = any, Data = any> {
         data?: Data,
         rowCount: number,
         keys?: () => Iterator<Key>
     }
 
+    export interface Instance<Key = any, Data = any> {
+        fields: Field[],
+        source: DataSource<Key, Data>[]
+    }
+
     export interface Filter {
         includeCategory(categoryName: string): boolean,
         includeField(categoryName: string, fieldName: string): boolean,
@@ -134,9 +142,15 @@ export namespace Category {
 
     export function ofTable(table: Table<Table.Schema>, indices?: ArrayLike<number>): Category.Instance {
         if (indices) {
-            return { fields: cifFieldsFromTableSchema(table._schema), data: table, rowCount: indices.length, keys: () => Iterator.Array(indices) };
+            return {
+                fields: cifFieldsFromTableSchema(table._schema),
+                source: [{ data: table, rowCount: indices.length, keys: () => Iterator.Array(indices) }]
+            };
         }
-        return { fields: cifFieldsFromTableSchema(table._schema), data: table, rowCount: table._rowCount };
+        return {
+            fields: cifFieldsFromTableSchema(table._schema),
+            source: [{ data: table, rowCount: table._rowCount }]
+        };
     }
 }
 
@@ -145,7 +159,7 @@ export interface Encoder<T = string | Uint8Array> extends EncoderBase {
     setFormatter(formatter?: Category.Formatter): void,
 
     startDataBlock(header: string): void,
-    writeCategory<Ctx>(category: Category<Ctx>, contexts?: Ctx[]): void,
+    writeCategory<Ctx>(category: Category<Ctx>, context?: Ctx): void,
     getData(): T
 }
 

+ 9 - 15
src/mol-io/writer/cif/encoder/binary.ts

@@ -6,7 +6,6 @@
  * @author David Sehnal <david.sehnal@gmail.com>
  */
 
-import { Iterator } from 'mol-data'
 import { Column } from 'mol-data/db'
 import encodeMsgPack from '../../../common/msgpack/encode'
 import {
@@ -14,7 +13,7 @@ import {
 } from '../../../common/binary-cif'
 import { Field, Category, Encoder } from '../encoder'
 import Writer from '../../writer'
-import { getIncludedFields } from './util';
+import { getIncludedFields, getCategoryInstanceData, CategoryInstanceData } from './util';
 import { classifyIntArray, classifyFloatArray } from '../../../common/binary-cif/classifier';
 
 export interface EncodingProvider {
@@ -43,7 +42,7 @@ export default class BinaryEncoder implements Encoder<Uint8Array> {
         });
     }
 
-    writeCategory<Ctx>(category: Category<Ctx>, contexts?: Ctx[]) {
+    writeCategory<Ctx>(category: Category<Ctx>, context?: Ctx) {
         if (!this.data) {
             throw new Error('The writer contents have already been encoded, no more writing.');
         }
@@ -54,22 +53,17 @@ export default class BinaryEncoder implements Encoder<Uint8Array> {
 
         if (!this.filter.includeCategory(category.name)) return;
 
-        const src = !contexts || !contexts.length ? [category.instance(<any>void 0)] : contexts.map(c => category.instance(c));
-        const instances = src.filter(c => c && c.rowCount > 0);
-        if (!instances.length) return;
+        const { instance, rowCount, source } = getCategoryInstanceData(category, context);
+        if (!rowCount) return;
 
-        const count = instances.reduce((a, c) => a + c.rowCount, 0);
-        if (!count) return;
-
-        const cat: EncodedCategory = { name: '_' + category.name, columns: [], rowCount: count };
-        const data = instances.map(c => ({ data: c.data, keys: () => c.keys ? c.keys() : Iterator.Range(0, c.rowCount - 1) }));
-        const fields = getIncludedFields(instances[0]);
+        const cat: EncodedCategory = { name: '_' + category.name, columns: [], rowCount };
+        const fields = getIncludedFields(instance);
 
         for (const f of fields) {
             if (!this.filter.includeField(category.name, f.name)) continue;
 
             const format = this.formatter.getFormat(category.name, f.name);
-            cat.columns.push(encodeField(category.name, f, data, count, format, this.encodingProvider, this.autoClassify));
+            cat.columns.push(encodeField(category.name, f, source, rowCount, format, this.encodingProvider, this.autoClassify));
         }
         // no columns included.
         if (!cat.columns.length) return;
@@ -133,7 +127,7 @@ function classify(type: Field.Type, data: ArrayLike<any>) {
     return classifyFloatArray(data);
 }
 
-function encodeField(categoryName: string, field: Field, data: { data: any, keys: () => Iterator<any> }[], totalCount: number, 
+function encodeField(categoryName: string, field: Field, data: CategoryInstanceData['source'], totalCount: number,
     format: Field.Format | undefined, encoderProvider: EncodingProvider | undefined, autoClassify: boolean): EncodedColumn {
 
     const { array, allPresent, mask } = getFieldData(field, getArrayCtor(field, format), totalCount, data);
@@ -163,7 +157,7 @@ function encodeField(categoryName: string, field: Field, data: { data: any, keys
     };
 }
 
-function getFieldData(field: Field<any, any>, arrayCtor: Helpers.ArrayCtor<string | number>, totalCount: number, data: { data: any; keys: () => Iterator<any>; }[]) {
+function getFieldData(field: Field<any, any>, arrayCtor: Helpers.ArrayCtor<string | number>, totalCount: number, data: CategoryInstanceData['source']) {
     const isStr = field.type === Field.Type.Str;
     const array = new arrayCtor(totalCount);
     const mask = new Uint8Array(totalCount);

+ 17 - 24
src/mol-io/writer/cif/encoder/text.ts

@@ -6,12 +6,11 @@
  * @author David Sehnal <david.sehnal@gmail.com>
  */
 
-import { Iterator } from 'mol-data'
 import { Column } from 'mol-data/db'
 import StringBuilder from 'mol-util/string-builder'
 import { Category, Field, Encoder } from '../encoder'
 import Writer from '../../writer'
-import { getFieldDigitCount, getIncludedFields } from './util';
+import { getFieldDigitCount, getIncludedFields, getCategoryInstanceData, CategoryInstanceData } from './util';
 
 export default class TextEncoder implements Encoder<string> {
     private builder = StringBuilder.create();
@@ -33,7 +32,7 @@ export default class TextEncoder implements Encoder<string> {
         StringBuilder.write(this.builder, `data_${(header || '').replace(/[ \n\t]/g, '').toUpperCase()}\n#\n`);
     }
 
-    writeCategory<Ctx>(category: Category<Ctx>, contexts?: Ctx[]) {
+    writeCategory<Ctx>(category: Category<Ctx>, context?: Ctx) {
         if (this.encoded) {
             throw new Error('The writer contents have already been encoded, no more writing.');
         }
@@ -43,19 +42,13 @@ export default class TextEncoder implements Encoder<string> {
         }
 
         if (!this.filter.includeCategory(category.name)) return;
-
-        const src = !contexts || !contexts.length ? [category.instance(<any>void 0)] : contexts.map(c => category.instance(c));
-        const instances = src.filter(c => c && c.rowCount > 0);
-        if (!instances.length) return;
-
-        const rowCount = instances.reduce((v, c) => v + c.rowCount, 0);
-
-        if (rowCount === 0) return;
+        const { instance, rowCount, source } = getCategoryInstanceData(category, context);
+        if (!rowCount) return;
 
         if (rowCount === 1) {
-            writeCifSingleRecord(category, instances[0]!, this.builder, this.filter, this.formatter);
+            writeCifSingleRecord(category, instance, source, this.builder, this.filter, this.formatter);
         } else {
-            writeCifLoop(category, instances, this.builder, this.filter, this.formatter);
+            writeCifLoop(category, instance, source, this.builder, this.filter, this.formatter);
         }
     }
 
@@ -110,18 +103,18 @@ function getFloatPrecisions(categoryName: string, fields: Field[], formatter: Ca
     return ret;
 }
 
-function writeCifSingleRecord(category: Category, instance: Category.Instance, builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
+function writeCifSingleRecord(category: Category, instance: Category.Instance, source: CategoryInstanceData['source'], builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
     const fields = getIncludedFields(instance);
-    const data = instance.data;
+    const src = source[0];
+    const data = src.data;
     let width = fields.reduce((w, f) => filter.includeField(category.name, f.name) ? Math.max(w, f.name.length) : 0, 0);
 
     // this means no field from this category is included.
     if (width === 0) return;
     width += category.name.length + 6;
 
-    const it = instance.keys ? instance.keys() : Iterator.Range(0, instance.rowCount - 1);
+    const it = src.keys();
     const key = it.move();
-
     const precisions = getFloatPrecisions(category.name, instance.fields, formatter);
 
     for (let _f = 0; _f < fields.length; _f++) {
@@ -135,8 +128,8 @@ function writeCifSingleRecord(category: Category, instance: Category.Instance, b
     StringBuilder.write(builder, '#\n');
 }
 
-function writeCifLoop(category: Category, instances: Category.Instance[], builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
-    const fieldSource = getIncludedFields(instances[0]);
+function writeCifLoop(category: Category, instance: Category.Instance, source: CategoryInstanceData['source'], builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
+    const fieldSource = getIncludedFields(instance);
     const fields = filter === Category.DefaultFilter ? fieldSource : fieldSource.filter(f => filter.includeField(category.name, f.name));
     const fieldCount = fields.length;
     if (fieldCount === 0) return;
@@ -149,13 +142,13 @@ function writeCifLoop(category: Category, instances: Category.Instance[], builde
     }
 
     let index = 0;
-    for (let _c = 0; _c < instances.length; _c++) {
-        const instance = instances[_c];
-        const data = instance.data;
+    for (let _c = 0; _c < source.length; _c++) {
+        const src = source[_c];
+        const data = src.data;
 
-        if (instance.rowCount === 0) continue;
+        if (src.rowCount === 0) continue;
 
-        const it = instance.keys ? instance.keys() : Iterator.Range(0, instance.rowCount - 1);
+        const it = src.keys();
         while (it.hasNext)  {
             const key = it.move();
 

+ 23 - 1
src/mol-io/writer/cif/encoder/util.ts

@@ -4,6 +4,7 @@
  * @author David Sehnal <david.sehnal@gmail.com>
  */
 
+import { Iterator } from 'mol-data'
 import { Field, Category } from '../encoder';
 
 export function getFieldDigitCount(field: Field) {
@@ -13,6 +14,27 @@ export function getFieldDigitCount(field: Field) {
 
 export function getIncludedFields(category: Category.Instance) {
     return category.fields.some(f => !!f.shouldInclude)
-        ? category.fields.filter(f => !f.shouldInclude || f.shouldInclude(category.data))
+        ? category.fields.filter(f => !f.shouldInclude || category.source.some(src => f.shouldInclude!(src.data)))
         : category.fields;
+}
+
+export interface CategoryInstanceData<Ctx = any> {
+    instance: Category.Instance<Ctx>,
+    rowCount: number,
+    source: { data: any, keys: () => Iterator<any>, rowCount: number }[]
+}
+
+export function getCategoryInstanceData<Ctx>(category: Category<Ctx>, ctx?: Ctx): CategoryInstanceData<Ctx> {
+    const instance = category.instance(ctx as any);
+    let sources = instance.source.filter(s => s.rowCount > 0);
+    if (!sources.length) return { instance, rowCount: 0, source: [] };
+
+    const rowCount = sources.reduce((a, c) => a + c.rowCount, 0);
+    const source = sources.map(c => ({
+        data: c.data,
+        keys: () => c.keys ? c.keys() : Iterator.Range(0, c.rowCount - 1),
+        rowCount: c.rowCount
+    }));
+
+    return { instance, rowCount, source };
 }