Browse Source

CifWriter filter and formatter

David Sehnal 6 years ago
parent
commit
2766f3ccbb

+ 3 - 3
src/mol-io/writer/cif.ts

@@ -5,8 +5,8 @@
  * @author Alexander Rose <alexander.rose@weirdbyte.de>
  */
 
-import TextCIFEncoder from './cif/encoder/text'
-import BinaryCIFEncoder from './cif/encoder/binary'
+import TextEncoder from './cif/encoder/text'
+import BinaryEncoder from './cif/encoder/binary'
 import * as _Encoder from './cif/encoder'
 import { ArrayEncoding } from '../common/binary-cif';
 
@@ -18,7 +18,7 @@ export namespace CifWriter {
 
     export function createEncoder(params?: { binary?: boolean, encoderName?: string }): Encoder {
         const { binary = false, encoderName = 'mol*' } = params || {};
-        return binary ? new BinaryCIFEncoder(encoderName) : new TextCIFEncoder();
+        return binary ? new BinaryEncoder(encoderName) : new TextEncoder();
     }
 
     import E = Encoding

+ 21 - 1
src/mol-io/writer/cif/encoder.ts

@@ -79,6 +79,24 @@ export namespace Category {
         (ctx: Ctx): Category
     }
 
+    export interface Filter {
+        includeCategory(categoryName: string): boolean,
+        includeField(categoryName: string, fieldName: string): boolean,
+    }
+
+    export const DefaultFilter: Filter = {
+        includeCategory(cat) { return true; },
+        includeField(cat, field) { return true; }
+    }
+
+    export interface Formatter {
+        getFormat(categoryName: string, fieldName: string): Field.Format | undefined
+    }
+
+    export const DefaultFormatter: Formatter = {
+        getFormat(cat, field) { return void 0; }
+    }
+
     export function ofTable(name: string, table: Table<Table.Schema>, indices?: ArrayLike<number>): Category<number, Table<Table.Schema>> {
         if (indices) {
             return { name, fields: cifFieldsFromTableSchema(table._schema), data: table, rowCount: indices.length, keys: () => Iterator.Array(indices) };
@@ -88,7 +106,9 @@ export namespace Category {
 }
 
 export interface Encoder<T = string | Uint8Array> extends EncoderBase {
-    // setFormatter(): void,
+    setFilter(filter?: Category.Filter): void,
+    setFormatter(formatter?: Category.Formatter): void,
+
     startDataBlock(header: string): void,
     writeCategory<Ctx>(category: Category.Provider<Ctx>, contexts?: Ctx[]): void,
     getData(): T

+ 36 - 15
src/mol-io/writer/cif/encoder/binary.ts

@@ -10,15 +10,25 @@ import { Iterator } from 'mol-data'
 import { Column } from 'mol-data/db'
 import encodeMsgPack from '../../../common/msgpack/encode'
 import {
-    EncodedColumn, EncodedData, EncodedFile, EncodedDataBlock, EncodedCategory, ArrayEncoder, ArrayEncoding as E, VERSION
+    EncodedColumn, EncodedData, EncodedFile, EncodedDataBlock, EncodedCategory, ArrayEncoder, ArrayEncoding as E, VERSION, ArrayEncoding
 } from '../../../common/binary-cif'
 import { Field, Category, Encoder } from '../encoder'
 import Writer from '../../writer'
 
-export default class BinaryCIFWriter implements Encoder<Uint8Array> {
+export default class BinaryEncoder implements Encoder<Uint8Array> {
     private data: EncodedFile;
     private dataBlocks: EncodedDataBlock[] = [];
     private encodedData: Uint8Array;
+    private filter: Category.Filter = Category.DefaultFilter;
+    private formatter: Category.Formatter = Category.DefaultFormatter;
+
+    setFilter(filter?: Category.Filter) {
+        this.filter = filter || Category.DefaultFilter;
+    }
+
+    setFormatter(formatter?: Category.Formatter) {
+        this.formatter = formatter || Category.DefaultFormatter;
+    }
 
     startDataBlock(header: string) {
         this.dataBlocks.push({
@@ -39,6 +49,7 @@ export default class BinaryCIFWriter implements Encoder<Uint8Array> {
         const src = !contexts || !contexts.length ? [category(<any>void 0)] : contexts.map(c => category(c));
         const categories = src.filter(c => c && c.rowCount > 0);
         if (!categories.length) return;
+        if (!this.filter.includeCategory(categories[0].name)) return;
 
         const count = categories.reduce((a, c) => a + c.rowCount, 0);
         if (!count) return;
@@ -47,7 +58,10 @@ export default class BinaryCIFWriter implements Encoder<Uint8Array> {
         const cat: EncodedCategory = { name: '_' + first.name, columns: [], rowCount: count };
         const data = categories.map(c => ({ data: c.data, keys: () => c.keys ? c.keys() : Iterator.Range(0, c.rowCount - 1) }));
         for (const f of first.fields) {
-            cat.columns.push(encodeField(f, data, count, f.defaultFormat));
+            if (!this.filter.includeField(first.name, f.name)) continue;
+
+            const format = this.formatter.getFormat(first.name, f.name);
+            cat.columns.push(encodeField(f, data, count, getArrayCtor(f, format), getEncoder(f, format)));
         }
         this.dataBlocks[this.dataBlocks.length - 1].categories.push(cat);
     }
@@ -77,25 +91,32 @@ export default class BinaryCIFWriter implements Encoder<Uint8Array> {
     }
 }
 
-function createArray(field: Field, count: number) {
-    if (field.type === Field.Type.Str) return new Array(count) as any;
-    else if (field.defaultFormat && field.defaultFormat.typedArray) return new field.defaultFormat.typedArray(count) as any;
-    else return (field.type === Field.Type.Int ? new Int32Array(count) : new Float32Array(count)) as any;
+function createArray(type: Field.Type, arrayCtor: ArrayEncoding.TypedArrayCtor | undefined,  count: number) {
+    if (type === Field.Type.Str) return new Array(count) as any;
+    else if (arrayCtor) return new arrayCtor(count) as any;
+    else return (type === Field.Type.Int ? new Int32Array(count) : new Float32Array(count)) as any;
 }
 
-function encodeField(field: Field, data: { data: any, keys: () => Iterator<any> }[], totalCount: number, format?: Field.Format): EncodedColumn {
-    const isStr = field.type === Field.Type.Str;
-    const array = createArray(field, totalCount);
-    let encoder: ArrayEncoder;
+function getArrayCtor(field: Field, format: Field.Format | undefined) {
+    if (format && format.typedArray) return format.typedArray;
+    if (field.defaultFormat && field.defaultFormat.typedArray) return field.defaultFormat.typedArray;
+    return void 0;
+}
 
+function getEncoder(field: Field, format: Field.Format | undefined) {
+    if (format && format.encoder) return format.encoder;
     if (field.defaultFormat && field.defaultFormat.encoder) {
-        encoder = field.defaultFormat.encoder;
-    } else if (isStr) {
-        encoder = ArrayEncoder.by(E.stringArray);
+        return field.defaultFormat.encoder;
+    } else if (field.type === Field.Type.Str) {
+        return ArrayEncoder.by(E.stringArray);
     } else {
-        encoder = ArrayEncoder.by(E.byteArray);
+        return ArrayEncoder.by(E.byteArray);
     }
+}
 
+function encodeField(field: Field, data: { data: any, keys: () => Iterator<any> }[], totalCount: number, arrayCtor: ArrayEncoding.TypedArrayCtor | undefined, encoder: ArrayEncoder): EncodedColumn {
+    const isStr = field.type === Field.Type.Str;
+    const array = createArray(field.type, arrayCtor, totalCount);
     const mask = new Uint8Array(totalCount);
     const valueKind = field.valueKind;
     const getter = field.value;

+ 31 - 12
src/mol-io/writer/cif/encoder/text.ts

@@ -12,10 +12,20 @@ import StringBuilder from 'mol-util/string-builder'
 import { Category, Field, Encoder } from '../encoder'
 import Writer from '../../writer'
 
-export default class TextCIFEncoder implements Encoder<string> {
+export default class TextEncoder implements Encoder<string> {
     private builder = StringBuilder.create();
     private encoded = false;
     private dataBlockCreated = false;
+    private filter: Category.Filter = Category.DefaultFilter;
+    private formatter: Category.Formatter = Category.DefaultFormatter;
+
+    setFilter(filter?: Category.Filter) {
+        this.filter = filter || Category.DefaultFilter;
+    }
+
+    setFormatter(formatter?: Category.Formatter) {
+        this.formatter = formatter || Category.DefaultFormatter;
+    }
 
     startDataBlock(header: string) {
         this.dataBlockCreated = true;
@@ -33,15 +43,16 @@ export default class TextCIFEncoder implements Encoder<string> {
 
         const categories = !contexts || !contexts.length ? [category(<any>void 0)] : contexts.map(c => category(c));
         if (!categories.length) return;
+        if (!this.filter.includeCategory(categories[0].name)) return;
 
         const rowCount = categories.reduce((v, c) => v + c.rowCount, 0);
 
         if (rowCount === 0) return;
 
         if (rowCount === 1) {
-            writeCifSingleRecord(categories[0]!, this.builder);
+            writeCifSingleRecord(categories[0]!, this.builder, this.filter, this.formatter);
         } else {
-            writeCifLoop(categories, this.builder);
+            writeCifLoop(categories, this.builder, this.filter, this.formatter);
         }
     }
 
@@ -86,26 +97,34 @@ function writeValue(builder: StringBuilder, data: any, key: any, f: Field<any, a
     return false;
 }
 
-function getFloatPrecisions(cat: Category) {
+function getFloatPrecisions(categoryName: string, fields: Field[], formatter: Category.Formatter) {
     const ret: number[] = [];
-    for (const f of cat.fields) {
-        ret[ret.length] = f.type === Field.Type.Float ? Math.pow(10, Field.getDigitCount(f)) : 0;
+    for (const f of fields) {
+        const format = formatter.getFormat(categoryName, f.name);
+        if (format && typeof format.digitCount !== 'undefined') ret[ret.length] = f.type === Field.Type.Float ? Math.pow(10, Math.max(0, Math.min(format.digitCount, 15))) : 0;
+        else ret[ret.length] = f.type === Field.Type.Float ? Math.pow(10, Field.getDigitCount(f)) : 0;
     }
     return ret;
 }
 
-function writeCifSingleRecord(category: Category<any>, builder: StringBuilder) {
+function writeCifSingleRecord(category: Category<any>, builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
     const fields = category.fields;
     const data = category.data;
-    const width = fields.reduce((w, s) => Math.max(w, s.name.length), 0) + category.name.length + 6;
+    let width = fields.reduce((w, f) => filter.includeField(category.name, f.name) ? Math.max(w, f.name.length) : 0, 0);
+
+    // this means no field from this category is included.
+    if (width === 0) return;
+    width += category.name.length + 6;
 
     const it = category.keys ? category.keys() : Iterator.Range(0, category.rowCount - 1);
     const key = it.move();
 
-    const precisions = getFloatPrecisions(category);
+    const precisions = getFloatPrecisions(category.name, category.fields, formatter);
 
     for (let _f = 0; _f < fields.length; _f++) {
         const f = fields[_f];
+        if (!filter.includeField(category.name, f.name)) continue;
+
         StringBuilder.writePadRight(builder, `_${category.name}.${f.name}`, width);
         const multiline = writeValue(builder, data, key, f, precisions[_f]);
         if (!multiline) StringBuilder.newline(builder);
@@ -113,11 +132,11 @@ function writeCifSingleRecord(category: Category<any>, builder: StringBuilder) {
     StringBuilder.write(builder, '#\n');
 }
 
-function writeCifLoop(categories: Category[], builder: StringBuilder) {
+function writeCifLoop(categories: Category[], builder: StringBuilder, filter: Category.Filter, formatter: Category.Formatter) {
     const first = categories[0];
-    const fields = first.fields;
+    const fields = filter === Category.DefaultFilter ? first.fields : first.fields.filter(f => filter.includeField(first.name, f.name));
     const fieldCount = fields.length;
-    const precisions = getFloatPrecisions(first);
+    const precisions = getFloatPrecisions(first.name, fields, formatter);
 
     writeLine(builder, 'loop_');
     for (let i = 0; i < fieldCount; i++) {

+ 0 - 0
src/mol-io/writer/cif/filter.ts


+ 14 - 1
src/mol-model/structure/export/mmcif.ts

@@ -81,7 +81,20 @@ const Categories = [
     copy_mmCif_cat('symmetry'),
     _entity,
     _atom_site
-]
+];
+
+mmCIF_Schema
+
+namespace _Filters {
+    export const AtomSitePositionsFieldNames = new Set<string>(<(keyof typeof mmCIF_Schema.atom_site)[]>['id', 'Cartn_x', 'Cartn_y', 'Cartn_z']);
+}
+
+export const mmCIF_Export_Filters = {
+    onlyPositions: <CifWriter.Category.Filter>{
+        includeCategory(name) { return name === 'atom_site'; },
+        includeField(cat, field) { return _Filters.AtomSitePositionsFieldNames.has(field); }
+    }
+}
 
 /** Doesn't start a data block */
 export function encode_mmCIF_categories(encoder: CifWriter.Encoder, structure: Structure) {

+ 6 - 0
src/perf-tests/cif-encoder.ts

@@ -29,7 +29,13 @@ function getInstance(ctx: { name: string, fields: CifWriter.Field[], rowCount: n
 
 const enc = CifWriter.createEncoder();
 
+const filter: CifWriter.Category.Filter = {
+    includeCategory(cat) { return true; },
+    includeField(cat, field) { return !(cat === 'cat2' && field === 'e2') }
+}
+
 enc.startDataBlock('test');
+enc.setFilter(filter);
 enc.writeCategory(getInstance, [{ rowCount: 5, name: 'cat1', fields: category1fields }]);
 enc.writeCategory(getInstance, [{ rowCount: 1, name: 'cat2', fields: category2fields  }]);
 console.log(enc.getData());

+ 3 - 0
src/servers/model/server/query.ts

@@ -80,7 +80,10 @@ export async function resolveRequest(req: Request, writer: Writer) {
     encoder.startDataBlock('result');
     encoder.writeCategory(_model_server_result, [req]);
     encoder.writeCategory(_model_server_params, [req]);
+
+    // encoder.setFilter(mmCIF_Export_Filters.onlyPositions);
     encode_mmCIF_categories(encoder, result);
+    // encoder.setFilter();
     perf.end('encode');
 
     ConsoleLogger.logId(req.id, 'Query', 'Encoded.');