tensor.ts 9.9 KB


  1. /**
  2. * Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info.
  3. *
  4. * @author David Sehnal <david.sehnal@gmail.com>
  5. */
  6. import { Mat4, Vec3, Vec4, Mat3 } from './3d'
  7. export interface Tensor { data: Tensor.Data, space: Tensor.Space }
  8. export namespace Tensor {
  9. export type ArrayCtor = { new (size: number): ArrayLike<number> }
  10. export interface Data extends Array<number> { '@type': 'tensor' }
  11. export interface Space {
  12. readonly rank: number,
  13. readonly dimensions: ReadonlyArray<number>,
  14. readonly axisOrderSlowToFast: ReadonlyArray<number>,
  15. create(array?: ArrayCtor): Tensor.Data,
  16. get(data: Tensor.Data, ...coords: number[]): number
  17. set(data: Tensor.Data, ...coordsAndValue: number[]): number
  18. add(data: Tensor.Data, ...coordsAndValue: number[]): number
  19. }
  20. interface Layout {
  21. dimensions: number[],
  22. axisOrderSlowToFast: number[],
  23. axisOrderFastToSlow: number[],
  24. accessDimensions: number[],
  25. // if not specified, use Float64Array
  26. defaultCtor: ArrayCtor
  27. }
  28. function Layout(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Layout {
  29. // need to reverse the axis order for better access.
  30. const axisOrderFastToSlow: number[] = [];
  31. for (let i = 0; i < axisOrderSlowToFast.length; i++) axisOrderFastToSlow[i] = axisOrderSlowToFast[axisOrderSlowToFast.length - i - 1];
  32. const accessDimensions = [1];
  33. for (let i = 1; i < dimensions.length; i++) accessDimensions[i] = dimensions[axisOrderFastToSlow[i - 1]];
  34. return { dimensions, axisOrderFastToSlow, axisOrderSlowToFast, accessDimensions, defaultCtor: ctor || Float64Array }
  35. }
  36. export function create(space: Space, data: Data): Tensor { return { space, data }; }
  37. export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
  38. const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
  39. const { get, set, add } = accessors(layout);
  40. return { rank: dimensions.length, dimensions, axisOrderSlowToFast, create: creator(layout), get, set, add };
  41. }
  42. export function Data1(values: ArrayLike<number>): Data { return values as Data; }
  43. export function Vector(d: number, ctor?: ArrayCtor) { return Space([d], [0], ctor); }
  44. export function ColumnMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [1, 0], ctor); }
  45. export function RowMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [0, 1], ctor); }
  46. export function toMat4(out: Mat4, space: Space, data: Tensor.Data): Mat4 {
  47. if (space.rank !== 2) throw new Error('Invalid tensor rank');
  48. const d0 = Math.min(4, space.dimensions[0]), d1 = Math.min(4, space.dimensions[1]);
  49. for (let i = 0; i < d0; i++) {
  50. for (let j = 0; j < d1; j++) Mat4.setValue(out, i, j, space.get(data, i, j));
  51. }
  52. return out;
  53. }
  54. export function toMat3(out: Mat3, space: Space, data: Tensor.Data): Mat3 {
  55. if (space.rank !== 2) throw new Error('Invalid tensor rank');
  56. const d0 = Math.min(3, space.dimensions[0]), d1 = Math.min(3, space.dimensions[1]);
  57. for (let i = 0; i < d0; i++) {
  58. for (let j = 0; j < d1; j++) Mat3.setValue(out, i, j, space.get(data, i, j));
  59. }
  60. return out;
  61. }
  62. export function toVec3(out: Vec3, space: Space, data: Tensor.Data): Vec3 {
  63. if (space.rank !== 1) throw new Error('Invalid tensor rank');
  64. const d0 = Math.min(3, space.dimensions[0]);
  65. for (let i = 0; i < d0; i++) out[i] = data[i];
  66. return out;
  67. }
  68. export function toVec4(out: Vec4, space: Space, data: Tensor.Data): Vec4 {
  69. if (space.rank !== 1) throw new Error('Invalid tensor rank');
  70. const d0 = Math.min(4, space.dimensions[0]);
  71. for (let i = 0; i < d0; i++) out[i] = data[i];
  72. return out;
  73. }
  74. export function areEqualExact(a: Tensor.Data, b: Tensor.Data) {
  75. const len = a.length;
  76. if (len !== b.length) return false;
  77. for (let i = 0; i < len; i++) if (a[i] !== b[i]) return false;
  78. return true;
  79. }
  80. function accessors(layout: Layout): { get: Space['get'], set: Space['set'], add: Space['add'] } {
  81. const { dimensions, axisOrderFastToSlow: ao } = layout;
  82. switch (dimensions.length) {
  83. case 1: return {
  84. get: (t, d) => t[d],
  85. set: (t, d, x) => t[d] = x,
  86. add: (t, d, x) => t[d] += x
  87. };
  88. case 2: {
  89. // column major
  90. if (ao[0] === 0 && ao[1] === 1) {
  91. const rows = dimensions[0];
  92. return {
  93. get: (t, i, j) => t[j * rows + i],
  94. set: (t, i, j, x) => t[j * rows + i] = x,
  95. add: (t, i, j, x) => t[j * rows + i] += x
  96. };
  97. }
  98. if (ao[0] === 1 && ao[1] === 0) {
  99. const cols = dimensions[1];
  100. return {
  101. get: (t, i, j) => t[i * cols + j],
  102. set: (t, i, j, x) => t[i * cols + j] = x,
  103. add: (t, i, j, x) => t[i * cols + j] += x
  104. };
  105. }
  106. throw new Error('bad axis order')
  107. }
  108. case 3: {
  109. if (ao[0] === 0 && ao[1] === 1 && ao[2] === 2) { // 012 ijk
  110. const u = dimensions[0], v = dimensions[1], uv = u * v;
  111. return {
  112. get: (t, i, j, k) => t[i + j * u + k * uv],
  113. set: (t, i, j, k, x ) => t[i + j * u + k * uv] = x,
  114. add: (t, i, j, k, x ) => t[i + j * u + k * uv] += x
  115. };
  116. }
  117. if (ao[0] === 0 && ao[1] === 2 && ao[2] === 1) { // 021 ikj
  118. const u = dimensions[0], v = dimensions[2], uv = u * v;
  119. return {
  120. get: (t, i, j, k) => t[i + k * u + j * uv],
  121. set: (t, i, j, k, x ) => t[i + k * u + j * uv] = x,
  122. add: (t, i, j, k, x ) => t[i + k * u + j * uv] += x
  123. };
  124. }
  125. if (ao[0] === 1 && ao[1] === 0 && ao[2] === 2) { // 102 jik
  126. const u = dimensions[1], v = dimensions[0], uv = u * v;
  127. return {
  128. get: (t, i, j, k) => t[j + i * u + k * uv],
  129. set: (t, i, j, k, x ) => t[j + i * u + k * uv] = x,
  130. add: (t, i, j, k, x ) => t[j + i * u + k * uv] += x
  131. };
  132. }
  133. if (ao[0] === 1 && ao[1] === 2 && ao[2] === 0) { // 120 jki
  134. const u = dimensions[1], v = dimensions[2], uv = u * v;
  135. return {
  136. get: (t, i, j, k) => t[j + k * u + i * uv],
  137. set: (t, i, j, k, x ) => t[j + k * u + i * uv] = x,
  138. add: (t, i, j, k, x ) => t[j + k * u + i * uv] += x
  139. };
  140. }
  141. if (ao[0] === 2 && ao[1] === 0 && ao[2] === 1) { // 201 kij
  142. const u = dimensions[2], v = dimensions[0], uv = u * v;
  143. return {
  144. get: (t, i, j, k) => t[k + i * u + j * uv],
  145. set: (t, i, j, k, x ) => t[k + i * u + j * uv] = x,
  146. add: (t, i, j, k, x ) => t[k + i * u + j * uv] += x
  147. };
  148. }
  149. if (ao[0] === 2 && ao[1] === 1 && ao[2] === 0) { // 210 kji
  150. const u = dimensions[2], v = dimensions[1], uv = u * v;
  151. return {
  152. get: (t, i, j, k) => t[k + j * u + i * uv],
  153. set: (t, i, j, k, x ) => t[k + j * u + i * uv] = x,
  154. add: (t, i, j, k, x ) => t[k + j * u + i * uv] += x
  155. };
  156. }
  157. throw new Error('bad axis order')
  158. }
  159. default: return {
  160. get: (t, ...c) => t[dataOffset(layout, c)],
  161. set: (t, ...c) => t[dataOffset(layout, c)] = c[c.length - 1],
  162. add: (t, ...c) => t[dataOffset(layout, c)] += c[c.length - 1]
  163. };
  164. }
  165. }
  166. function creator(layout: Layout): Space['create'] {
  167. const { dimensions: ds } = layout;
  168. let size = 1;
  169. for (let i = 0, _i = ds.length; i < _i; i++) size *= ds[i];
  170. return ctor => new (ctor || layout.defaultCtor)(size) as Tensor.Data;
  171. }
  172. function dataOffset(layout: Layout, coord: number[]) {
  173. const { accessDimensions: acc, axisOrderFastToSlow: ao } = layout;
  174. const d = acc.length - 1;
  175. let o = acc[d] * coord[ao[d]];
  176. for (let i = d - 1; i >= 0; i--) {
  177. o = (o + coord[ao[i]]) * acc[i];
  178. }
  179. return o;
  180. }
  181. // Convers "slow to fast" axis order to "fast to slow" and vice versa.
  182. export function invertAxisOrder(v: number[]) {
  183. const ret: number[] = [];
  184. for (let i = 0; i < v.length; i++) {
  185. ret[i] = v[v.length - i - 1];
  186. }
  187. return ret;
  188. }
  189. function reorder(xs: number[], indices: number[]) {
  190. const ret: number[] = [];
  191. for (let i = 0; i < xs.length; i++) ret[i] = xs[indices[i]];
  192. return ret;
  193. }
  194. export function convertToCanonicalAxisIndicesFastToSlow(order: number[]) {
  195. const indices = new Int32Array(order.length) as any as number[];
  196. for (let i = 0; i < order.length; i++) indices[order[i]] = i;
  197. return (xs: number[]) => reorder(xs, indices);
  198. }
  199. export function convertToCanonicalAxisIndicesSlowToFast(order: number[]) {
  200. const indices = new Int32Array(order.length) as any as number[];
  201. for (let i = 0; i < order.length; i++) indices[order[order.length - i - 1]] = i;
  202. return (xs: number[]) => reorder(xs, indices);
  203. }
  204. }