-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtensorflow.mdc
More file actions
50 lines (43 loc) · 2 KB
/
tensorflow.mdc
File metadata and controls
50 lines (43 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
---
description: "TensorFlow: Keras, model training, production deployment"
globs: ["*.py", "requirements.txt", "pyproject.toml"]
alwaysApply: true
---
# TensorFlow Rules
## Model Building
- Use tf.keras.Model or Sequential for model architecture
- Define layers in __init__, forward pass in call() method
- Use Input layer to define input shape explicitly
- Implement custom layers by inheriting tf.keras.layers.Layer
- Use functional API for complex model architectures
## Data Pipeline
- Use tf.data.Dataset for efficient data loading
- Apply transformations with map(), filter(), and batch()
- Use tf.data.AUTOTUNE for optimal performance
- Implement proper data augmentation in pipeline
- Use tf.data.experimental.AUTOTUNE for num_parallel_calls
## Training Configuration
- Configure mixed precision with policy = tf.keras.mixed_precision.Policy('mixed_float16')
- Use appropriate optimizers (Adam, AdamW, SGD) with learning rate schedules
- Implement callbacks for checkpointing, early stopping, and monitoring
- Use tf.keras.utils.plot_model for architecture visualization
- Configure proper loss functions and metrics
## Memory & Performance
- Use tf.function decorator for graph compilation
- Avoid Python loops in graph mode operations
- Use tf.GradientTape for custom training loops
- Implement gradient clipping for training stability
- Use tf.distribute.Strategy for multi-GPU training
## Model Persistence
- Save models with model.save() for complete model persistence
- Use SavedModel format for production deployment
- Export models to TensorFlow Lite for mobile/edge deployment
- Implement versioning strategy for model management
- Use TensorFlow Serving for production inference
## Best Practices
- Set random seeds for reproducibility: tf.random.set_seed()
- Use tf.keras.backend.clear_session() to reset state
- Implement proper input validation and preprocessing
- Use TensorBoard for training visualization and debugging
- Configure GPU memory growth to avoid OOM errors
- Use tf.debugging assertions for runtime validation