color-smoothing.ts 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. /**
  2. * Copyright (c) 2021 mol* contributors, licensed under MIT, See LICENSE file for more info.
  3. *
  4. * @author Alexander Rose <alexander.rose@weirdbyte.de>
  5. */
  6. import { ValueCell } from '../../../mol-util';
  7. import { createComputeRenderable, ComputeRenderable } from '../../../mol-gl/renderable';
  8. import { WebGLContext } from '../../../mol-gl/webgl/context';
  9. import { Texture } from '../../../mol-gl/webgl/texture';
  10. import { ShaderCode } from '../../../mol-gl/shader-code';
  11. import { createComputeRenderItem } from '../../../mol-gl/webgl/render-item';
  12. import { ValueSpec, AttributeSpec, UniformSpec, TextureSpec, Values, DefineSpec } from '../../../mol-gl/renderable/schema';
  13. import { quad_vert } from '../../../mol-gl/shader/quad.vert';
  14. import { normalize_frag } from '../../../mol-gl/shader/compute/color-smoothing/normalize.frag';
  15. import { QuadSchema, QuadValues } from '../../../mol-gl/compute/util';
  16. import { Vec2, Vec3, Vec4 } from '../../../mol-math/linear-algebra';
  17. import { Box3D, Sphere3D } from '../../../mol-math/geometry';
  18. import { accumulate_frag } from '../../../mol-gl/shader/compute/color-smoothing/accumulate.frag';
  19. import { accumulate_vert } from '../../../mol-gl/shader/compute/color-smoothing/accumulate.vert';
  20. import { TextureImage } from '../../../mol-gl/renderable/util';
  21. export const ColorAccumulateSchema = {
  22. drawCount: ValueSpec('number'),
  23. instanceCount: ValueSpec('number'),
  24. uTotalCount: UniformSpec('i'),
  25. uInstanceCount: UniformSpec('i'),
  26. uGroupCount: UniformSpec('i'),
  27. aTransform: AttributeSpec('float32', 16, 1),
  28. aInstance: AttributeSpec('float32', 1, 1),
  29. aSample: AttributeSpec('float32', 1, 0),
  30. uGeoTexDim: UniformSpec('v2', 'buffered'),
  31. tPosition: TextureSpec('texture', 'rgba', 'float', 'nearest'),
  32. tGroup: TextureSpec('texture', 'rgba', 'float', 'nearest'),
  33. uColorTexDim: UniformSpec('v2'),
  34. tColor: TextureSpec('image-uint8', 'rgb', 'ubyte', 'nearest'),
  35. dColorType: DefineSpec('string', ['group', 'groupInstance', 'vertex', 'vertexInstance']),
  36. uCurrentSlice: UniformSpec('f'),
  37. uCurrentX: UniformSpec('f'),
  38. uCurrentY: UniformSpec('f'),
  39. uBboxMin: UniformSpec('v3', 'material'),
  40. uBboxSize: UniformSpec('v3', 'material'),
  41. uResolution: UniformSpec('f', 'material'),
  42. };
  43. type ColorAccumulateValues = Values<typeof ColorAccumulateSchema>
  44. const ColorAccumulateName = 'color-accumulate';
  45. interface AccumulateInput {
  46. vertexCount: number
  47. instanceCount: number
  48. groupCount: number
  49. transformBuffer: Float32Array
  50. instanceBuffer: Float32Array
  51. positionTexture: Texture
  52. groupTexture: Texture
  53. colorData: TextureImage<Uint8Array>
  54. colorType: 'group' | 'groupInstance'
  55. }
  56. function getSampleBuffer(sampleCount: number, stride: number) {
  57. const sampleBuffer = new Float32Array(sampleCount);
  58. for (let i = 0; i < sampleCount; ++i) {
  59. sampleBuffer[i] = i * stride;
  60. }
  61. return sampleBuffer;
  62. }
  63. function getAccumulateRenderable(ctx: WebGLContext, input: AccumulateInput, box: Box3D, resolution: number, stride: number): ComputeRenderable<ColorAccumulateValues> {
  64. if (ctx.namedComputeRenderables[ColorAccumulateName]) {
  65. const extent = Vec3.sub(Vec3(), box.max, box.min);
  66. const v = ctx.namedComputeRenderables[ColorAccumulateName].values as ColorAccumulateValues;
  67. const sampleCount = input.vertexCount / stride;
  68. if (sampleCount > v.drawCount.ref.value) {
  69. ValueCell.update(v.aSample, getSampleBuffer(sampleCount, stride));
  70. }
  71. ValueCell.updateIfChanged(v.drawCount, sampleCount);
  72. ValueCell.updateIfChanged(v.instanceCount, input.instanceCount);
  73. ValueCell.updateIfChanged(v.uTotalCount, input.vertexCount);
  74. ValueCell.updateIfChanged(v.uInstanceCount, input.instanceCount);
  75. ValueCell.updateIfChanged(v.uGroupCount, input.groupCount);
  76. ValueCell.update(v.aTransform, input.transformBuffer);
  77. ValueCell.update(v.aInstance, input.instanceBuffer);
  78. ValueCell.update(v.uGeoTexDim, Vec2.set(v.uGeoTexDim.ref.value, input.positionTexture.getWidth(), input.positionTexture.getHeight()));
  79. ValueCell.update(v.tPosition, input.positionTexture);
  80. ValueCell.update(v.tGroup, input.groupTexture);
  81. ValueCell.update(v.uColorTexDim, Vec2.set(v.uColorTexDim.ref.value, input.colorData.width, input.colorData.height));
  82. ValueCell.update(v.tColor, input.colorData);
  83. ValueCell.updateIfChanged(v.dColorType, input.colorType);
  84. ValueCell.updateIfChanged(v.uCurrentSlice, 0);
  85. ValueCell.updateIfChanged(v.uCurrentX, 0);
  86. ValueCell.updateIfChanged(v.uCurrentY, 0);
  87. ValueCell.update(v.uBboxMin, box.min);
  88. ValueCell.update(v.uBboxSize, extent);
  89. ValueCell.updateIfChanged(v.uResolution, resolution);
  90. ctx.namedComputeRenderables[ColorAccumulateName].update();
  91. } else {
  92. ctx.namedComputeRenderables[ColorAccumulateName] = createAccumulateRenderable(ctx, input, box, resolution, stride);
  93. }
  94. return ctx.namedComputeRenderables[ColorAccumulateName];
  95. }
  96. function createAccumulateRenderable(ctx: WebGLContext, input: AccumulateInput, box: Box3D, resolution: number, stride: number) {
  97. const extent = Vec3.sub(Vec3(), box.max, box.min);
  98. const sampleCount = input.vertexCount / stride;
  99. const values: ColorAccumulateValues = {
  100. drawCount: ValueCell.create(sampleCount),
  101. instanceCount: ValueCell.create(input.instanceCount),
  102. uTotalCount: ValueCell.create(input.vertexCount),
  103. uInstanceCount: ValueCell.create(input.instanceCount),
  104. uGroupCount: ValueCell.create(input.groupCount),
  105. aTransform: ValueCell.create(input.transformBuffer),
  106. aInstance: ValueCell.create(input.instanceBuffer),
  107. aSample: ValueCell.create(getSampleBuffer(sampleCount, stride)),
  108. uGeoTexDim: ValueCell.create(Vec2.create(input.positionTexture.getWidth(), input.positionTexture.getHeight())),
  109. tPosition: ValueCell.create(input.positionTexture),
  110. tGroup: ValueCell.create(input.groupTexture),
  111. uColorTexDim: ValueCell.create(Vec2.create(input.colorData.width, input.colorData.height)),
  112. tColor: ValueCell.create(input.colorData),
  113. dColorType: ValueCell.create(input.colorType),
  114. uCurrentSlice: ValueCell.create(0),
  115. uCurrentX: ValueCell.create(0),
  116. uCurrentY: ValueCell.create(0),
  117. uBboxMin: ValueCell.create(box.min),
  118. uBboxSize: ValueCell.create(extent),
  119. uResolution: ValueCell.create(resolution),
  120. };
  121. const schema = { ...ColorAccumulateSchema };
  122. const shaderCode = ShaderCode('accumulate', accumulate_vert, accumulate_frag);
  123. const renderItem = createComputeRenderItem(ctx, 'points', shaderCode, schema, values);
  124. return createComputeRenderable(renderItem, values);
  125. }
  126. function setAccumulateDefaults(ctx: WebGLContext) {
  127. const { gl, state } = ctx;
  128. state.disable(gl.CULL_FACE);
  129. state.enable(gl.BLEND);
  130. state.disable(gl.DEPTH_TEST);
  131. state.enable(gl.SCISSOR_TEST);
  132. state.depthMask(false);
  133. state.clearColor(0, 0, 0, 0);
  134. state.blendFunc(gl.ONE, gl.ONE);
  135. state.blendEquation(gl.FUNC_ADD);
  136. }
  137. //
  138. export const ColorNormalizeSchema = {
  139. ...QuadSchema,
  140. tColor: TextureSpec('texture', 'rgba', 'float', 'nearest'),
  141. uTexSize: UniformSpec('v2'),
  142. };
  143. type ColorNormalizeValues = Values<typeof ColorNormalizeSchema>
  144. const ColorNormalizeName = 'color-normalize';
  145. function getNormalizeRenderable(ctx: WebGLContext, color: Texture): ComputeRenderable<ColorNormalizeValues> {
  146. if (ctx.namedComputeRenderables[ColorNormalizeName]) {
  147. const v = ctx.namedComputeRenderables[ColorNormalizeName].values as ColorNormalizeValues;
  148. ValueCell.update(v.tColor, color);
  149. ValueCell.update(v.uTexSize, Vec2.set(v.uTexSize.ref.value, color.getWidth(), color.getHeight()));
  150. ctx.namedComputeRenderables[ColorNormalizeName].update();
  151. } else {
  152. ctx.namedComputeRenderables[ColorNormalizeName] = createColorNormalizeRenderable(ctx, color);
  153. }
  154. return ctx.namedComputeRenderables[ColorNormalizeName];
  155. }
  156. function createColorNormalizeRenderable(ctx: WebGLContext, color: Texture) {
  157. const values: ColorNormalizeValues = {
  158. ...QuadValues,
  159. tColor: ValueCell.create(color),
  160. uTexSize: ValueCell.create(Vec2.create(color.getWidth(), color.getHeight())),
  161. };
  162. const schema = { ...ColorNormalizeSchema };
  163. const shaderCode = ShaderCode('normalize', quad_vert, normalize_frag);
  164. const renderItem = createComputeRenderItem(ctx, 'triangles', shaderCode, schema, values);
  165. return createComputeRenderable(renderItem, values);
  166. }
  167. function setNormalizeDefaults(ctx: WebGLContext) {
  168. const { gl, state } = ctx;
  169. state.disable(gl.CULL_FACE);
  170. state.enable(gl.BLEND);
  171. state.disable(gl.DEPTH_TEST);
  172. state.enable(gl.SCISSOR_TEST);
  173. state.depthMask(false);
  174. state.clearColor(0, 0, 0, 0);
  175. state.blendFunc(gl.ONE, gl.ONE);
  176. state.blendEquation(gl.FUNC_ADD);
  177. }
  178. //
  179. function getTexture2dSize(gridDim: Vec3) {
  180. const area = gridDim[0] * gridDim[1] * gridDim[2];
  181. const squareDim = Math.sqrt(area);
  182. const powerOfTwoSize = Math.pow(2, Math.ceil(Math.log(squareDim) / Math.log(2)));
  183. let texDimX = 0;
  184. let texDimY = gridDim[1];
  185. let texRows = 1;
  186. let texCols = gridDim[2];
  187. if (powerOfTwoSize < gridDim[0] * gridDim[2]) {
  188. texCols = Math.floor(powerOfTwoSize / gridDim[0]);
  189. texRows = Math.ceil(gridDim[2] / texCols);
  190. texDimX = texCols * gridDim[0];
  191. texDimY *= texRows;
  192. } else {
  193. texDimX = gridDim[0] * gridDim[2];
  194. }
  195. // console.log(texDimX, texDimY, texDimY < powerOfTwoSize ? powerOfTwoSize : powerOfTwoSize * 2);
  196. return { texDimX, texDimY, texRows, texCols, powerOfTwoSize: texDimY < powerOfTwoSize ? powerOfTwoSize : powerOfTwoSize * 2 };
  197. }
  198. interface ColorSmoothingInput extends AccumulateInput {
  199. boundingSphere: Sphere3D
  200. invariantBoundingSphere: Sphere3D
  201. }
  202. export function calcTextureMeshColorSmoothing(webgl: WebGLContext, input: ColorSmoothingInput, resolution: number, stride: number, texture?: Texture) {
  203. const { gl, resources, state, extensions: { colorBufferHalfFloat, textureHalfFloat } } = webgl;
  204. const isInstanceType = input.colorType.endsWith('Instance');
  205. const box = Box3D.fromSphere3D(Box3D(), isInstanceType ? input.boundingSphere : input.invariantBoundingSphere);
  206. const scaleFactor = 1 / resolution;
  207. const scaledBox = Box3D.scale(Box3D(), box, scaleFactor);
  208. const dim = Box3D.size(Vec3(), scaledBox);
  209. Vec3.ceil(dim, dim);
  210. Vec3.add(dim, dim, Vec3.create(2, 2, 2));
  211. const { min } = box;
  212. const [ dx, dy, dz ] = dim;
  213. const { texDimX: width, texDimY: height, texCols } = getTexture2dSize(dim);
  214. // console.log({ width, height, texCols, dim, resolution });
  215. if (!webgl.namedTextures[ColorAccumulateName]) {
  216. webgl.namedTextures[ColorAccumulateName] = colorBufferHalfFloat && textureHalfFloat
  217. ? resources.texture('image-float16', 'rgba', 'fp16', 'nearest')
  218. : resources.texture('image-float32', 'rgba', 'float', 'nearest');
  219. }
  220. const accumulateTexture = webgl.namedTextures[ColorAccumulateName];
  221. accumulateTexture.define(width, height);
  222. const accumulateRenderable = getAccumulateRenderable(webgl, input, box, resolution, stride);
  223. //
  224. const { uCurrentSlice, uCurrentX, uCurrentY } = accumulateRenderable.values;
  225. if (!webgl.namedFramebuffers[ColorAccumulateName]) {
  226. webgl.namedFramebuffers[ColorAccumulateName] = webgl.resources.framebuffer();
  227. }
  228. const framebuffer = webgl.namedFramebuffers[ColorAccumulateName];
  229. framebuffer.bind();
  230. setAccumulateDefaults(webgl);
  231. state.currentRenderItemId = -1;
  232. accumulateTexture.attachFramebuffer(framebuffer, 0);
  233. gl.viewport(0, 0, width, height);
  234. gl.scissor(0, 0, width, height);
  235. gl.clear(gl.COLOR_BUFFER_BIT);
  236. ValueCell.update(uCurrentY, 0);
  237. let currCol = 0;
  238. let currY = 0;
  239. let currX = 0;
  240. for (let i = 0; i < dz; ++i) {
  241. if (currCol >= texCols) {
  242. currCol -= texCols;
  243. currY += dy;
  244. currX = 0;
  245. ValueCell.update(uCurrentY, currY);
  246. }
  247. // console.log({ i, currX, currY });
  248. ValueCell.update(uCurrentX, currX);
  249. ValueCell.update(uCurrentSlice, i);
  250. gl.viewport(currX, currY, dx, dy);
  251. gl.scissor(currX, currY, dx, dy);
  252. accumulateRenderable.render();
  253. ++currCol;
  254. currX += dx;
  255. }
  256. // const accImage = new Float32Array(width * height * 4);
  257. // accumulateTexture.attachFramebuffer(framebuffer, 0);
  258. // webgl.readPixels(0, 0, width, height, accImage);
  259. // console.log(accImage);
  260. // printTextureImage({ array: accImage, width, height }, 1 / 4);
  261. // normalize
  262. if (!texture) texture = resources.texture('image-uint8', 'rgb', 'ubyte', 'linear');
  263. texture.define(width, height);
  264. const normalizeRenderable = getNormalizeRenderable(webgl, accumulateTexture);
  265. setNormalizeDefaults(webgl);
  266. state.currentRenderItemId = -1;
  267. texture.attachFramebuffer(framebuffer, 0);
  268. gl.viewport(0, 0, width, height);
  269. gl.scissor(0, 0, width, height);
  270. gl.clear(gl.COLOR_BUFFER_BIT);
  271. normalizeRenderable.render();
  272. // const normImage = new Uint8Array(width * height * 4);
  273. // texture.attachFramebuffer(framebuffer, 0);
  274. // webgl.readPixels(0, 0, width, height, normImage);
  275. // console.log(normImage);
  276. // printTextureImage({ array: normImage, width, height }, 1 / 4);
  277. const transform = Vec4.create(min[0], min[1], min[2], scaleFactor);
  278. const type = isInstanceType ? 'volumeInstance' : 'volume';
  279. return { texture, gridDim: dim, gridTexDim: Vec2.create(width, height), transform, type };
  280. }