|
8 | 8 | from tinygrad.helpers import PCONTIG, FLOAT16, OPENPILOT_HACKS, argsort, partition, get_single_element |
9 | 9 | from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify |
10 | 10 | from tinygrad.codegen.opt import Opt |
11 | | -from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op |
| 11 | +from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, IndexingContext, apply_movement_op |
12 | 12 | from tinygrad.schedule.multi import multi_pm |
13 | 13 | from tinygrad.schedule.allreduce import create_allreduce_function |
14 | 14 |
|
@@ -71,8 +71,11 @@ def found_assign(ctx:dict[UOp, UOp], assign:UOp, src:UOp): |
71 | 71 | def fix_store_after_hazard(after:UOp, target:UOp, src:UOp): |
72 | 72 | # PERMUTE and FLIP reorder indices, SHRINK can have overlapping regions when dest is also shrunk |
73 | 73 | unsafe = {Ops.PERMUTE, Ops.FLIP} | ({Ops.SHRINK} if target.op_in_backward_slice_with_self(Ops.SHRINK) else set()) |
74 | | - if any(s.op in unsafe and target.base in s.backward_slice for s in src.toposort(gate=lambda s:s.op not in ALWAYS_CONTIGUOUS or s.op is Ops.AFTER)): |
75 | | - return after.replace(src=(after.src[0], target.store(src.contiguous()))) |
| 74 | + base = target.base |
| 75 | + reaches_base: dict[UOp, bool] = {} |
| 76 | + for s in src.toposort(gate=lambda s: s.op is not Ops.CONTIGUOUS): |
| 77 | + reaches_base[s] = s is base or any(reaches_base.get(c) for c in s.src) |
| 78 | + if reaches_base[s] and s.op in unsafe: return after.replace(src=(after.src[0], target.store(src.contiguous()))) |
76 | 79 |
|
77 | 80 | def normalize_store_after_target_chain(after:UOp, target:UOp, src:UOp): |
78 | 81 | root_target = target |
|
0 commit comments