Skip to content

Commit d39d70c

Browse files
committed
release
0 parents  commit d39d70c

20 files changed

+1561
-0
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*.egg-info
2+
__pycache__
3+
*.ipynb
4+
output
5+
build

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2025 Yuchen Lin
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# A PyTorch Implementation of MLS-MPM (Moving Least Squares Material Point Method)
2+
This repository provides a PyTorch implementation of the MLS-MPM (Moving Least Squares Material Point Method). The algorithm is implemented using **a few lines of tensor operations in PyTorch**, making it naturally differentiable and optimized for GPU acceleration.
3+
The code is vectorized without any explicit loops, which makes it efficient for large-scale simulations.
4+
5+
[Gradient Checkpointing](https://pytorch.org/docs/stable/checkpoint.html) is highly recommended when integrating the MLS-MPM into a trainable deep learning framework. See [OmniPhysGS](https://github.com/wgsxm/omniphysgs) (ICLR 2025) for an example.
6+
## Installation
7+
### From source
8+
```bash
9+
git clone https://github.com/wgsxm/MPM-PyTorch.git
10+
cd MPM-PyTorch
11+
pip install .
12+
```
13+
## Quick Start
14+
Run the following code to try a simple example of MLS-MPM simulation. The code simulates a 3D elastic jelly falling onto a rigid floor. By default, the code will produce a video of the simulation in the `output` directory.
15+
```bash
16+
python simulate.py --config examples/jelly.yaml
17+
```
18+
<img src="assets/jelly.gif" width=400>
19+
<img src="assets/sand.gif" width=400>
20+
21+
## Usage
22+
Refer to `simulate.py` for more details.
23+
```python
24+
from mpm_pytorch import MPMSolver, set_boundary_conditions
25+
from mpm_pytorch.constitutive_models import *
26+
particles = ... # Particle positions to simulate
27+
# Create a MPM solver with default parameters
28+
mpm_solver = MPMSolver(particles)
29+
# Set boundary conditions (optional)
30+
boundary_conditions = ... # Refer to example configs
31+
set_boundary_conditions(mpm_solver, boundary_conditions)
32+
# Create constitutive models
33+
elasticity = CorotatedElasticity(E=2e6)
34+
plasicity = IdentityPlasticity()
35+
# Init particle state
36+
x = particles
37+
v = torch.zeros_like(x)
38+
C = torch.zeros((x.shape[0], 3, 3), device=x.device)
39+
F = torch.eye(3, device=x.device).unsqueeze(0).repeat(x.shape[0], 1, 1)
40+
# Start simulation for T steps
41+
for i in range(T):
42+
# Update stress
43+
stress = elasticity(F)
44+
# Particle to grid, grid update, grid to particle
45+
grid_update = mpm_solver(x, v, C, F, stress)
46+
# Plasticity correction
47+
F = plasticity(F)
48+
```
49+
50+
## Fast Batched SVD
51+
Batched SVD is a common operation in constitutive models. Original PyTorch SVD is not optimized for batched computation. We adopt a `warp-lang` implementation of differentiable batched SVD in `mpm_pytorch.constitutive_models.warp_svd` from [NCLaw](https://github.com/PingchuanMa/NCLaw/tree/main/nclaw/warp). It can run on CPU or GPU with CUDA.
52+
53+
The result of the decomposed matrices is not guaranteed to be the same as the original PyTorch SVD. You can also use the original PyTorch SVD or other batched SVD implementations if there is any environment conflict (the version of `warp-lang` requires `numpy<2`).
54+
55+
We provide a script to benchmark the batched SVD implementation.
56+
```bash
57+
python benchmark_svd.py
58+
```
59+
If you encounter error as `ModuleNotFoundError: No module named 'imp'` when using higher version of Python, simply change the import statement from `import imp` to `import importlib as imp`.
60+
61+
## Citation
62+
If you find our work helpful, please consider citing:
63+
```
64+
@inproceedings{
65+
lin2025omniphysgs,
66+
title={OmniPhys{GS}: 3D Constitutive Gaussians for General Physics-Based Dynamics Generation},
67+
author={Yuchen Lin and Chenguo Lin and Jianjin Xu and Yadong MU},
68+
booktitle={The Thirteenth International Conference on Learning Representations},
69+
year={2025},
70+
}
71+
```
72+
```
73+
@article{hu2018moving,
74+
title={A moving least squares material point method with displacement discontinuity and two-way rigid body coupling},
75+
author={Hu, Yuanming and Fang, Yu and Ge, Ziheng and Qu, Ziyin and Zhu, Yixin and Pradhana, Andre and Jiang, Chenfanfu},
76+
journal={ACM Transactions on Graphics (TOG)},
77+
year={2018},
78+
}
79+
```
80+
```
81+
@article{stomakhin2013material,
82+
title={A material point method for snow simulation},
83+
author={Stomakhin, Alexey and Schroeder, Craig and Chai, Lawrence and Teran, Joseph and Selle, Andrew},
84+
journal={ACM Transactions on Graphics (TOG)},
85+
year={2013},
86+
}
87+
```

assets/jelly.gif

2.59 MB
Loading

assets/sand.gif

5.05 MB
Loading

benchmark_svd.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import time
2+
import torch
3+
from tqdm import tqdm
4+
5+
from mpm_pytorch.constitutive_models.warp_svd import SVD
6+
7+
wp_svd = SVD()
8+
wp_svd(torch.randn(1, 3, 3))
9+
def warp_svd(x):
10+
return wp_svd(x)
11+
12+
def torch_svd(x):
13+
U, s, Vh = torch.svd(x)
14+
return U, s, Vh.transpose(-2, -1)
15+
16+
if __name__ == '__main__':
17+
test_time = 10
18+
n = 100000
19+
atol = 1e-5
20+
21+
print("\nTest correctness")
22+
Is = torch.eye(3).unsqueeze(0).repeat(n, 1, 1)
23+
for _ in tqdm(range(test_time), desc="Test orthogonality of U and V"):
24+
x = torch.randn(n, 3, 3)
25+
U, s, V = warp_svd(x)
26+
UU = U @ U.transpose(-2, -1)
27+
VV = V @ V.transpose(-2, -1)
28+
assert torch.allclose(UU, Is, atol=1e-5)
29+
assert torch.allclose(VV, Is, atol=1e-5)
30+
bar = tqdm(range(test_time), desc="Test correctness of decomposition")
31+
for _ in bar:
32+
x = torch.randn(n, 3, 3)
33+
U, s, V = warp_svd(x)
34+
warp_pred = U @ torch.diag_embed(s) @ V
35+
warp_mae = torch.abs(warp_pred - x).mean()
36+
U, s, V = torch_svd(x)
37+
torch_pred = U @ torch.diag_embed(s) @ V
38+
torch_mae = torch.abs(torch_pred - x).mean()
39+
bar.set_postfix(warp_mae=warp_mae.item(), torch_mae=torch_mae.item())
40+
assert warp_mae < atol
41+
42+
print("\nTest differentiability")
43+
bar = tqdm(range(test_time), desc="Test differentiability")
44+
for _ in bar:
45+
x = torch.randn(n, 3, 3, requires_grad=True)
46+
U, s, V = warp_svd(x)
47+
warp_pred = U @ torch.diag_embed(s) @ V
48+
warp_pred.sum().backward()
49+
assert x.grad is not None
50+
warp_mae = torch.abs(x.grad - torch.ones_like(x)).mean()
51+
x.grad = None
52+
U, s, V = torch_svd(x)
53+
torch_pred = U @ torch.diag_embed(s) @ V
54+
torch_pred.sum().backward()
55+
torch_mae = torch.abs(x.grad - torch.ones_like(x)).mean()
56+
bar.set_postfix(warp_mae=warp_mae.item(), torch_mae=torch_mae.item())
57+
assert warp_mae < atol
58+
59+
print("\nTest speed")
60+
x = torch.randn(n, 3, 3)
61+
start = time.time()
62+
for _ in tqdm(range(test_time), desc="Test warp_svd forward (CPU)"):
63+
U, s, V = warp_svd(x)
64+
end = time.time()
65+
print("warp_svd: ", (end - start) / test_time, "seconds")
66+
start = time.time()
67+
for _ in tqdm(range(test_time), desc="Test torch_svd forward (CPU)"):
68+
U, s, V = torch_svd(x)
69+
end = time.time()
70+
print("torch_svd: ", (end - start) / test_time, "seconds")
71+
72+
print()
73+
x = torch.randn(n, 3, 3, requires_grad=True)
74+
start = time.time()
75+
for _ in tqdm(range(test_time), desc="Test warp_svd forward + backward (CPU)"):
76+
U, s, V = warp_svd(x)
77+
warp_pred = U @ torch.diag_embed(s) @ V
78+
warp_pred.sum().backward()
79+
end = time.time()
80+
print("warp_svd: ", (end - start) / test_time, "seconds")
81+
x.grad = None
82+
start = time.time()
83+
for _ in tqdm(range(test_time), desc="Test torch_svd forward + backward (CPU)"):
84+
U, s, V = torch_svd(x)
85+
torch_pred = U @ torch.diag_embed(s) @ V
86+
torch_pred.sum().backward()
87+
end = time.time()
88+
print("torch_svd: ", (end - start) / test_time, "seconds")
89+
90+
if torch.cuda.is_available():
91+
print()
92+
x = torch.randn(n, 3, 3).cuda().requires_grad_()
93+
start = time.time()
94+
for _ in tqdm(range(test_time), desc="Test warp_svd forward + backward (GPU)"):
95+
U, s, V = warp_svd(x)
96+
warp_pred = U @ torch.diag_embed(s) @ V
97+
warp_pred.sum().backward()
98+
end = time.time()
99+
print("warp_svd: ", (end - start) / test_time, "seconds")
100+
x.grad = None
101+
start = time.time()
102+
for _ in tqdm(range(test_time), desc="Test torch_svd forward + backward (GPU)"):
103+
U, s, V = torch_svd(x)
104+
torch_pred = U @ torch.diag_embed(s) @ V
105+
torch_pred.sum().backward()
106+
end = time.time()
107+
print("torch_svd: ", (end - start) / test_time, "seconds")
108+
109+
110+
111+

examples/jelly.yaml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
output_dir: './output'
2+
tag: 'jelly'
3+
4+
material:
5+
elasticity: 'CorotatedElasticity'
6+
plasticity: 'IdentityPlasticity'
7+
color: 'blue'
8+
9+
sim:
10+
num_frames: 150
11+
steps_per_frame: 10
12+
initial_velocity: [0.0, 0.0, -0.5]
13+
boundary_conditions:
14+
- type: 'surface_collider'
15+
point: [1.0, 1.0, 0.02]
16+
normal: [0.0, 0.0, 1.0]
17+
surface: 'sticky'
18+
friction: 0.0
19+
start_time: 0.0
20+
end_time: 1e3

examples/sand.yaml

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
output_dir: './output'
2+
tag: 'sand'
3+
4+
material:
5+
elasticity: 'StVKElasticity'
6+
plasticity: 'DruckerPragerPlasticity'
7+
color: 'orange'
8+
9+
sim:
10+
num_frames: 150
11+
steps_per_frame: 10
12+
initial_velocity: [0.0, 0.0, -0.5]
13+
boundary_conditions:
14+
- type: 'surface_collider'
15+
point: [1.0, 1.0, 0.02]
16+
normal: [0.0, 0.0, 1.0]
17+
surface: 'sticky'
18+
friction: 0.0
19+
start_time: 0.0
20+
end_time: 1e3

mpm_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .mpm_solver import MPMSolver
2+
from .boundary_condition import set_boundary_conditions
3+
from .constitutive_models import get_constitutive

0 commit comments

Comments
 (0)