|
@@ -1,7 +1,8 @@
|
|
|
/**
|
|
|
- * Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info.
|
|
|
+ * Copyright (c) 2017-2020 mol* contributors, licensed under MIT, See LICENSE file for more info.
|
|
|
*
|
|
|
* @author David Sehnal <david.sehnal@gmail.com>
|
|
|
+ * @author Alexander Rose <alexander.rose@weirdbyte.de>
|
|
|
*/
|
|
|
|
|
|
import { Mat4, Vec3, Vec4, Mat3 } from './3d';
|
|
@@ -21,6 +22,7 @@ export namespace Tensor {
|
|
|
get(data: Tensor.Data, ...coords: number[]): number
|
|
|
set(data: Tensor.Data, ...coordsAndValue: number[]): number
|
|
|
add(data: Tensor.Data, ...coordsAndValue: number[]): number
|
|
|
+ dataOffset(...coords: number[]): number
|
|
|
}
|
|
|
|
|
|
interface Layout {
|
|
@@ -46,8 +48,8 @@ export namespace Tensor {
|
|
|
|
|
|
export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
|
|
|
const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
|
|
|
- const { get, set, add } = accessors(layout);
|
|
|
- return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add };
|
|
|
+ const { get, set, add, dataOffset } = accessors(layout);
|
|
|
+ return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add, dataOffset };
|
|
|
}
|
|
|
|
|
|
export function Data1(values: ArrayLike<number>): Data { return values as Data; }
|
|
@@ -95,13 +97,14 @@ export namespace Tensor {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
- function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'] } {
|
|
|
+ function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'], dataOffset: Space['dataOffset'] } {
|
|
|
const { dimensions, axisOrderFastToSlow: ao } = layout;
|
|
|
switch (dimensions.length) {
|
|
|
case 1: return {
|
|
|
get: (t, d) => t[d],
|
|
|
set: (t, d, x) => t[d] = x,
|
|
|
- add: (t, d, x) => t[d] += x
|
|
|
+ add: (t, d, x) => t[d] += x,
|
|
|
+ dataOffset: (d) => d
|
|
|
};
|
|
|
case 2: {
|
|
|
// column major
|
|
@@ -110,7 +113,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j) => t[j * rows + i],
|
|
|
set: (t, i, j, x) => t[j * rows + i] = x,
|
|
|
- add: (t, i, j, x) => t[j * rows + i] += x
|
|
|
+ add: (t, i, j, x) => t[j * rows + i] += x,
|
|
|
+ dataOffset: (i, j) => j * rows + i
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 1 && ao[1] === 0) {
|
|
@@ -118,7 +122,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j) => t[i * cols + j],
|
|
|
set: (t, i, j, x) => t[i * cols + j] = x,
|
|
|
- add: (t, i, j, x) => t[i * cols + j] += x
|
|
|
+ add: (t, i, j, x) => t[i * cols + j] += x,
|
|
|
+ dataOffset: (i, j) => i * cols + j
|
|
|
};
|
|
|
}
|
|
|
throw new Error('bad axis order');
|
|
@@ -129,7 +134,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[i + j * u + k * uv],
|
|
|
set: (t, i, j, k, x ) => t[i + j * u + k * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => i + j * u + k * uv
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 0 && ao[1] === 2 && ao[2] === 1) { // 021 ikj
|
|
@@ -137,7 +143,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[i + k * u + j * uv],
|
|
|
set: (t, i, j, k, x ) => t[i + k * u + j * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => i + k * u + j * uv
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 1 && ao[1] === 0 && ao[2] === 2) { // 102 jik
|
|
@@ -145,7 +152,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[j + i * u + k * uv],
|
|
|
set: (t, i, j, k, x ) => t[j + i * u + k * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => j + i * u + k * uv
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 1 && ao[1] === 2 && ao[2] === 0) { // 120 jki
|
|
@@ -153,7 +161,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[j + k * u + i * uv],
|
|
|
set: (t, i, j, k, x ) => t[j + k * u + i * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => j + k * u + i * uv
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 2 && ao[1] === 0 && ao[2] === 1) { // 201 kij
|
|
@@ -161,7 +170,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[k + i * u + j * uv],
|
|
|
set: (t, i, j, k, x ) => t[k + i * u + j * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => k + i * u + j * uv
|
|
|
};
|
|
|
}
|
|
|
if (ao[0] === 2 && ao[1] === 1 && ao[2] === 0) { // 210 kji
|
|
@@ -169,7 +179,8 @@ export namespace Tensor {
|
|
|
return {
|
|
|
get: (t, i, j, k) => t[k + j * u + i * uv],
|
|
|
set: (t, i, j, k, x ) => t[k + j * u + i * uv] = x,
|
|
|
- add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x
|
|
|
+ add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x,
|
|
|
+ dataOffset: (i, j, k) => k + j * u + i * uv
|
|
|
};
|
|
|
}
|
|
|
throw new Error('bad axis order');
|
|
@@ -177,7 +188,8 @@ export namespace Tensor {
|
|
|
default: return {
|
|
|
get: (t, ...c) => t[dataOffset(layout, c)],
|
|
|
set: (t, ...c) => t[dataOffset(layout, c)] = c[c.length - 1],
|
|
|
- add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1]
|
|
|
+ add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1],
|
|
|
+ dataOffset: (...c) => dataOffset(layout, c)
|
|
|
};
|
|
|
}
|
|
|
}
|