diff --git a/example.py b/example.py index 21033c15..acbbca95 100644 --- a/example.py +++ b/example.py @@ -10,7 +10,7 @@ from trellis.utils import render_utils, postprocessing_utils # Load a pipeline from a model folder or a Hugging Face model hub. -pipeline = TrellisImageTo3DPipeline.from_pretrained("microsoft/TRELLIS-image-large") +pipeline = TrellisImageTo3DPipeline.from_pretrained("jetx/trellis-image-large") pipeline.cuda() # Load an image diff --git a/trellis/modules/sparse/attention/full_attn.py b/trellis/modules/sparse/attention/full_attn.py index e9e27aeb..99e6516e 100755 --- a/trellis/modules/sparse/attention/full_attn.py +++ b/trellis/modules/sparse/attention/full_attn.py @@ -194,7 +194,7 @@ def sparse_scaled_dot_product_attention(*args, **kwargs): q = q.unsqueeze(0) k = k.unsqueeze(0) v = v.unsqueeze(0) - mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) out = xops.memory_efficient_attention(q, k, v, mask)[0] elif ATTN == 'flash_attn': cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) diff --git a/trellis/modules/sparse/attention/serialized_attn.py b/trellis/modules/sparse/attention/serialized_attn.py index 5950b75b..ca37a162 100755 --- a/trellis/modules/sparse/attention/serialized_attn.py +++ b/trellis/modules/sparse/attention/serialized_attn.py @@ -177,7 +177,7 @@ def sparse_serialized_scaled_dot_product_self_attention( q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(seq_lens) out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] elif ATTN == 'flash_attn': cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ diff --git a/trellis/modules/sparse/attention/windowed_attn.py b/trellis/modules/sparse/attention/windowed_attn.py index cd642c52..a704c9ea 100755 --- a/trellis/modules/sparse/attention/windowed_attn.py +++ b/trellis/modules/sparse/attention/windowed_attn.py @@ -119,7 +119,7 @@ def sparse_windowed_scaled_dot_product_self_attention( q = q.unsqueeze(0) # [1, M, H, C] k = k.unsqueeze(0) # [1, M, H, C] v = v.unsqueeze(0) # [1, M, H, C] - mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(seq_lens) out = xops.memory_efficient_attention(q, k, v, mask)[0] # [M, H, C] elif ATTN == 'flash_attn': cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \