Add fake tensor#32
Conversation
|
@ezyang any other thoughts ? particularly re: |
|
@anjali411 operator tagging would be one way to get the correct semantics without having to code operators manually here |
|
|
||
| @property | ||
| def device(self): | ||
| return self.fake_device |
There was a problem hiding this comment.
As mentioned in chats, this would be insufficient for reporting CPU/CUDA device in C++, which matters if you're sending this tensor through C++ code that does device tests. This might get us pretty far for now though, so I wouldn't block on it.
There was a problem hiding this comment.
Does the new CustomPython TensorImpl policy also covers the device attribute?
|
Re testing, a first step is to wire this up the same way as done in pytorch/pytorch#77008 ; the torture test is pytorch/pytorch#75994 (but do the OpInfo stuff first) |
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
| # Run the original computation | ||
| r = super().__torch_dispatch__(func, types, args, kwargs) |
| # if device is specified, use that | ||
| # not sure this is actually needed.. device only shows up in constructors | ||
| if kwargs.get("device", None): | ||
| return tree_map(partial(wrap, device=kwargs["device"]), r) |
There was a problem hiding this comment.
This rule is not "sound", in the sense that someone could add an operator with device kwarg used in a non-standard way which would break the calculation here. We can keep doing this cheap and cheerful calculation though if we added a tag; we'd just go through and tag all the device operators with a tag "yes the output tensor matches what device you pass in here" LOL. Is this the right ontology? Probably not. But it's easy and gives a bit more safety.
|
Regarding modes, cc'ing @zou3519 and @samdow about it. Now that we have proper APIs for modes we have a code duplication problem, where mode dispatch and tensor dispatch have potentially the same code; how exactly should the code be shared between these two things? There's also a degree of freedom in how you do the implementation: modes get handled before subclass dispatch, but the mode could handle getting rid of subclasses so the subclass dispatch never gets involved. Here's my first guess for what we should do:
|
| # TODO: https://github.com/pytorch/pytorch/pull/77182 | ||
| _, new_kwargs = normalize_function( | ||
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | ||
| ) |
There was a problem hiding this comment.
Should we be recommending people generally use the FX normalize function? It still seems a bit different from how I personally feel idiomatic handling for functions should be done (via defining a function, and then hooking it up via some register decorator).
There was a problem hiding this comment.
I originally thought I was going to be handling more than just resize_as_ (and I still might, e.g. when this is removed). When you only care about one argument input it would be easier to handle these functions in bulk than separately defining each operator, and also means that if the signature changes you dont also need to change the signature of your function.
Would also save having to specify defaults as in https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py#L99
And for functions intercepted at the __torch_function__ level, could potentially handle some of the implicit conversions that currently before the kernel is invoked - e.g. things like broadcasting lists, and allowing tensor(0) for int arguments
| _, new_kwargs = normalize_function( | ||
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | ||
| ) | ||
| return tree_map(partial(wrap, device=new_kwargs["other"].device), r) |
There was a problem hiding this comment.
It's a composite, that's why it doesn't get called
- func: type_as(Tensor self, Tensor other) -> Tensor
variants: method
There was a problem hiding this comment.
I guess I just need to define __torch_function__ to deal with some of the composite operators
There was a problem hiding this comment.
Why? Shouldn't they just be properly handled by the natural decomposition?
There was a problem hiding this comment.
Because meta_tensor.type_as(meta_tensor) short-circuits, they both have the same type
There was a problem hiding this comment.
I will probably also need to handle all variants of aten::to ..
There was a problem hiding this comment.
I think this is the part where the "is-a" is a problem :( But hopefully the new TensorImpl override should help.
There was a problem hiding this comment.
Do you have any details on this ?
There was a problem hiding this comment.
Work like pytorch/pytorch#77396 (for device/type) should allow us to intercept all the calls on Tensors. And in particular avoid this short circuit?
| return t.device == torch.device("cpu") and t.dim() == 0 | ||
|
|
||
| # cpu - zero-dim tensors can be called in cuda kernels, | ||
| # so overwrite cuda kernels |
There was a problem hiding this comment.
Only for some kernels! Tagging would help.
| def __init__(self, elem, device: Union[torch.device, str]): | ||
| # elem does not need to be recorded, because FakeTensor *is a* elem | ||
| assert elem.device.type == "meta" | ||
| device if isinstance(device, torch.device) else torch.device(device) |
|
|
||
| @property | ||
| def device(self): | ||
| return self.fake_device |
There was a problem hiding this comment.
Does the new CustomPython TensorImpl policy also covers the device attribute?
| _, new_kwargs = normalize_function( | ||
| func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True | ||
| ) | ||
| return tree_map(partial(wrap, device=new_kwargs["other"].device), r) |
There was a problem hiding this comment.
I think this is the part where the "is-a" is a problem :( But hopefully the new TensorImpl override should help.
| self.assertEqual(y.shape, (4, 2, 2)) | ||
| self.assertEqual(y.device, torch.device("cpu")) | ||
|
|
||
| @unittest.skip("Waiting on https://github.com/pytorch/pytorch/pull/77182") |
| self.assertEqual(out.shape, (8, 8)) | ||
| self.assertEqual(out.device, torch.device("cpu")) | ||
|
|
||
| def test_zero_dim(self): |
There was a problem hiding this comment.
nit: could you add a skip if cuda is not available to make sure this file runs on CPU-only machines?
|
cc @anjali411 |
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467) [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. From code comments > Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. FakeTensor extends MetaTensors to also carry an additional `fake_device` which tracks devices that would have been used. [ghstack-poisoned]
Adds a fake tensor, which augments meta tensors with a device, and does device propagation on operators.
Miscellaneous notes:
FakeModewhich will cover constructors. Either withenable_python_modeorTorchFunctionModestill need to get clarity there.type_asdoesn't get invoked (probably a peephole because they're both meta under the hood) is this a sign that this should be composition instead of inheritance, or that I need to intercept at higher level of dispatcher ? (Edit: define__torch_function__to handle some composite operators ?)resize_as_does because it's aCompositeExplicitAutogradbutexpand_asis not.