|
@@ -6,18 +6,20 @@
|
|
|
|
|
|
import { Mat4, Vec3, Vec4 } from './3d'
|
|
import { Mat4, Vec3, Vec4 } from './3d'
|
|
|
|
|
|
-export interface Tensor extends Array<number> { '@type': 'tensor' }
|
|
|
|
|
|
+export interface Tensor { data: Tensor.Data, space: Tensor.Space }
|
|
|
|
|
|
export namespace Tensor {
|
|
export namespace Tensor {
|
|
export type ArrayCtor = { new (size: number): ArrayLike<number> }
|
|
export type ArrayCtor = { new (size: number): ArrayLike<number> }
|
|
|
|
|
|
|
|
+ export interface Data extends Array<number> { '@type': 'tensor' }
|
|
|
|
+
|
|
export interface Space {
|
|
export interface Space {
|
|
readonly rank: number,
|
|
readonly rank: number,
|
|
readonly dimensions: ReadonlyArray<number>,
|
|
readonly dimensions: ReadonlyArray<number>,
|
|
readonly axisOrderSlowToFast: ReadonlyArray<number>,
|
|
readonly axisOrderSlowToFast: ReadonlyArray<number>,
|
|
- create(array?: ArrayCtor): Tensor,
|
|
|
|
- get(data: Tensor, ...coords: number[]): number
|
|
|
|
- set(data: Tensor, ...coordsAndValue: number[]): number
|
|
|
|
|
|
+ create(array?: ArrayCtor): Tensor.Data,
|
|
|
|
+ get(data: Tensor.Data, ...coords: number[]): number
|
|
|
|
+ set(data: Tensor.Data, ...coordsAndValue: number[]): number
|
|
}
|
|
}
|
|
|
|
|
|
interface Layout {
|
|
interface Layout {
|
|
@@ -39,6 +41,8 @@ export namespace Tensor {
|
|
return { dimensions, axisOrderFastToSlow, axisOrderSlowToFast, accessDimensions, defaultCtor: ctor || Float64Array }
|
|
return { dimensions, axisOrderFastToSlow, axisOrderSlowToFast, accessDimensions, defaultCtor: ctor || Float64Array }
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ export function create(space: Space, data: Data): Tensor { return { space, data }; }
|
|
|
|
+
|
|
export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
|
|
export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
|
|
const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
|
|
const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
|
|
const { get, set } = accessors(layout);
|
|
const { get, set } = accessors(layout);
|
|
@@ -49,7 +53,7 @@ export namespace Tensor {
|
|
export function ColumnMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [1, 0], ctor); }
|
|
export function ColumnMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [1, 0], ctor); }
|
|
export function RowMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [0, 1], ctor); }
|
|
export function RowMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [0, 1], ctor); }
|
|
|
|
|
|
- export function toMat4(space: Space, data: Tensor): Mat4 {
|
|
|
|
|
|
+ export function toMat4(space: Space, data: Tensor.Data): Mat4 {
|
|
if (space.rank !== 2) throw new Error('Invalid tensor rank');
|
|
if (space.rank !== 2) throw new Error('Invalid tensor rank');
|
|
const mat = Mat4.zero();
|
|
const mat = Mat4.zero();
|
|
const d0 = Math.min(4, space.dimensions[0]), d1 = Math.min(4, space.dimensions[1]);
|
|
const d0 = Math.min(4, space.dimensions[0]), d1 = Math.min(4, space.dimensions[1]);
|
|
@@ -59,7 +63,7 @@ export namespace Tensor {
|
|
return mat;
|
|
return mat;
|
|
}
|
|
}
|
|
|
|
|
|
- export function toVec3(space: Space, data: Tensor): Vec3 {
|
|
|
|
|
|
+ export function toVec3(space: Space, data: Tensor.Data): Vec3 {
|
|
if (space.rank !== 1) throw new Error('Invalid tensor rank');
|
|
if (space.rank !== 1) throw new Error('Invalid tensor rank');
|
|
const vec = Vec3.zero();
|
|
const vec = Vec3.zero();
|
|
const d0 = Math.min(3, space.dimensions[0]);
|
|
const d0 = Math.min(3, space.dimensions[0]);
|
|
@@ -67,7 +71,7 @@ export namespace Tensor {
|
|
return vec;
|
|
return vec;
|
|
}
|
|
}
|
|
|
|
|
|
- export function toVec4(space: Space, data: Tensor): Vec4 {
|
|
|
|
|
|
+ export function toVec4(space: Space, data: Tensor.Data): Vec4 {
|
|
if (space.rank !== 1) throw new Error('Invalid tensor rank');
|
|
if (space.rank !== 1) throw new Error('Invalid tensor rank');
|
|
const vec = Vec4.zero();
|
|
const vec = Vec4.zero();
|
|
const d0 = Math.min(4, space.dimensions[0]);
|
|
const d0 = Math.min(4, space.dimensions[0]);
|
|
@@ -75,7 +79,7 @@ export namespace Tensor {
|
|
return vec;
|
|
return vec;
|
|
}
|
|
}
|
|
|
|
|
|
- export function areEqualExact(a: Tensor, b: Tensor) {
|
|
|
|
|
|
+ export function areEqualExact(a: Tensor.Data, b: Tensor.Data) {
|
|
const len = a.length;
|
|
const len = a.length;
|
|
if (len !== b.length) return false;
|
|
if (len !== b.length) return false;
|
|
for (let i = 0; i < len; i++) if (a[i] !== b[i]) return false;
|
|
for (let i = 0; i < len; i++) if (a[i] !== b[i]) return false;
|
|
@@ -136,7 +140,7 @@ export namespace Tensor {
|
|
const { dimensions: ds } = layout;
|
|
const { dimensions: ds } = layout;
|
|
let size = 1;
|
|
let size = 1;
|
|
for (let i = 0, _i = ds.length; i < _i; i++) size *= ds[i];
|
|
for (let i = 0, _i = ds.length; i < _i; i++) size *= ds[i];
|
|
- return ctor => new (ctor || layout.defaultCtor)(size) as Tensor;
|
|
|
|
|
|
+ return ctor => new (ctor || layout.defaultCtor)(size) as Tensor.Data;
|
|
}
|
|
}
|
|
|
|
|
|
function dataOffset(layout: Layout, coord: number[]) {
|
|
function dataOffset(layout: Layout, coord: number[]) {
|