From dd5599500be87fb90f85f6154859c5a71735c7f0 Mon Sep 17 00:00:00 2001 From: Yuhan Date: Fri, 23 May 2025 13:51:45 -0400 Subject: [PATCH 1/2] Adjust the relative path of BlockDiagonalMask to support all versions of transformers. --- trellis/modules/sparse/attention/full_attn.py | 2 +- trellis/modules/sparse/attention/serialized_attn.py | 2 +- trellis/modules/sparse/attention/windowed_attn.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py index e9e27aeb..99e6516e 100755 --- a/trellis/modules/sparse/attention/full_attn.py +++ b/trellis/modules/sparse/attention/full_attn.py @@ -194,7 +194,7 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): q = q.unsqueeze(0) k = k.unsqueeze(0) v = v.unsqueeze(0) - mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) out = xops.memory_efficient_attention(q, k, v, mask)[0] elif ATTN == 'flash_attn': cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py index 5950b75b..ca37a162 100755 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -177,7 +177,7 @@ def sparse_serialized_scaled_dot_product_self_attention( q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(seq_lens) out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] elif ATTN == 'flash_attn': cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py index cd642c52..a704c9ea 100755 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -119,7 +119,7 @@ def sparse_windowed_scaled_dot_product_self_attention( q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(seq_lens) out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] elif ATTN == 'flash_attn': cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ From cfd0d52cddb7e338ec7ca70f0fb25ca365b8181f Mon Sep 17 00:00:00 2001 From: Yuhan Date: Fri, 23 May 2025 13:52:03 -0400 Subject: [PATCH 2/2] Update the huggingface repo in example.py. --- example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example.py b/example.py index 21033c15..acbbca95 100644 --- a/example.py +++ b/example.py @@ -10,7 +10,7 @@ from trellis.utils import render_utils, postprocessing_utils # Load a pipeline from a model folder or a Hugging Face model hub. -pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") +pipeline = TrellisImageTo3DPipeline.from_pretrained("jetx/trellis-image-large") pipeline.cuda() # Load an image