Skip to content

Add fake tensor#32

Open
eellison wants to merge 3 commits into
albanD:mainfrom
eellison:add_fake_tensor
Open

Add fake tensor#32
eellison wants to merge 3 commits into
albanD:mainfrom
eellison:add_fake_tensor

Conversation

@eellison

@eellison eellison commented May 10, 2022

Copy link
Copy Markdown

Adds a fake tensor, which augments meta tensors with a device, and does device propagation on operators.

Miscellaneous notes:

  • I still need to add FakeMode which will cover constructors. Either with enable_python_mode or TorchFunctionMode still need to get clarity there.
  • type_as doesn't get invoked (probably a peephole because they're both meta under the hood) is this a sign that this should be composition instead of inheritance, or that I need to intercept at higher level of dispatcher ? (Edit: define __torch_function__ to handle some composite operators ?)
  • I assume you want to be able to run python code which calls into .device, so I overrided device in python, but I'm not sure if that's sufficient.
  • The code currently doesn't distinguish between ops which allow 0-dim cpu tensors to convert to cuda and those which don't. It should always give correct answers for operators which would have run successfully, but might allow 0-dim cpu conversion for operators which would have otherwise thrown. TODO: use full operator tagging
  • resize_as calls into https://github.com/pytorch/pytorch/blob/master/torch/autograd/_functions/tensor.py#L28 and fails
  • It wasn't particularly clear to me which operators will get seen by this class. I guess resize_as_ does because it's a CompositeExplicitAutograd but expand_as is not.

@eellison

eellison commented May 10, 2022

Copy link
Copy Markdown
Author

cc @albanD @ezyang @Chillee

Comment thread fake_tensor.py Outdated
@eellison

Copy link
Copy Markdown
Author

@ezyang any other thoughts ? particularly re: type_as problem ? do you think that is solvable via inheritance or is it a good sign I should be using composition instead ?

@ezyang

ezyang commented May 10, 2022

Copy link
Copy Markdown
Collaborator

@anjali411 operator tagging would be one way to get the correct semantics without having to code operators manually here

Comment thread fake_tensor.py

@property
def device(self):
return self.fake_device

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in chats, this would be insufficient for reporting CPU/CUDA device in C++, which matters if you're sending this tensor through C++ code that does device tests. This might get us pretty far for now though, so I wouldn't block on it.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the new CustomPython TensorImpl policy also covers the device attribute?

@ezyang

ezyang commented May 10, 2022

Copy link
Copy Markdown
Collaborator

Re testing, a first step is to wire this up the same way as done in pytorch/pytorch#77008 ; the torture test is pytorch/pytorch#75994 (but do the OpInfo stuff first)

Comment thread fake_tensor.py
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# Run the original computation
r = super().__torch_dispatch__(func, types, args, kwargs)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big benefit of is-a :>

Comment thread fake_tensor.py
# if device is specified, use that
# not sure this is actually needed.. device only shows up in constructors
if kwargs.get("device", None):
return tree_map(partial(wrap, device=kwargs["device"]), r)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule is not "sound", in the sense that someone could add an operator with device kwarg used in a non-standard way which would break the calculation here. We can keep doing this cheap and cheerful calculation though if we added a tag; we'd just go through and tag all the device operators with a tag "yes the output tensor matches what device you pass in here" LOL. Is this the right ontology? Probably not. But it's easy and gives a bit more safety.

@ezyang

ezyang commented May 11, 2022

Copy link
Copy Markdown
Collaborator

Regarding modes, cc'ing @zou3519 and @samdow about it. Now that we have proper APIs for modes we have a code duplication problem, where mode dispatch and tensor dispatch have potentially the same code; how exactly should the code be shared between these two things? There's also a degree of freedom in how you do the implementation: modes get handled before subclass dispatch, but the mode could handle getting rid of subclasses so the subclass dispatch never gets involved.

Here's my first guess for what we should do:

  • Fake mode should apply to ALL tensors universally, whether or not they are fake. This matches the semantics of https://github.com/pytorch/torchdistx#fake-tensor and saves you from having to do a bunch of conversions into fake tensors before running fake. It will make testing easier. But it also means more work for us: we need a working meta mode too (I'm not sure how hard it would be--a first start would just be to slap Meta in the local dispatch key set and pray). Or if we don't have a meta mode, we can't do the super() trick anymore; we have to convert all of the inputs into meta tensors, run the meta function, and then turn it back into a fake tensor.
  • Therefore, if the mode applies universally, we should handle all fake tensors in the input as we go, no need to dispatch again into them later. This also suggests that the fake tensor subclass handler is solely used to handle tensors that escape "out of the mode" (non-lexical mode).
  • The torch handler should probably do something like turn on the mode locally and then run the result, that's probably the easiest way to share code

Comment thread fake_tensor.py
# TODO: https://github.com/pytorch/pytorch/pull/77182
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be recommending people generally use the FX normalize function? It still seems a bit different from how I personally feel idiomatic handling for functions should be done (via defining a function, and then hooking it up via some register decorator).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally thought I was going to be handling more than just resize_as_ (and I still might, e.g. when this is removed). When you only care about one argument input it would be easier to handle these functions in bulk than separately defining each operator, and also means that if the signature changes you dont also need to change the signature of your function.

Would also save having to specify defaults as in https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py#L99

And for functions intercepted at the __torch_function__ level, could potentially handle some of the implicit conversions that currently before the kernel is invoked - e.g. things like broadcasting lists, and allowing tensor(0) for int arguments

Comment thread fake_tensor.py
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
return tree_map(partial(wrap, device=new_kwargs["other"].device), r)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a composite, that's why it doesn't get called

- func: type_as(Tensor self, Tensor other) -> Tensor
  variants: method

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I just need to define __torch_function__ to deal with some of the composite operators

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Shouldn't they just be properly handled by the natural decomposition?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because meta_tensor.type_as(meta_tensor) short-circuits, they both have the same type

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will probably also need to handle all variants of aten::to ..

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the part where the "is-a" is a problem :( But hopefully the new TensorImpl override should help.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any details on this ?

@albanD albanD May 16, 2022

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Work like pytorch/pytorch#77396 (for device/type) should allow us to intercept all the calls on Tensors. And in particular avoid this short circuit?

Comment thread fake_tensor.py Outdated
return t.device == torch.device("cpu") and t.dim() == 0

# cpu - zero-dim tensors can be called in cuda kernels,
# so overwrite cuda kernels

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only for some kernels! Tagging would help.

@albanD albanD left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty cool!

Comment thread fake_tensor.py
def __init__(self, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
assert elem.device.type == "meta"
device if isinstance(device, torch.device) else torch.device(device)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a device = ?

Comment thread fake_tensor.py

@property
def device(self):
return self.fake_device

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the new CustomPython TensorImpl policy also covers the device attribute?

Comment thread fake_tensor.py
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
return tree_map(partial(wrap, device=new_kwargs["other"].device), r)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the part where the "is-a" is a problem :( But hopefully the new TensorImpl override should help.

Comment thread fake_tensor.py
self.assertEqual(y.shape, (4, 2, 2))
self.assertEqual(y.device, torch.device("cpu"))

@unittest.skip("Waiting on https://github.com/pytorch/pytorch/pull/77182")

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: that landed

Comment thread fake_tensor.py
self.assertEqual(out.shape, (8, 8))
self.assertEqual(out.device, torch.device("cpu"))

def test_zero_dim(self):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could you add a skip if cuda is not available to make sure this file runs on CPU-only machines?

@ezyang

ezyang commented May 13, 2022

Copy link
Copy Markdown
Collaborator

cc @anjali411

eellison pushed a commit to pytorch/pytorch that referenced this pull request May 20, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 20, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 20, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 20, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 20, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 23, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 23, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 24, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 25, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 25, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 31, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 31, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.

Differential Revision: [D36618467](https://our.internmc.facebook.com/intern/diff/D36618467)

[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 31, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
eellison pushed a commit to pytorch/pytorch that referenced this pull request May 31, 2022
This is just copying over the PR from albanD/subclass_zoo#32, and pulling in `BaseTensor`, plus adding one test based on invariants about schemas. 

From code comments
>  Meta tensors give you the ability to run PyTorch code without having to
 actually do computation through tensors allocated on a `meta` device.
 Because the device is `meta`, meta tensors do not model device propagation.
 FakeTensor extends MetaTensors to also carry an additional `fake_device`
 which tracks devices that would have been used.



[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants