Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Relu allocates extra memory for storing broadcasted zero tensor #91

Open
mmengjiadai opened this issue Oct 12, 2023 · 0 comments
Open
Assignees
Labels
bug Something isn't working

Comments

@mmengjiadai
Copy link
Contributor

Describe the bug
allo.relu allocates an extra memory block for storing the zero tensor to which the input is compared.

To Reproduce
Run mlp.py without enable tensor and print the intermediate_module. The last part of output is

%alloc_50 = memref.alloc() : memref<30x30xf32>
    %c0_51 = arith.constant 0 : index
    %c30_52 = arith.constant 30 : index
    %c1_53 = arith.constant 1 : index
    scf.for %arg1 = %c0_51 to %c30_52 step %c1_53 {
      %c0_57 = arith.constant 0 : index
      %c30_58 = arith.constant 30 : index
      %c1_59 = arith.constant 1 : index
      scf.for %arg2 = %c0_57 to %c30_58 step %c1_59 {
        memref.store %cst, %alloc_50[%arg1, %arg2] : memref<30x30xf32>
      }
    }
    %c0_54 = arith.constant 0 : index
    %c30_55 = arith.constant 30 : index
    %c1_56 = arith.constant 1 : index
    scf.for %arg1 = %c0_54 to %c30_55 step %c1_56 {
      %c0_57 = arith.constant 0 : index
      %c30_58 = arith.constant 30 : index
      %c1_59 = arith.constant 1 : index
      scf.for %arg2 = %c0_57 to %c30_58 step %c1_59 {
        %4 = memref.load %alloc_39[%arg1, %arg2] : memref<30x30xf32>
        %5 = memref.load %alloc_50[%arg1, %arg2] : memref<30x30xf32>
        %6 = arith.maxf %4, %5 : f32
        memref.store %6, %alloc_46[%arg1, %arg2] : memref<30x30xf32>
      }
    }

%cst, a 0 constant, is broadcasted into %alloc_50. This allocation is not needed.

Expected behavior
Currently added feature allo.max to dsl.py:

def max(x, y):
    return np.maximum(x, y)

and added attribute max to builder.py:

 if isinstance(arg_type, (F32Type, IntegerType)):
                opcls = {
                    "exp": math_d.ExpOp,
                    "log": math_d.LogOp,
                    "log2": math_d.Log2Op,
                    "log10": math_d.Log10Op,
                    "sqrt": math_d.SqrtOp,
                    "sin": math_d.SinOp,
                    "cos": math_d.CosOp,
                    "tan": math_d.TanOp,
                    "tanh": math_d.TanhOp,
                    "power": math_d.PowFOp,
                    "max": arith_d.MaxFOp,
                }.get(fn_name)

When implementing relu using for loops, there is no overhead.

for i, j in allo.grid(30, 30):
            C[i, j] = allo.max(C[i, j], 0.0)
@mmengjiadai mmengjiadai added the bug Something isn't working label Oct 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants