Skip to content

Commit 1ca3781

Browse files
authored
Merge pull request #8538 from processing/fix/tutorial-webgpu
Fix WebGPU bugs surfaced by the Intro to Strands tutorial
2 parents 3a7992c + 0d86288 commit 1ca3781

13 files changed

Lines changed: 315 additions & 27 deletions

File tree

src/strands/ir_builders.js

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ export function unaryOpNode(strandsContext, nodeOrValue, opCode) {
4343
const { dag, cfg } = strandsContext;
4444
let dependsOn;
4545
let node;
46-
if (nodeOrValue instanceof StrandsNode) {
46+
if (nodeOrValue?.isStrandsNode) {
4747
node = nodeOrValue;
4848
} else {
4949
const { id, dimension } = primitiveConstructorNode(strandsContext, { baseType: BaseType.FLOAT, dimension: null }, nodeOrValue);
@@ -257,6 +257,20 @@ export function constructTypeFromIDs(strandsContext, typeInfo, strandsNodesArray
257257

258258
export function primitiveConstructorNode(strandsContext, typeInfo, dependsOn) {
259259
const cfg = strandsContext.cfg;
260+
dependsOn = (Array.isArray(dependsOn) ? dependsOn : [dependsOn])
261+
.flat(Infinity)
262+
.map(a => {
263+
if (
264+
a.isStrandsNode &&
265+
a.typeInfo().baseType === BaseType.INT &&
266+
// TODO: handle ivec inputs instead of just int scalars
267+
a.typeInfo().dimension === 1
268+
) {
269+
return castToFloat(strandsContext, a);
270+
} else {
271+
return a;
272+
}
273+
});
260274
const { mappedDependencies, inferredTypeInfo } = mapPrimitiveDepsToIDs(strandsContext, typeInfo, dependsOn);
261275

262276
const finalType = {
@@ -272,6 +286,24 @@ export function primitiveConstructorNode(strandsContext, typeInfo, dependsOn) {
272286
return { id, dimension: finalType.dimension, components: mappedDependencies };
273287
}
274288

289+
export function castToFloat(strandsContext, dep) {
290+
const { id, dimension } = functionCallNode(
291+
strandsContext,
292+
strandsContext.backend.getTypeName('float', dep.typeInfo().dimension),
293+
[dep],
294+
{
295+
overloads: [{
296+
params: [dep.typeInfo()],
297+
returnType: {
298+
...dep.typeInfo(),
299+
baseType: BaseType.FLOAT,
300+
},
301+
}],
302+
}
303+
);
304+
return createStrandsNode(id, dimension, strandsContext);
305+
}
306+
275307
export function structConstructorNode(strandsContext, structTypeInfo, rawUserArgs) {
276308
const { cfg, dag } = strandsContext;
277309
const { identifer, properties } = structTypeInfo;
@@ -491,7 +523,7 @@ export function swizzleTrap(id, dimension, strandsContext, onRebind) {
491523
// This may not be the most efficient way, as we swizzle each component individually,
492524
// so that .xyz becomes .x, .y, .z
493525
let scalars = [];
494-
if (value instanceof StrandsNode) {
526+
if (value?.isStrandsNode) {
495527
if (value.dimension === 1) {
496528
scalars = Array(chars.length).fill(value);
497529
} else if (value.dimension === chars.length) {

src/strands/strands_api.js

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function _getBuiltinGlobalsCache(strandsContext) {
5656
function getBuiltinGlobalNode(strandsContext, name) {
5757
const spec = BUILTIN_GLOBAL_SPECS[name]
5858
if (!spec) return null
59-
59+
6060
const cache = _getBuiltinGlobalsCache(strandsContext)
6161
const uniformName = `_p5_global_${name}`
6262
const cached = cache.nodes.get(uniformName)
@@ -154,7 +154,7 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
154154
}
155155

156156
// Convert value to a StrandsNode if it isn't already
157-
const valueNode = value instanceof StrandsNode ? value : p5.strandsNode(value);
157+
const valueNode = value?.isStrandsNode ? value : p5.strandsNode(value);
158158

159159
// Create a new CFG block for the early return
160160
const earlyReturnBlockID = CFG.createBasicBlock(cfg, BlockType.DEFAULT);
@@ -369,12 +369,17 @@ export function initGlobalStrandsAPI(p5, fn, strandsContext) {
369369
fn[typeInfo.fnName] = function(...args) {
370370
if (strandsContext.active) {
371371
if (args.length === 1 && args[0].dimension && args[0].dimension === typeInfo.dimension) {
372-
const { id, dimension } = build.functionCallNode(strandsContext, typeInfo.fnName, args, {
373-
overloads: [{
374-
params: [args[0].typeInfo()],
375-
returnType: typeInfo,
376-
}]
377-
});
372+
const { id, dimension } = build.functionCallNode(
373+
strandsContext,
374+
strandsContext.backend.getTypeName(typeInfo.baseType, typeInfo.dimension),
375+
args,
376+
{
377+
overloads: [{
378+
params: [args[0].typeInfo()],
379+
returnType: typeInfo,
380+
}]
381+
}
382+
);
378383
return createStrandsNode(id, dimension, strandsContext);
379384
} else {
380385
// For vector types with a single argument, repeat it for each component
@@ -431,7 +436,7 @@ function createHookArguments(strandsContext, parameters){
431436
const oldDependsOn = dag.dependsOn[structNode.id];
432437
const newDependsOn = [...oldDependsOn];
433438
let newValueID;
434-
if (val instanceof StrandsNode) {
439+
if (val?.isStrandsNode) {
435440
newValueID = val.id;
436441
}
437442
else {
@@ -463,7 +468,7 @@ function createHookArguments(strandsContext, parameters){
463468
return args;
464469
}
465470
function enforceReturnTypeMatch(strandsContext, expectedType, returned, hookName) {
466-
if (!(returned instanceof StrandsNode)) {
471+
if (!(returned?.isStrandsNode)) {
467472
// try {
468473
const result = build.primitiveConstructorNode(strandsContext, expectedType, returned);
469474
return result.id;
@@ -578,7 +583,7 @@ export function createShaderHooksFunctions(strandsContext, fn, shader) {
578583
const handleRetVal = (retNode) => {
579584
if(isStructType(expectedReturnType)) {
580585
const expectedStructType = structType(expectedReturnType);
581-
if (retNode instanceof StrandsNode) {
586+
if (retNode?.isStrandsNode) {
582587
const returnedNode = getNodeDataFromID(strandsContext.dag, retNode.id);
583588
if (returnedNode.baseType !== expectedStructType.typeName) {
584589
const receivedTypeName = returnedNode.baseType || 'undefined';

src/strands/strands_for.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ export class StrandsFor {
309309
let initialVar = this.initialCb();
310310

311311
// Convert to StrandsNode if it's not already one
312-
if (!(initialVar instanceof StrandsNode)) {
312+
if (!(initialVar?.isStrandsNode)) {
313313
const { id, dimension } = primitiveConstructorNode(this.strandsContext, { baseType: BaseType.FLOAT, dimension: 1 }, initialVar);
314314
initialVar = createStrandsNode(id, dimension, this.strandsContext);
315315
}

src/strands/strands_node.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ export class StrandsNode {
4040
const baseType = orig?.baseType ?? BaseType.FLOAT;
4141

4242
let newValueID;
43-
if (value instanceof StrandsNode) {
43+
if (value?.isStrandsNode) {
4444
newValueID = value.id;
4545
} else {
4646
const newVal = primitiveConstructorNode(
@@ -95,7 +95,7 @@ export class StrandsNode {
9595
const baseType = orig?.baseType ?? BaseType.FLOAT;
9696

9797
let newValueID;
98-
if (value instanceof StrandsNode) {
98+
if (value?.isStrandsNode) {
9999
newValueID = value.id;
100100
} else {
101101
const newVal = primitiveConstructorNode(

src/webgpu/p5.RendererWebGPU.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1904,7 +1904,7 @@ function rendererWebGPU(p5, fn) {
19041904

19051905
getNextBindingIndex({ vert, frag }, group = 0) {
19061906
// Get the highest binding index in the specified group and return the next available
1907-
const samplerRegex = /@group\((\d+)\)\s*@binding\((\d+)\)\s*var\s+(\w+)\s*:\s*(texture_2d<f32>|sampler|uniform)/g;
1907+
const samplerRegex = /@group\((\d+)\)\s*@binding\((\d+)\)\s*var(?:<uniform>)?\s+(\w+)\s*:\s*(texture_2d<f32>|sampler|uniform|\w+)/g;
19081908
let maxBindingIndex = -1;
19091909

19101910
for (const [src, visibility] of [
@@ -2254,6 +2254,9 @@ function rendererWebGPU(p5, fn) {
22542254
// Inject hook uniforms as a separate struct at a new binding
22552255
let hookUniformFields = '';
22562256
for (const key in shader.hooks.uniforms) {
2257+
// Skip textures, they don't get added to structs
2258+
if (key.endsWith(': sampler2D')) continue;
2259+
22572260
// WGSL format: "name: type"
22582261
hookUniformFields += ` ${key},\n`;
22592262
}

src/webgpu/strands_wgslBackend.js

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,11 @@ export const wgslBackend = {
223223
return primitiveTypeName;
224224
},
225225
generateHookUniformKey(name, typeInfo) {
226-
// For sampler2D types, we don't add them to the uniform struct
227-
// Instead, they become separate texture and sampler bindings
226+
// For sampler2D types, we don't add them to the uniform struct,
227+
// but we still need them in the shader's hooks object so that
228+
// they can be set by users.
228229
if (typeInfo.baseType === 'sampler2D') {
229-
return null; // Signal that this should not be added to uniform struct
230+
return `${name}: sampler2D`; // Signal that this should not be added to uniform struct
230231
}
231232
return `${name}: ${this.getTypeName(typeInfo.baseType, typeInfo.dimension)}`;
232233
},

test/unit/visual/cases/webgl.js

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,128 @@ visualSuite('WebGL', function() {
11441144
screenshot();
11451145
});
11461146
});
1147+
1148+
visualTest('Strands tutorial', function(p5, screenshot) {
1149+
// From Luke Plowden's Intro to Strands tutorial
1150+
// https://beta.p5js.org/tutorials/intro-to-p5-strands/
1151+
1152+
function starShaderCallback({ p5 }) {
1153+
const time = p5.uniformFloat(() => p5.millis());
1154+
const skyRadius = p5.uniformFloat(90);
1155+
1156+
function rand2(st) {
1157+
return p5.sin((st.x + st.y) * 123.456);
1158+
}
1159+
1160+
function semiSphere() {
1161+
let id = p5.instanceID();
1162+
let theta = rand2([id, 0.1234]) * p5.TWO_PI + time / 100000;
1163+
let phi = rand2([id, 3.321]) * p5.PI + time / 50000;
1164+
1165+
let r = skyRadius;
1166+
r *= p5.sin(phi);
1167+
let x = r * p5.sin(phi) * p5.cos(theta);
1168+
let y = r * 1.5 * p5.cos(phi);
1169+
let z = r * p5.sin(phi) * p5.sin(theta);
1170+
return [x, y, z];
1171+
}
1172+
1173+
p5.getWorldInputs((inputs) => {
1174+
inputs.position += semiSphere();
1175+
return inputs;
1176+
});
1177+
1178+
p5.getObjectInputs((inputs) => {
1179+
let size = 1 + 0.5 * p5.sin(time * 0.002 + p5.instanceID());
1180+
inputs.position *= size;
1181+
return inputs;
1182+
});
1183+
}
1184+
1185+
function pixelateShaderCallback({ p5 }) {
1186+
const pixelCountX = p5.uniformFloat(() => 100);
1187+
1188+
p5.getColor((inputs, canvasContent) => {
1189+
const aspectRatio = inputs.canvasSize.x / inputs.canvasSize.y;
1190+
const pixelSize = [pixelCountX, pixelCountX / aspectRatio];
1191+
1192+
let coord = inputs.texCoord;
1193+
coord = p5.floor(coord * pixelSize) / pixelSize;
1194+
1195+
let col = p5.getTexture(canvasContent, coord);
1196+
return col//[coord, 0, 1];
1197+
});
1198+
}
1199+
1200+
function bloomShaderCallback({ p5, originalImage }) {
1201+
const preBlur = p5.uniformTexture(() => originalImage);
1202+
1203+
getColor((input, canvasContent) => {
1204+
const blurredCol = p5.getTexture(canvasContent, input.texCoord);
1205+
const originalCol = p5.getTexture(preBlur, input.texCoord);
1206+
1207+
const intensity = p5.max(originalCol, 0.1) * 12.2;
1208+
1209+
const bloom = originalCol + blurredCol * intensity;
1210+
return [bloom.rgb, 1];
1211+
});
1212+
}
1213+
1214+
p5.createCanvas(200, 200, p5.WEBGL);
1215+
const stars = p5.buildGeometry(() => p5.sphere(4, 4, 2))
1216+
const originalImage = p5.createFramebuffer();
1217+
1218+
function fresnelShaderCallback({ p5 }) {
1219+
const fresnelPower = p5.uniformFloat(2);
1220+
const fresnelBias = p5.uniformFloat(-0.1);
1221+
const fresnelScale = p5.uniformFloat(2);
1222+
1223+
p5.getCameraInputs((inputs) => {
1224+
let n = p5.normalize(inputs.normal);
1225+
let v = p5.normalize(-inputs.position);
1226+
let base = 1.0 - p5.dot(n, v);
1227+
let fresnel = fresnelScale * p5.pow(base, fresnelPower) + fresnelBias;
1228+
let col = p5.mix([0, 0, 0], [1, .5, .7], fresnel);
1229+
inputs.color = [col, 1];
1230+
return inputs;
1231+
});
1232+
}
1233+
1234+
const starShader = p5.baseMaterialShader().modify(starShaderCallback, { p5 });
1235+
const starStrokeShader = p5.baseStrokeShader().modify(starShaderCallback, { p5 })
1236+
const fresnelShader = p5.baseColorShader().modify(fresnelShaderCallback, { p5 });
1237+
const bloomShader = p5.baseFilterShader().modify(bloomShaderCallback, { p5, originalImage });
1238+
const pixelateShader = p5.baseFilterShader().modify(pixelateShaderCallback, { p5 });
1239+
1240+
originalImage.begin();
1241+
p5.background(0);
1242+
1243+
p5.push()
1244+
p5.strokeWeight(2)
1245+
p5.stroke(255,0,0)
1246+
p5.fill(255,100, 150)
1247+
p5.strokeShader(starStrokeShader)
1248+
p5.shader(starShader);
1249+
p5.model(stars, 100);
1250+
p5.pop()
1251+
1252+
p5.push()
1253+
p5.shader(fresnelShader)
1254+
p5.noStroke()
1255+
p5.sphere(30);
1256+
p5.filter(pixelateShader);
1257+
p5.pop()
1258+
1259+
originalImage.end();
1260+
1261+
p5.imageMode(p5.CENTER)
1262+
p5.image(originalImage, 0, 0)
1263+
1264+
p5.filter(p5.BLUR, 5)
1265+
p5.filter(bloomShader);
1266+
1267+
screenshot();
1268+
});
11471269
});
11481270

11491271
visualSuite('background()', function () {
@@ -1316,7 +1438,7 @@ visualSuite('WebGL', function() {
13161438
visualSuite('Tessellation', function() {
13171439
visualTest('Handles nearly identical consecutive vertices', function(p5, screenshot) {
13181440
p5.createCanvas(400, 400, p5.WEBGL);
1319-
1441+
13201442
const contours = [
13211443
[
13221444
[-3.8642425537109375, -6.120738636363637, 0],
@@ -1355,7 +1477,7 @@ visualSuite('WebGL', function() {
13551477
[-1.8045834628018462, 4.177556818181818, 0]
13561478
]
13571479
];
1358-
1480+
13591481
p5.background('red');
13601482
p5.push();
13611483
p5.stroke(0);

0 commit comments

Comments
 (0)