Skip to content

Commit

Permalink
Support changing the input layout
Browse files Browse the repository at this point in the history
  • Loading branch information
reillyeon committed Apr 30, 2024
1 parent f078c04 commit 3125578
Showing 1 changed file with 43 additions and 18 deletions.
61 changes: 43 additions & 18 deletions webnn-conv2d.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,22 @@
<body>
<p>
<label for="deviceType">Device preference:</label>
<select id="deviceType">
<select id="deviceType" disabled>
<option selected value="cpu">CPU</option>
<option value="gpu">GPU</option>
<option value="npu">NPU</option>
</select>
</p>
<p>
<label for="filterSize">Filter size:</label>
<input id="filterSize" type="number" min="1" max="50" value="5">
<input id="filterSize" type="number" min="1" max="50" value="5" disabled>
</p>
<p>
<label for="inputLayout">Input layout:</label>
<select id="inputLayout" disabled>
<option value="nchw">Channels first (NCHW)</option>
<option selected value="nhwc">Channels last (NHWC)</option>
</select>
</p>
<table>
<tr><th>Input</th><th>Output</th></tr>
Expand All @@ -28,21 +35,39 @@
let inputData;

async function createGraph(context) {
const inputLayout = inputLayoutElement.value;
const builder = new MLGraphBuilder(context);
const filterWidth = Number(filterSizeElement.value);
const filterHeight = Number(filterSizeElement.value);
const input = builder.input(

let input = builder.input(
'input', {dataType: 'float32', dimensions: [1, 500, 500, channels]});
const filterData = new Float32Array(filterWidth * filterHeight * channels);
filterData.fill(1 / (filterWidth * filterHeight));
if (inputLayout == 'nchw') {
input = builder.transpose(input, {permutation: [0, 3, 1, 2]})
}

// Right now Chromium only supports one filter layout for each input layout.
const filterHeight = Number(filterSizeElement.value);
const filterWidth = Number(filterSizeElement.value);
const filterLayout = inputLayout == 'nchw' ? 'oihw' : 'ihwo';
const filterShape =
filterLayout == 'oihw' ?
[channels, 1, filterHeight, filterWidth] :
[1, filterHeight, filterWidth, channels]

// A simple blur filter is easy because the layout doesn't matter, the
// elements simply have to sum to 1.
const filterData = new Float32Array(filterHeight * filterWidth * channels);
filterData.fill(1 / (filterHeight * filterWidth));
const filter = builder.constant(
{dataType: 'float32', dimensions: [1, filterWidth, filterHeight, channels]},
filterData);
const output = builder.conv2d(input, filter, {
inputLayout: 'nhwc',
filterLayout: 'ihwo', // IHWO is required for depthwise convolution.
{dataType: 'float32', dimensions: filterShape}, filterData);

let output = builder.conv2d(input, filter, {
inputLayout, filterLayout,
groups: channels, // Convolve each input channel with its own filter.
});
if (inputLayout == 'nchw') {
output = builder.transpose(output, {permutation: [0, 2, 3, 1]})
}

return {
graph: await builder.build({'output': output}),
outputHeight: output.shape()[1],
Expand Down Expand Up @@ -92,14 +117,13 @@
}

const deviceTypeElement = document.getElementById('deviceType');
deviceTypeElement.onchange = () => {
run();
};
deviceTypeElement.onchange = run

const filterSizeElement = document.getElementById('filterSize');
filterSizeElement.onchange = () => {
run();
};
filterSizeElement.onchange = run

const inputLayoutElement = document.getElementById('inputLayout');
inputLayoutElement.onchange = run

const image = new Image();
image.onload = () => {
Expand All @@ -110,6 +134,7 @@

deviceTypeElement.disabled = false;
filterSizeElement.disabled = false;
inputLayoutElement.disabled = false;
run();
};
image.src = 'photo.jpg';
Expand Down

0 comments on commit 3125578

Please sign in to comment.