|
134 | 134 | " patch_size: int | None = None, # Render patches of the DRR in series\n", |
135 | 135 | " renderer: str = \"siddon\", # Rendering backend, either \"siddon\" or \"trilinear\"\n", |
136 | 136 | " persistent: bool = True, # Set persistent value in `torch.nn.Module.register_buffer`\n", |
| 137 | + " compile_renderer: bool = False, # Compile the renderer for performance boost\n", |
| 138 | + " checkpoint_gradients: bool = False, # Checkpoint gradients to improve memory usage\n", |
137 | 139 | " **renderer_kwargs, # Kwargs for the renderer\n", |
138 | 140 | " ):\n", |
139 | 141 | " super().__init__()\n", |
|
191 | 193 | " raise ValueError(\n", |
192 | 194 | " f\"renderer must be 'siddon' or 'trilinear', not {renderer}\"\n", |
193 | 195 | " )\n", |
| 196 | + " if compile_renderer:\n", |
| 197 | + " self.renderer = torch.compile(self.renderer, mode=\"default\")\n", |
194 | 198 | " self.reshape = reshape\n", |
195 | 199 | " self.patch_size = patch_size\n", |
| 200 | + " self.checkpoint_gradients = checkpoint_gradients\n", |
196 | 201 | "\n", |
197 | 202 | " def reshape_transform(self, img, batch_size):\n", |
198 | 203 | " if self.reshape:\n", |
|
260 | 265 | "outputs": [], |
261 | 266 | "source": [ |
262 | 267 | "#| export\n", |
| 268 | + "from torch.utils.checkpoint import checkpoint\n", |
| 269 | + "\n", |
263 | 270 | "from diffdrr.pose import RigidTransform, convert\n", |
264 | 271 | "\n", |
265 | 272 | "\n", |
|
282 | 289 | "\n", |
283 | 290 | " # Create the source / target points and render the image\n", |
284 | 291 | " source, target = self.detector(pose, calibration)\n", |
285 | | - " img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n", |
| 292 | + "\n", |
| 293 | + " if self.checkpoint_gradients:\n", |
| 294 | + " img = checkpoint(\n", |
| 295 | + " self.render,\n", |
| 296 | + " self.density,\n", |
| 297 | + " source,\n", |
| 298 | + " target,\n", |
| 299 | + " mask_to_channels,\n", |
| 300 | + " **kwargs,\n", |
| 301 | + " use_reentrant=False,\n", |
| 302 | + " )\n", |
| 303 | + " else:\n", |
| 304 | + " img = self.render(self.density, source, target, mask_to_channels, **kwargs)\n", |
286 | 305 | " return self.reshape_transform(img, batch_size=len(pose))\n", |
287 | 306 | "\n", |
288 | 307 | "\n", |
|
408 | 427 | " x[..., 1] = self.detector.height - x[..., 1]\n", |
409 | 428 | " if self.detector.reverse_x_axis:\n", |
410 | 429 | " x[..., 0] = self.detector.width - x[..., 0]\n", |
411 | | - " \n", |
| 430 | + "\n", |
412 | 431 | " return x[..., :2]" |
413 | 432 | ] |
414 | 433 | }, |
|
0 commit comments