Skip to content

Commit c23a620

Browse files
3l1facebook-github-bot
authored andcommitted
Add multi-reader tests for add/sub ifm scaling (#18758)
Summary: Add AddMultiReader and SubMultiReader test models (conv2(conv1(x)) +/- conv3(conv1(x))) where conv1's output Rescale has two readers. Differential Revision: D99939008
1 parent 0fc08b0 commit c23a620

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

backends/arm/test/ops/test_add.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,52 @@ def test_add_dual_conv_u85_INT(test_data: input_t1):
363363
pipeline.run()
364364

365365

366+
class AddMultiReader(torch.nn.Module):
367+
"""Conv2(conv1(x)) + conv3(conv1(x)) — conv1's output Rescale has two
368+
readers.
369+
"""
370+
371+
def __init__(self):
372+
super().__init__()
373+
self.conv1 = torch.nn.Conv2d(3, 3, 1, bias=False)
374+
self.conv2 = torch.nn.Conv2d(3, 3, 1, bias=False)
375+
self.conv3 = torch.nn.Conv2d(3, 3, 1, bias=False)
376+
377+
def forward(self, x):
378+
y = self.conv1(x)
379+
return self.conv2(y) + self.conv3(y)
380+
381+
test_data = {
382+
"4d_randn": lambda: (torch.randn(1, 3, 4, 4),),
383+
}
384+
385+
386+
@common.parametrize("test_data", AddMultiReader.test_data)
387+
def test_add_multi_reader_tosa_INT(test_data: input_t1):
388+
pipeline = TosaPipelineINT[input_t1](
389+
AddMultiReader(), test_data(), aten_op, exir_op
390+
)
391+
pipeline.run()
392+
393+
394+
@common.parametrize("test_data", AddMultiReader.test_data)
395+
@common.XfailIfNoCorstone300
396+
def test_add_multi_reader_u55_INT(test_data: input_t1):
397+
pipeline = EthosU55PipelineINT[input_t1](
398+
AddMultiReader(), test_data(), aten_op, exir_op
399+
)
400+
pipeline.run()
401+
402+
403+
@common.parametrize("test_data", AddMultiReader.test_data)
404+
@common.XfailIfNoCorstone320
405+
def test_add_multi_reader_u85_INT(test_data: input_t1):
406+
pipeline = EthosU85PipelineINT[input_t1](
407+
AddMultiReader(), test_data(), aten_op, exir_op
408+
)
409+
pipeline.run()
410+
411+
366412
@common.parametrize("test_data", Add.test_data)
367413
def test_add_tensor_tosa_INT_16a8w(test_data: input_t1):
368414
"""Test add operation with 16A8W quantization (16-bit activations, 8-bit

backends/arm/test/ops/test_sub.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,50 @@ def test_sub_dual_conv_u85_INT(test_data: input_t1):
394394
pipeline.run()
395395

396396

397+
class SubMultiReader(torch.nn.Module):
398+
"""conv2(conv1(x)) - conv3(conv1(x)) — conv1's output Rescale has two readers."""
399+
400+
def __init__(self):
401+
super().__init__()
402+
self.conv1 = torch.nn.Conv2d(3, 3, 1, bias=False)
403+
self.conv2 = torch.nn.Conv2d(3, 3, 1, bias=False)
404+
self.conv3 = torch.nn.Conv2d(3, 3, 1, bias=False)
405+
406+
def forward(self, x):
407+
y = self.conv1(x)
408+
return self.conv2(y) - self.conv3(y)
409+
410+
test_data = {
411+
"4d_randn": lambda: (torch.randn(1, 3, 4, 4),),
412+
}
413+
414+
415+
@common.parametrize("test_data", SubMultiReader.test_data)
416+
def test_sub_multi_reader_tosa_INT(test_data: input_t1):
417+
pipeline = TosaPipelineINT[input_t1](
418+
SubMultiReader(), test_data(), aten_op, exir_op
419+
)
420+
pipeline.run()
421+
422+
423+
@common.parametrize("test_data", SubMultiReader.test_data)
424+
@common.XfailIfNoCorstone300
425+
def test_sub_multi_reader_u55_INT(test_data: input_t1):
426+
pipeline = EthosU55PipelineINT[input_t1](
427+
SubMultiReader(), test_data(), aten_op, exir_op
428+
)
429+
pipeline.run()
430+
431+
432+
@common.parametrize("test_data", SubMultiReader.test_data)
433+
@common.XfailIfNoCorstone320
434+
def test_sub_multi_reader_u85_INT(test_data: input_t1):
435+
pipeline = EthosU85PipelineINT[input_t1](
436+
SubMultiReader(), test_data(), aten_op, exir_op
437+
)
438+
pipeline.run()
439+
440+
397441
@common.parametrize("test_data", sub_test_data)
398442
def test_sub_tensor_16a8w_tosa_INT(test_data: input_t1):
399443
"""Test sub operation with 16A8W quantization (16-bit activations, 8-bit

0 commit comments

Comments
 (0)