-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathpytorch.mdc
More file actions
49 lines (42 loc) · 1.96 KB
/
pytorch.mdc
File metadata and controls
49 lines (42 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
---
description: "PyTorch: neural networks, model training, GPU optimization"
globs: ["*.py", "requirements.txt", "pyproject.toml"]
alwaysApply: true
---
# PyTorch Rules
## Model Architecture
- Use torch.nn.Module base class for all models
- Initialize layers in __init__, define forward pass in forward()
- Use meaningful layer names for debugging and model inspection
- Implement proper weight initialization for training stability
- Use torch.nn.Sequential for simple sequential models
## Data Handling
- Use torch.utils.data.Dataset for custom datasets
- Implement __len__ and __getitem__ methods properly
- Use DataLoader with appropriate batch_size and num_workers
- Apply transforms consistently using torchvision.transforms
- Handle data augmentation in dataset transform pipeline
## Training Loop
- Move model and data to same device (CPU/GPU)
- Use torch.no_grad() for validation and inference
- Clear gradients with optimizer.zero_grad() before backward pass
- Use mixed precision training with torch.cuda.amp for efficiency
- Implement proper checkpointing with state_dict
## Memory Management
- Use torch.cuda.empty_cache() to clear GPU memory when needed
- Prefer in-place operations where possible (tensor.add_() vs tensor.add())
- Use gradient accumulation for large effective batch sizes
- Implement proper cleanup in exception handlers
- Monitor GPU memory usage with torch.cuda.memory_stats()
## Model Deployment
- Use torch.jit.script or torch.jit.trace for production models
- Save models with torch.save(model.state_dict(), path)
- Use torch.hub for model sharing and distribution
- Implement proper error handling for device compatibility
- Test models on target deployment hardware
## Best Practices
- Use torch.manual_seed() for reproducible results
- Validate tensor shapes throughout the pipeline
- Use appropriate loss functions and optimizers for your task
- Implement learning rate scheduling for better convergence
- Use tensorboard or wandb for training visualization