-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement moving to device
and changing dtype
#209
Implement moving to device
and changing dtype
#209
Conversation
Can lines like factory_kwargs = {"device": device, "dtype": dtype} then be removed? |
Should In general, should Should Also what about |
In |
@cr-xu please review carefully. This seems too easy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Since all the tests are passing let's go with it. I guess we'll see in practice if anything breaks.
Description
Implements that elements and beams can easily be moved to different devices and their beam types can be changes, like you would expect from a normal
torch.nn.Module
.Motivation and Context
This is normal for
torch.nn.Module
and would make working with different devices and dtypes in Cheetah much easier. Closes #113.Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line