Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue]: Triton Compiler Takes Indefinite Time in ttgir -> llir Stage. #596

Closed
xinyazhang opened this issue Jun 3, 2024 · 20 comments
Closed
Assignees

Comments

@xinyazhang
Copy link

Problem Description

Full source code to reproduce:
rep.py.gz

Triton version: upstream d688063f731cfc4d9431bb8c0d0d73dce8cd1c38
Docker Container: rocm/pytorch-private:compute-rocm-rel-6.1-116_ubuntu22.04_py3.9_pytorch_rocm6.1_internal_testing_ae01701

Can be reproduced in both MI200(gfx90a) and Navi3x. Debugging print shows the compiler hangs during ttgir->llir stage.

Operating System

Ubuntu 22.04.4 LTS (Jammy Jellyfish)

CPU

AMD Ryzen Threadripper PRO 5975WX 32-Cores

GPU

AMD Instinct MI210

ROCm Version

ROCm 6.1.0

ROCm Component

No response

Steps to Reproduce

Download the rep.py.gz in the Description section, and then

gunzip rep.py.gz
python rep.py

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

@xinyazhang
Copy link
Author

The following passes do not exist in the newer code

#ifdef USE_ROCM
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createConvertControlFlowToLLVMPass());
#endif

Maybe we can try to add them to the make_llir function to see if it can fix the problem.

@zhanglx13 zhanglx13 assigned zhanglx13 and unassigned jayfurmanek Jun 5, 2024
@zhanglx13
Copy link

@giuseros
Copy link

giuseros commented Jun 10, 2024

I am working on this, because it looks (very) slightly simpler than https://github.com/ROCm/triton-internal/issues/104

This is what I got so far:

  • The bug (as mentioned previously) originates from convert-builtin-func-to-llvm
  • The main issue for this specific case, is about the stores. I made a reproduction bug.mlir which is the file just before convert-builtin-func-to-llvm (attached to this comment) and I commented out all the loads and some of the stores. While triton-opt terminates, the output produced is massively large. The more stores we add back into bug.mlir, the more time it takes to complete (I think that if we leave it long enough it will eventually complete)
  • The source of the issue looks like the mergeIdenticalBlocks transformation contained in the simplifyRegion utility. If I disable that transformation enableRegionSimplify=false then compilation is quite quick.
  • I produced a disable_simplify.mlir output that comes from bug.mlir when passing enableRegionSimplify=false to the rewriter. If we do: triton-opt --canonicalize disable_simplify.mlir we see the same masssive output as befreo with triton-opt that takes some time to finish. Instead, if we do: triton-opt --canonicalize="region-simplify=false" disable_simplify.mlir the output is normal, and triton-opt terminates quickly.

repro.zip

@jayfurmanek
Copy link
Collaborator

Another note that might help
On the repro script, if I break the nested if below on line 367 (by just deleting the else there), then it doesn't hang.
Perhaps this is related to stores in nested-if statements.

    if q_padded:
        if PADDED_HEAD:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
        else:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,))
    else:
        if PADDED_HEAD:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(1,))
        else:
            tl.store(O_block_ptr, acc.to(Out.type.element_ty))

@jayfurmanek
Copy link
Collaborator

jayfurmanek commented Jun 10, 2024

I guess the comment here bascially confirms that

  # This pass (`add_builtin_func_to_llvmir`) serves as a temporary workaround to address the issue of excessive basic block
        # count caused by predicated loads/stores. In certain kernels, the addition of these blocks can cause the MLIR
        # canonicalizer to never finish when attempting to merge blocks. The permanent solution under consideration
        # involves using MUBUF instructions that have built-in out-of-bounds checks, which would eliminate the need
        # for conditional branching around memory accesses.

@giuseros
Copy link

So yes, I was aware of this comment, but @antiagainst was asking if there could be a simpler solution than implementing buffer loads. I guess the main question is :is this a bug or is this unavoidable because of so many blocks?

Should you, me and @antiagainst have a chat to decide the best step forward?

@giuseros
Copy link

giuseros commented Jun 10, 2024

So this is the situation we have in the CFG (cc @antiagainst ):
image

I think the problem is that MLIR is trying to produce a single big if-block where to put all those subgraphs

@giuseros
Copy link

So, I think all in all this is a correct transformation, also in our case. What happens is that we meet the following cases:

Store case

leader block:
^bb152:  // pred: ^bb151
  llvm.store %3316, %3245 : i16, !llvm.ptr<1>
  llvm.br ^bb153
blocks to merge:
^bb181:  // pred: ^bb180
  "llvm.store"(%3409, %3375) <{ordering = 0 : i64}> : (i16, !llvm.ptr<1>) -> ()
  "llvm.br"()[^bb153] : () -> ()

In this case those blocks can be merged, and the merged block will have +2 operands

Insertelement case

^bb151:  // 2 preds: ^bb149, ^bb150
  %3315 = llvm.insertelement %3148, %3251[%60 : i32] : vector<1xf16>
  %3316 = llvm.bitcast %3315 : vector<1xf16> to i16
  llvm.cond_br %3254, ^bb152(%3316, %3245 : i16, !llvm.ptr<1>), ^bb153
blocks to merge:
^bb180:  // 2 preds: ^bb178, ^bb179
  %3412 = "llvm.insertelement"(%3385, %3336, %60) : (vector<1xf16>, f16, i32) -> vector<1xf16>
  %3413 = "llvm.bitcast"(%3412) : (vector<1xf16>) -> i16
  "llvm.cond_br"(%3384, %3413, %3379)[^bb152, ^bb153] <{operandSegmentSizes = array<i32: 1, 2, 0>}> : (i1, i16, !llvm.ptr<1>) -> ()

In this case the blocks are still structurally similar, but we are doubling the number of input operands of the merged block. When we do that 64 times, we get to blocks that have 32764 input operands which is very slow to handle.

Possible (quick) workaround

We can introduce a threshold: don't merge the blocks if this results in more than K (defaulted to 16?) input operands in the resulting block

@jayfurmanek
Copy link
Collaborator

A threshold is a good idea, I think.
Where would we implement the threshold? In the canonicalizer?

@giuseros
Copy link

Yes, we can have an option like maxBlockArguments in the canonicalizer pass defaulted to 16. I tried to hardcode that and indeed it works fine. I will try to update a patch.

I want also underline that by not merging those blocks we are creating a super branchy code that will probably be very slow. So once I implement this, I will try to finish the buffer_load implementation

@antiagainst
Copy link

Yup agreed that having a threshold in the greedy pattern rewriter configuration to control this would be good. Once you have the patch to mlir please add me as a reviewer.

@giuseros
Copy link

giuseros commented Jun 11, 2024

They were faster than me :) : llvm/llvm-project#95057

Not sure if the threshold solution is better or not, but I commented on the PR instead of creating a different one

(note, once the PR is merged, we should upgrade Triton commit to get the change)

@jerryyin
Copy link
Member

@giuseros Have you verified the upstream PR will address the two use cases? This ticket and https://github.com/ROCm/triton-internal/issues/104

@giuseros
Copy link

Yes, it disables block merging on canonicalization that is the root cause of both.

@giuseros
Copy link

giuseros commented Jun 12, 2024

Update on this: they made a further change (or the change was there and it skipped my eye) for which they now enable block-merging in the rewriter. If they stick with that, we will have the hang (see llvm/llvm-project#95057 (comment))

Either we convince them to disable merging into the rewriter, or I will have (urgently) to implement this:

@giuseros
Copy link

giuseros commented Jun 12, 2024

After thinking about this, I guess we can set:

 GreedySimplifyRegionLevel enableRegionSimplification =      GreedySimplifyRegionLevel::Normal;

When we instantiate the rewriter. And meanwhile I can work on llvm/llvm-project#63230 to solve the core issue.

@zhanglx13
Copy link

I was trying to follow you discussion with Mehdi on that upstream PR. What does it mean by "they disable block merging for canonicalization but enable it for rewriter"?

@giuseros
Copy link

Both the canonicalize pass and the rewriter use the simplifyRegions function. The solution Mehdi is proposing is to default to simplifyRegions(normal) in the canonicalize pass (block merging disabled) and simplifyRegions(aggressive) in the rewriter (block merging enabled -> hang). We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal), but this means setting passing a config every time we invoke it (with config.enableRegionSimplification =Normal )

@zhanglx13
Copy link

Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?

We can change the default behaviour of the rewriter in Triton (so that it calls simplifyRegions(normal), but this means setting passing a config every time we invoke it (with config.enableRegionSimplification =Normal )

We only need to set it for the rewriter in builtin_func_to_llvm pass. right? If so, are there any other drawbacks ?

@giuseros
Copy link

Does the rewriter call simplifyRegions (and probably other passes to canonicalize stuff) after it matches and rewrites all the ops?

Yes

We only need to set it for the rewriter in builtin_func_to_llvm pass. right? If so, are there any other drawbacks ?

And anytime we invoke the rewriter after that. I see that builtin_func_to_llvm is the last pass, so it shouldn't be an issue.

Of course there is the core drawback that we will disable block merging in all cases. But this is something we can worry later (and I will try to work on it in my "spare" time)

xinyazhang added a commit to ROCm/aotriton that referenced this issue Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants