diff --git a/.gitattributes b/.gitattributes index 68e9101..d036ae4 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,4 @@ outputs/train/aloha_sim_transfer_cube_human/model_10000.safetensors filter=lfs diff=lfs merge=lfs -text resnet18-f37072fd.safetensors filter=lfs diff=lfs merge=lfs -text outputs/train/act_aloha_sim_transfer_cube_human/model_10000.safetensors filter=lfs diff=lfs merge=lfs -text +*.ipynb filter=lfs diff=lfs merge=lfs -text diff --git a/.ipynb_checkpoints/modeling_act-checkpoint.ipynb b/.ipynb_checkpoints/modeling_act-checkpoint.ipynb index 0760951..af94033 100644 --- a/.ipynb_checkpoints/modeling_act-checkpoint.ipynb +++ b/.ipynb_checkpoints/modeling_act-checkpoint.ipynb @@ -1,1738 +1,3 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0441978e-0b8a-4dc5-9300-f04249e4e2e3", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0c941a80-30a4-4680-9f76-5ac76c67bae1", - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "from collections import deque\n", - "from itertools import chain\n", - "from typing import Callable\n", - "\n", - "import numpy as np\n", - "import tinygrad\n", - "from tinygrad import Tensor, nn, dtypes" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "88efb4dc-05e4-4ec7-8c46-59657e281141", - "metadata": {}, - "outputs": [], - "source": [ - "def get_activation_fn(activation: str) -> Callable:\n", - " \"\"\"Return an activation function given a string.\"\"\"\n", - " if activation == \"relu\":\n", - " return Tensor.relu\n", - " if activation == \"gelu\":\n", - " return Tensor.gelu\n", - " raise RuntimeError(f\"activation should be relu/gelu/glu, not {activation}.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "e6ce0e70-2b63-4240-878b-215a80e856a4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "get_activation_fn('relu')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "aac9f722-55ce-4bca-9008-cf0f68152dd6", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTSinusoidalPositionEmbedding2d:\n", - " \"\"\"2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.\n", - "\n", - " The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H\n", - " for the vertical direction, and 1/W for the horizontal direction.\n", - " \"\"\"\n", - "\n", - " def __init__(self, dimension: int):\n", - " \"\"\"\n", - " Args:\n", - " dimension: The desired dimension of the embeddings.\n", - " \"\"\"\n", - " super().__init__()\n", - " self.dimension = dimension\n", - " self._two_pi = 2 * math.pi\n", - " self._eps = 1e-6\n", - " # Inverse \"common ratio\" for the geometric progression in sinusoid frequencies.\n", - " self._temperature = 10000\n", - "\n", - " def __call__(self, x: Tensor) -> Tensor:\n", - " \"\"\"\n", - " Args:\n", - " x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.\n", - " Returns:\n", - " A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.\n", - " \"\"\"\n", - " not_mask = Tensor.ones_like(x[0, :1]) # (1, H, W)\n", - " # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations\n", - " # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.\n", - " y_range = not_mask.cumsum(1).cast(dtype=dtypes.float32)\n", - " x_range = not_mask.cumsum(2).cast(dtype=dtypes.float32)\n", - "\n", - " # \"Normalize\" the position index such that it ranges in [0, 2π].\n", - " # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range\n", - " # are non-zero by construction. This is an artifact of the original code.\n", - " y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi\n", - " x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi\n", - "\n", - " inverse_frequency = Tensor(self._temperature ** (\n", - " 2 * (np.arange(self.dimension, dtype='f') // 2) / self.dimension\n", - " ))\n", - "\n", - " x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)\n", - " y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)\n", - "\n", - " print(x_range)\n", - " print(y_range)\n", - "\n", - " # Note: this stack then flatten operation results in interleaved sine and cosine terms.\n", - " # pos_embed_x and pos_embed_y are (1, H, W, C // 2).\n", - " x_range_sin = x_range[..., 0::2].sin()\n", - " x_range_cos = x_range[..., 1::2].cos()\n", - " y_range_sin = y_range[..., 0::2].sin()\n", - " y_range_cos = y_range[..., 1::2].cos()\n", - " print(f'x_range[..., 0::2].sin(): {x_range_sin}')\n", - " print(f'x_range[..., 1::2].cos(): {x_range_cos}')\n", - " pos_embed_x = x_range_sin.stack(x_range_cos, dim=-1).flatten(3)\n", - " pos_embed_y = y_range_sin.stack(y_range_cos, dim=-1).flatten(3)\n", - " pos_embed = pos_embed_y.cat(pos_embed_x, dim=3).permute(0, 3, 1, 2) # (1, C, H, W)\n", - "\n", - " return pos_embed" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7692187e-822c-4b68-924b-2788e1cf7723", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ", None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n" - ] - }, - { - "data": { - "text/plain": [ - " on METAL with grad None>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actSin = ACTSinusoidalPositionEmbedding2d(10)\n", - "actSin(Tensor.zeros(4,4,4,4))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "2a0ef1f0-ddf7-40d7-a6bf-2b3e7480ef82", - "metadata": {}, - "outputs": [], - "source": [ - "def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:\n", - " \"\"\"1D sinusoidal positional embeddings as in Attention is All You Need.\n", - "\n", - " Args:\n", - " num_positions: Number of token positions required.\n", - " Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).\n", - "\n", - " \"\"\"\n", - "\n", - " def get_position_angle_vec(position):\n", - " return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]\n", - "\n", - " sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)], dtype='f')\n", - " sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n", - " sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n", - " return Tensor(sinusoid_table).float()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "19404914-6edd-4bcb-8731-dc1855149fce", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0. , 1. , 0. ],\n", - " [ 0.841471 , 0.5403023 , 0.00215443],\n", - " [ 0.9092974 , -0.4161468 , 0.00430886]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "create_sinusoidal_pos_embedding(3, 3).numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f8dd33d3-a35f-403d-a46c-6f228739c928", - "metadata": {}, - "outputs": [], - "source": [ - "from tinygrad import Tensor, nn\n", - "from typing import Optional, Union, Literal\n", - "from tinygrad.ops import Variable\n", - "\n", - "class MultiheadAttention:\n", - " def __init__(self, embed_dim, num_heads, dropout=0.0):\n", - " self.embed_dim = embed_dim\n", - " self.num_heads = num_heads\n", - " self.head_dim = embed_dim // num_heads\n", - " assert self.head_dim * num_heads == embed_dim, \"n_state must be divisible by n_head\"\n", - "\n", - " self.query = nn.Linear(embed_dim, embed_dim)\n", - " self.key = nn.Linear(embed_dim, embed_dim)\n", - " self.value = nn.Linear(embed_dim, embed_dim)\n", - " self.out = nn.Linear(embed_dim, embed_dim)\n", - "\n", - " self.scaling = self.head_dim ** -0.5 \n", - " self.dropout = dropout\n", - "\n", - " def __call__(self, q: Tensor, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, training: bool = True):\n", - " batch_size, tgt_len, _ = q.shape\n", - " src_len = k.shape[1]\n", - "\n", - " # Apply linear transformations\n", - " q = self.query(q)\n", - " k = self.key(k)\n", - " v = self.value(v)\n", - "\n", - " # Reshape and transpose for multi-head attention\n", - " q = q.reshape(batch_size, tgt_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " k = k.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " v = v.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " \n", - " # Calculate attention scores\n", - " attn_scores = (q @ k.transpose(-2, -1)) * self.scaling\n", - "\n", - " # Apply key padding mask if provided\n", - " if key_padding_mask is not None:\n", - " print(f'(q,k,v): {q.shape}, {k.shape}, {v.shape}')\n", - " print(f'(key_padding_mask): {key_padding_mask.shape}')\n", - " # Reshape and expand key_padding_mask to match attn_scores dimensions\n", - " key_padding_mask = key_padding_mask.squeeze(1).squeeze(1) # Remove extra dimensions\n", - " key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # Add dimensions for heads and query length\n", - " key_padding_mask = key_padding_mask.expand(batch_size, self.num_heads, tgt_len, src_len)\n", - " attn_scores = attn_scores.masked_fill(key_padding_mask, float('-inf')) \n", - " \n", - " # Apply softmax to get attention weights\n", - " attn_weights = attn_scores.softmax(axis=-1)\n", - "\n", - " # Apply dropout\n", - " if self.dropout > 0:\n", - " attn_weights = attn_weights.dropout(p=self.dropout)\n", - "\n", - " # Apply attention to values\n", - " attn_output = attn_weights @ v\n", - "\n", - " # Reshape and combine heads\n", - " attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, tgt_len, self.embed_dim)\n", - "\n", - " # Final projection\n", - " attn_output = self.out(attn_output)\n", - "\n", - " return attn_output\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6a91632b-0a49-4bec-b500-215fc7b34ac5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mha = MultiheadAttention(9, 9)\n", - "mha(Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9))" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "db7c53ee-be78-41ba-b463-4083cf12d92c", - "metadata": {}, - "outputs": [], - "source": [ - "from dataclasses import dataclass, field\n", - "\n", - "@dataclass\n", - "class ACTConfig:\n", - " \"\"\"Configuration class for the Action Chunking Transformers policy.\n", - "\n", - " Defaults are configured for training on bimanual Aloha tasks like \"insertion\" or \"transfer\".\n", - "\n", - " The parameters you will most likely need to change are the ones which depend on the environment / sensors.\n", - " Those are: `input_shapes` and 'output_shapes`.\n", - "\n", - " Notes on the inputs and outputs:\n", - " - Either:\n", - " - At least one key starting with \"observation.image is required as an input.\n", - " AND/OR\n", - " - The key \"observation.environment_state\" is required as input.\n", - " - If there are multiple keys beginning with \"observation.images.\" they are treated as multiple camera\n", - " views. Right now we only support all images having the same shape.\n", - " - May optionally work without an \"observation.state\" key for the proprioceptive robot state.\n", - " - \"action\" is required as an output key.\n", - "\n", - " Args:\n", - " n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the\n", - " current step and additional steps going back).\n", - " chunk_size: The size of the action prediction \"chunks\" in units of environment steps.\n", - " n_action_steps: The number of action steps to run in the environment for one invocation of the policy.\n", - " This should be no greater than the chunk size. For example, if the chunk size size 100, you may\n", - " set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the\n", - " environment, and throws the other 50 out.\n", - " input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents\n", - " the input data name, and the value is a list indicating the dimensions of the corresponding data.\n", - " For example, \"observation.image\" refers to an input from a camera with dimensions [3, 96, 96],\n", - " indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't\n", - " include batch dimension or temporal dimension.\n", - " output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents\n", - " the output data name, and the value is a list indicating the dimensions of the corresponding data.\n", - " For example, \"action\" refers to an output shape of [14], indicating 14-dimensional actions.\n", - " Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.\n", - " input_normalization_modes: A dictionary with key representing the modality (e.g. \"observation.state\"),\n", - " and the value specifies the normalization mode to apply. The two available modes are \"mean_std\"\n", - " which subtracts the mean and divides by the standard deviation and \"min_max\" which rescale in a\n", - " [-1, 1] range.\n", - " output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the\n", - " original scale. Note that this is also used for normalizing the training targets.\n", - " vision_backbone: Name of the torchvision resnet backbone to use for encoding images.\n", - " pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.\n", - " `None` means no pretrained weights.\n", - " replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated\n", - " convolution.\n", - " pre_norm: Whether to use \"pre-norm\" in the transformer blocks.\n", - " dim_model: The transformer blocks' main hidden dimension.\n", - " n_heads: The number of heads to use in the transformer blocks' multi-head attention.\n", - " dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward\n", - " layers.\n", - " feedforward_activation: The activation to use in the transformer block's feed-forward layers.\n", - " n_encoder_layers: The number of transformer layers to use for the transformer encoder.\n", - " n_decoder_layers: The number of transformer layers to use for the transformer decoder.\n", - " use_vae: Whether to use a variational objective during training. This introduces another transformer\n", - " which is used as the VAE's encoder (not to be confused with the transformer encoder - see\n", - " documentation in the policy class).\n", - " latent_dim: The VAE's latent dimension.\n", - " n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.\n", - " temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal\n", - " ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be\n", - " 1 when using this feature, as inference needs to happen at every step to form an ensemble. For\n", - " more information on how ensembling works, please see `ACTTemporalEnsembler`.\n", - " dropout: Dropout to use in the transformer layers (see code for details).\n", - " kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective\n", - " is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.\n", - " \"\"\"\n", - "\n", - " # Input / output structure.\n", - " n_obs_steps: int = 1\n", - " chunk_size: int = 100\n", - " n_action_steps: int = 100\n", - "\n", - " input_shapes: dict[str, list[int]] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": [3, 480, 640],\n", - " \"observation.state\": [14],\n", - " }\n", - " )\n", - " output_shapes: dict[str, list[int]] = field(\n", - " default_factory=lambda: {\n", - " \"action\": [14],\n", - " }\n", - " )\n", - "\n", - " # Normalization / Unnormalization\n", - " input_normalization_modes: dict[str, str] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": \"mean_std\",\n", - " \"observation.state\": \"mean_std\",\n", - " }\n", - " )\n", - " output_normalization_modes: dict[str, str] = field(\n", - " default_factory=lambda: {\n", - " \"action\": \"mean_std\",\n", - " }\n", - " )\n", - "\n", - " # Overrides.\n", - " override_dataset_stats: dict[str, dict[str, list[[float]]]] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": {\n", - " \"mean\": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)\n", - " \"std\": [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)\n", - " }\n", - " }\n", - " )\n", - "\n", - " # Architecture.\n", - " # Vision backbone.\n", - " vision_backbone: str = \"resnet18\"\n", - " pretrained_backbone_weights: str | None = \"ResNet18_Weights.IMAGENET1K_V1\"\n", - " replace_final_stride_with_dilation: int = False\n", - " # Transformer layers.\n", - " pre_norm: bool = False\n", - " dim_model: int = 512\n", - " n_heads: int = 8\n", - " dim_feedforward: int = 3200\n", - " feedforward_activation: str = \"relu\"\n", - " n_encoder_layers: int = 4\n", - " # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code\n", - " # that means only the first layer is used. Here we match the original implementation by setting this to 1.\n", - " # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.\n", - " n_decoder_layers: int = 1\n", - " # VAE.\n", - " use_vae: bool = True\n", - " latent_dim: int = 32\n", - " n_vae_encoder_layers: int = 4\n", - "\n", - " # Inference.\n", - " # Note: the value used in ACT when temporal ensembling is enabled is 0.01.\n", - " temporal_ensemble_coeff: float | None = None\n", - "\n", - " # Training and loss computation.\n", - " dropout: float = 0.1\n", - " kl_weight: float = 10.0\n", - "\n", - " def __post_init__(self):\n", - " \"\"\"Input validation (not exhaustive).\"\"\"\n", - " if not self.vision_backbone.startswith(\"resnet\"):\n", - " raise ValueError(\n", - " f\"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}.\"\n", - " )\n", - " if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:\n", - " raise NotImplementedError(\n", - " \"`n_action_steps` must be 1 when using temporal ensembling. This is \"\n", - " \"because the policy needs to be queried every step to compute the ensembled action.\"\n", - " )\n", - " if self.n_action_steps > self.chunk_size:\n", - " raise ValueError(\n", - " f\"The chunk size is the upper bound for the number of action steps per model invocation. Got \"\n", - " f\"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`.\"\n", - " )\n", - " if self.n_obs_steps != 1:\n", - " raise ValueError(\n", - " f\"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`\"\n", - " )\n", - " if (\n", - " not any(k.startswith(\"observation.image\") for k in self.input_shapes)\n", - " and \"observation.environment_state\" not in self.input_shapes\n", - " ):\n", - " raise ValueError(\"You must provide at least one image or the environment state among the inputs.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "f21a9b74-1b92-4646-b093-27fedfc72513", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTDecoderLayer:\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - " self.multihead_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - "\n", - " # Feed forward layers.\n", - " self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)\n", - " self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)\n", - "\n", - " self.norm1 = nn.LayerNorm(config.dim_model)\n", - " self.norm2 = nn.LayerNorm(config.dim_model)\n", - " self.norm3 = nn.LayerNorm(config.dim_model)\n", - " self.dropout_rate = config.dropout\n", - "\n", - " self.activation = get_activation_fn(config.feedforward_activation)\n", - " self.pre_norm = config.pre_norm\n", - "\n", - " def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:\n", - " return tensor if pos_embed is None else tensor + pos_embed\n", - "\n", - " def __call__(\n", - " self,\n", - " x: Tensor,\n", - " encoder_out: Tensor,\n", - " decoder_pos_embed: Tensor | None = None,\n", - " encoder_pos_embed: Tensor | None = None,\n", - " ) -> Tensor:\n", - " \"\"\"\n", - " Args:\n", - " x: (Decoder Sequence, Batch, Channel) tensor of input tokens.\n", - " encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are\n", - " cross-attending with.\n", - " decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).\n", - " encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).\n", - " Returns:\n", - " (DS, B, C) tensor of decoder output features.\n", - " \"\"\"\n", - " skip = x\n", - " if self.pre_norm:\n", - " x = self.norm1(x)\n", - " q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)\n", - " x = self.self_attn(q, k, x) \n", - " #x = x[0] # select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm2(x)\n", - " else:\n", - " x = self.norm1(x)\n", - " skip = x\n", - " x = self.multihead_attn(\n", - " self.maybe_add_pos_embed(x, decoder_pos_embed),\n", - " self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),\n", - " encoder_out,\n", - " )\n", - " #x = x[0] # select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm3(x)\n", - " else:\n", - " x = self.norm2(x)\n", - " skip = x\n", - " \n", - " x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout_rate).sequential([self.linear2])\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if not self.pre_norm:\n", - " x = self.norm3(x)\n", - " return x\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "79ac1c63-ded4-4b6e-8291-3e8b0c5e5d49", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actDecoder = ACTDecoderLayer(ACTConfig())\n", - "actDecoder(Tensor.zeros(3,512, 512), Tensor.zeros(3,512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "199f37fd-b9ff-4946-9531-53c7d90e6332", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTDecoder:\n", - " def __init__(self, config: ACTConfig):\n", - " \"\"\"Convenience module for running multiple decoder layers followed by normalization.\"\"\"\n", - " super().__init__()\n", - " self.layers = [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]\n", - " self.norm = nn.LayerNorm(config.dim_model)\n", - "\n", - " def __call__(\n", - " self,\n", - " x: Tensor,\n", - " encoder_out: Tensor,\n", - " decoder_pos_embed: Tensor | None = None,\n", - " encoder_pos_embed: Tensor | None = None,\n", - " ) -> Tensor:\n", - " for layer in self.layers:\n", - " x = layer(\n", - " x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed\n", - " )\n", - " if self.norm is not None:\n", - " x = self.norm(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8fc180ee-fe07-44a8-863a-6d064a93fb32", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actDecode = ACTDecoder(ACTConfig())\n", - "actDecode(Tensor.zeros(3,512,512), Tensor.zeros(3,512,512))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "5fd920a5-55bc-4f74-a4f0-896ecbc800cc", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTEncoderLayer:\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - "\n", - " # Feed forward layers.\n", - " self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)\n", - " self.dropout = config.dropout\n", - " self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)\n", - "\n", - " self.norm1 = nn.LayerNorm(config.dim_model)\n", - " self.norm2 = nn.LayerNorm(config.dim_model)\n", - "\n", - " self.activation = get_activation_fn(config.feedforward_activation)\n", - " self.pre_norm = config.pre_norm\n", - "\n", - " def __call__(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:\n", - " skip = x\n", - " if self.pre_norm:\n", - " x = self.norm1(x)\n", - " q = k = x if pos_embed is None else x + pos_embed\n", - " x = self.self_attn(q, k, x, key_padding_mask=key_padding_mask)\n", - " # x = x[0] # note: [0] to select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm2(x)\n", - " else:\n", - " x = self.norm1(x)\n", - " skip = x\n", - " x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout).sequential([self.linear2])\n", - " x = skip + x.dropout(p=self.dropout)\n", - " if not self.pre_norm:\n", - " x = self.norm2(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "4e2839ad-6fa9-47f9-80e5-2681f7eb0d1e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actEncode = ACTEncoderLayer(ACTConfig())\n", - "actEncode(Tensor.zeros(3, 512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "a7ec9da4-bea0-46bc-9bb0-7ed27a1fd6d7", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTEncoder:\n", - " \"\"\"Convenience module for running multiple encoder layers, maybe followed by normalization.\"\"\"\n", - "\n", - " def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):\n", - " super().__init__()\n", - " self.is_vae_encoder = is_vae_encoder\n", - " num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers\n", - " self.layers = [ACTEncoderLayer(config) for _ in range(num_layers)]\n", - " self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else lambda x: x\n", - "\n", - " def __call__(\n", - " self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None\n", - " ) -> Tensor:\n", - " for layer in self.layers:\n", - " print(f'ACTEncoder x.shape per layer: {x.shape}')\n", - " x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)\n", - " x = self.norm(x)\n", - " return x\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "c3747d03-b4ac-4f61-a4f3-5b9e94eb417f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actEncoder = ACTEncoder(ACTConfig())\n", - "actEncode(Tensor.zeros(3, 512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "1ad78fa2-7179-477c-88c0-ae2dea72ef09", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ram used: 0.04 GB, layer4.1.bn2.running_var : 100%|█| \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loaded weights in 30.70 ms, 0.04 GB loaded at 1.46 GB/s\n" - ] - } - ], - "source": [ - "import tinygrad.nn as nn\n", - "from tinygrad import Tensor, dtypes\n", - "from tinygrad.helpers import fetch, get_child\n", - "\n", - "# allow monkeypatching in layer implementations\n", - "BatchNorm = nn.BatchNorm2d\n", - "Conv2d = nn.Conv2d\n", - "Linear = nn.Linear\n", - "\n", - "class FrozenBatchNorm2d:\n", - " def __init__(self, num_features, eps=1e-5):\n", - " super().__init__()\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " # Register buffers instead of parameters\n", - " self.weight = Tensor.ones(num_features, requires_grad=False)\n", - " self.bias = Tensor.zeros(num_features, requires_grad=False)\n", - " self.running_mean = Tensor.zeros(num_features, requires_grad=False)\n", - " self.running_var = Tensor.ones(num_features, requires_grad=False)\n", - " def __call__(self, x:Tensor) -> Tensor:\n", - " # Reshape for 2D input\n", - " scale = (self.weight / (self.running_var + self.eps).sqrt()).reshape(1, -1, 1, 1)\n", - " bias = (self.bias - self.running_mean * scale.flatten()).reshape(1, -1, 1, 1)\n", - " return x * scale + bias\n", - "\n", - "class Block:\n", - " def __init__(self, in_dims, dims, stride=1):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(\n", - " in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False\n", - " )\n", - " self.bn1 = FrozenBatchNorm2d(dims)\n", - " self.conv2 = nn.Conv2d(\n", - " dims, dims, kernel_size=3, stride=1, padding=1, bias=False\n", - " )\n", - " self.bn2 = FrozenBatchNorm2d(dims)\n", - " self.downsample = []\n", - " if stride != 1:\n", - " self.downsample = [\n", - " nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),\n", - " FrozenBatchNorm2d(dims)\n", - " ]\n", - " def __call__(self, x):\n", - " base_operations = [\n", - " self.conv1,\n", - " self.bn1,\n", - " Tensor.relu,\n", - " self.conv2,\n", - " self.bn2\n", - " ]\n", - " out = x.sequential(base_operations)\n", - " \n", - " if self.downsample != []:\n", - " return (x.sequential(base_operations) + x.sequential(self.downsample)).relu()\n", - " else:\n", - " return x.sequential(base_operations).relu()\n", - "\n", - "class ResNet:\n", - " def __init__(self, block, num_blocks, num_classes=10):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", - " self.bn1 = FrozenBatchNorm2d(64)\n", - " self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)\n", - " self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)\n", - " self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)\n", - " self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)\n", - " #self.fc = nn.Linear(512, num_classes, requires_grad=False) # if we decide to use this someday, remove the grad\n", - " def _make_layer(self, block, in_dims, dims, num_blocks, stride):\n", - " strides = [stride] + [1] * (num_blocks - 1)\n", - " layers = []\n", - " for stride in strides:\n", - " layers.append(block(in_dims, dims, stride))\n", - " in_dims = dims\n", - " return layers\n", - " def __call__(self, x:Tensor):\n", - " x = self.bn1(self.conv1(x)).relu().max_pool2d()\n", - " x = x.sequential(self.layer1)\n", - " x = x.sequential(self.layer2 + self.layer3 + self.layer4)\n", - " \"\"\"\n", - " Commented out for now, because we're just using the output from layer4\n", - " \"\"\"\n", - " #x = x.mean([2, 3])\n", - " #x = self.fc(x)\n", - " return x\n", - "\n", - "resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)\n", - "state_dict = nn.state.safe_load(\"resnet18-f37072fd.safetensors\")\n", - "nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "e226e543-2c2f-4326-9d50-5c825834c044", - "metadata": {}, - "outputs": [], - "source": [ - "from itertools import chain\n", - "\n", - "class ACT:\n", - " \"\"\"Action Chunking Transformer: The underlying neural network for ACTPolicy.\n", - "\n", - " Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.\n", - " - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the\n", - " model that encodes the target data (a sequence of actions), and the condition (the robot\n", - " joint-space).\n", - " - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with\n", - " cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we\n", - " have an option to train this model without the variational objective (in which case we drop the\n", - " `vae_encoder` altogether, and nothing about this model has anything to do with a VAE).\n", - "\n", - " Transformer\n", - " Used alone for inference\n", - " (acts as VAE decoder\n", - " during training)\n", - " ┌───────────────────────┐\n", - " │ Outputs │\n", - " │ ▲ │\n", - " │ ┌─────►┌───────┐ │\n", - " ┌──────┐ │ │ │Transf.│ │\n", - " │ │ │ ├─────►│decoder│ │\n", - " ┌────┴────┐ │ │ │ │ │ │\n", - " │ │ │ │ ┌───┴───┬─►│ │ │\n", - " │ VAE │ │ │ │ │ └───────┘ │\n", - " │ encoder │ │ │ │Transf.│ │\n", - " │ │ │ │ │encoder│ │\n", - " └───▲─────┘ │ │ │ │ │\n", - " │ │ │ └▲──▲─▲─┘ │\n", - " │ │ │ │ │ │ │\n", - " inputs └─────┼──┘ │ image emb. │\n", - " │ state emb. │\n", - " └───────────────────────┘\n", - " \"\"\"\n", - "\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.config = config\n", - " # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].\n", - " # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).\n", - " self.use_robot_state = \"observation.state\" in config.input_shapes\n", - " self.use_images = any(k.startswith(\"observation.image\") for k in config.input_shapes)\n", - " self.use_env_state = \"observation.environment_state\" in config.input_shapes\n", - " if self.config.use_vae:\n", - " self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)\n", - " self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)\n", - " # Projection layer for joint-space configuration to hidden dimension.\n", - " if self.use_robot_state:\n", - " self.vae_encoder_robot_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.state\"][0], config.dim_model\n", - " )\n", - " # Projection layer for action (joint-space target) to hidden dimension.\n", - " self.vae_encoder_action_input_proj = nn.Linear(\n", - " config.output_shapes[\"action\"][0], config.dim_model\n", - " )\n", - " # Projection layer from the VAE encoder's output to the latent distribution's parameter space.\n", - " self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)\n", - " # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch\n", - " # dimension.\n", - " num_input_token_encoder = 1 + config.chunk_size\n", - " if self.use_robot_state:\n", - " num_input_token_encoder += 1\n", - " self.vae_encoder_pos_enc = create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0)\n", - " self.vae_encoder_pos_enc.requires_grad = False\n", - "\n", - " # Backbone for image feature extraction.\n", - " if self.use_images:\n", - " resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)\n", - " state_dict = nn.state.safe_load(\"resnet18-f37072fd.safetensors\")\n", - " nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)\n", - " backbone_model = resnet18_IMAGENET1K_V1\n", - " # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final\n", - " # feature map).\n", - " # Note: The forward method of this returns a dict: {\"feature_map\": output}.\n", - " self.backbone = backbone_model #IntermediateLayerGetter(backbone_model, return_layers={\"layer4\": \"feature_map\"})\n", - "\n", - " # Transformer (acts as VAE decoder when training with the variational objective).\n", - " self.encoder = ACTEncoder(config)\n", - " self.decoder = ACTDecoder(config)\n", - "\n", - " # Transformer encoder input projections. The tokens will be structured like\n", - " # [latent, (robot_state), (env_state), (image_feature_map_pixels)].\n", - " if self.use_robot_state:\n", - " self.encoder_robot_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.state\"][0], config.dim_model\n", - " )\n", - " if self.use_env_state:\n", - " self.encoder_env_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.environment_state\"][0], config.dim_model\n", - " )\n", - " self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)\n", - " if self.use_images:\n", - " self.encoder_img_feat_input_proj = nn.Conv2d(\n", - " 512, config.dim_model, kernel_size=1\n", - " )\n", - " # Transformer encoder positional embeddings.\n", - " n_1d_tokens = 1 # for the latent\n", - " if self.use_robot_state:\n", - " n_1d_tokens += 1\n", - " if self.use_env_state:\n", - " n_1d_tokens += 1\n", - " self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)\n", - " if self.use_images:\n", - " self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)\n", - "\n", - " # Transformer decoder.\n", - " # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).\n", - " self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)\n", - "\n", - " # Final action regression head on the output of the transformer's decoder.\n", - " self.action_head = nn.Linear(config.dim_model, config.output_shapes[\"action\"][0])\n", - "\n", - " self._reset_parameters()\n", - "\n", - " # CHANGE THIS WHEN RUNNING.\n", - " self.training=True\n", - "\n", - " def _reset_parameters(self):\n", - " \"\"\"Xavier-uniform initialization of the transformer parameters as in the original code.\"\"\"\n", - " for p in chain(nn.state.get_parameters(self.encoder), nn.state.get_parameters(self.decoder)):\n", - " if p.ndim > 1:\n", - " def xavier_uniform_(tensor: Tensor) -> Tensor:\n", - " fan_in, fan_out = tensor.shape[:2]\n", - " \n", - " # Calculate the range for the uniform distribution\n", - " # This is the glorot/xavier uniform initialization formula\n", - " a = math.sqrt(6.0 / (fan_in + fan_out))\n", - " \n", - " # Use uniform distribution to initialize the tensor\n", - " return Tensor.uniform(*tensor.shape, low=-a, high=a)\n", - " p = xavier_uniform_(p)\n", - "\n", - " def __call__(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:\n", - " \"\"\"A forward pass through the Action Chunking Transformer (with optional VAE encoder).\n", - "\n", - " `batch` should have the following structure:\n", - " {\n", - " \"observation.state\" (optional): (B, state_dim) batch of robot states.\n", - "\n", - " \"observation.images\": (B, n_cameras, C, H, W) batch of images.\n", - " AND/OR\n", - " \"observation.environment_state\": (B, env_dim) batch of environment states.\n", - "\n", - " \"action\" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.\n", - " }\n", - "\n", - " Returns:\n", - " (B, chunk_size, action_dim) batch of action sequences\n", - " Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the\n", - " latent dimension.\n", - " \"\"\"\n", - " if self.config.use_vae and self.training:\n", - " assert (\n", - " \"action\" in batch\n", - " ), \"actions must be provided when using the variational objective in training mode.\"\n", - "\n", - " batch_size = (\n", - " batch[\"observation.images\"]\n", - " if \"observation.images\" in batch\n", - " else batch[\"observation.environment_state\"]\n", - " ).shape[0]\n", - "\n", - " print(f'batch_size: {batch_size}')\n", - "\n", - " # Prepare the latent for input to the transformer encoder.\n", - " if self.config.use_vae and \"action\" in batch:\n", - " # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].\n", - " cls_embed = self.vae_encoder_cls_embed.weight.repeat(batch_size, 1, 1) # (B, 1, D)\n", - " if self.use_robot_state:\n", - " robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[\"observation.state\"])\n", - " robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)\n", - " action_embed = self.vae_encoder_action_input_proj(batch[\"action\"]) # (B, S, D)\n", - "\n", - " if self.use_robot_state:\n", - " vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)\n", - " else:\n", - " vae_encoder_input = [cls_embed, action_embed]\n", - " vae_encoder_input = Tensor.cat(*vae_encoder_input, dim=1)\n", - "\n", - " # Prepare fixed positional embedding.\n", - " # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.\n", - " pos_embed = self.vae_encoder_pos_enc.contiguous().detach() # (1, S+2, D)\n", - "\n", - " # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the\n", - " # sequence depending whether we use the input states or not (cls and robot state)\n", - " # False means not a padding token.\n", - " cls_joint_is_pad = Tensor.full(\n", - " shape=(batch_size, 2 if self.use_robot_state else 1),\n", - " fill_value=False\n", - " )\n", - " key_padding_mask = Tensor.cat(\n", - " cls_joint_is_pad, batch[\"action_is_pad\"], dim=1\n", - " ) # (bs, seq+1 or 2)\n", - "\n", - " print(f'vae_encoder_input.shape: {vae_encoder_input.shape}')\n", - " print(f'pos_embed.shape: {pos_embed.shape}')\n", - " print(f'key_padding_mask.shape: {key_padding_mask.shape}')\n", - "\n", - " # Forward pass through VAE encoder to get the latent PDF parameters.\n", - " cls_token_out = self.vae_encoder(\n", - " vae_encoder_input.permute(1, 0, 2),\n", - " pos_embed=pos_embed.permute(1, 0, 2),\n", - " key_padding_mask=key_padding_mask.permute(1,0),\n", - " )\n", - " print(f'cls_token_out.shape: {cls_token_out.shape}')\n", - " cls_token_out = cls_token_out[0] # select the class token, with shape (B, D)\n", - " print(f'cls_token_out[0].shape: {cls_token_out.shape}')\n", - " latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)\n", - " mu = latent_pdf_params[:, : self.config.latent_dim]\n", - " # This is 2log(sigma). Done this way to match the original implementation.\n", - " log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]\n", - "\n", - " # Sample the latent with the reparameterization trick.\n", - " latent_sample = mu + log_sigma_x2.div(2).exp() * Tensor.randn(*(mu.shape))\n", - " else:\n", - " # When not using the VAE encoder, we set the latent to be all zeros.\n", - " mu = log_sigma_x2 = None\n", - " # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer\n", - " latent_sample = Tensor.zeros(batch_size, self.config.latent_dim, dtype=dtypes.float32)\n", - "\n", - " # Prepare transformer encoder inputs.\n", - " encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]\n", - " encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))\n", - " # Robot state token.\n", - " if self.use_robot_state:\n", - " encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[\"observation.state\"]))\n", - " # Environment state token.\n", - " if self.use_env_state:\n", - " encoder_in_tokens.append(\n", - " self.encoder_env_state_input_proj(batch[\"observation.environment_state\"])\n", - " )\n", - "\n", - " # Camera observation features and positional embeddings.\n", - " if self.use_images:\n", - " all_cam_features = []\n", - " all_cam_pos_embeds = []\n", - "\n", - " for cam_index in range(batch[\"observation.images\"].shape[-4]):\n", - " cam_features = self.backbone(batch[\"observation.images\"][:, cam_index]) #[\"feature_map\"]\n", - " print(f'backbone output: {cam_features.shape}')\n", - " # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use\n", - " # buffer\n", - " cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).cast(dtype=cam_features.dtype)\n", - " cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)\n", - " print(f'cam_features: {cam_features.shape}')\n", - " all_cam_features.append(cam_features)\n", - " print(f'len all_cam_features: {len(all_cam_features)}')\n", - " all_cam_pos_embeds.append(cam_pos_embed)\n", - " # Concatenate camera observation feature maps and positional embeddings along the width dimension,\n", - " # and move to (sequence, batch, dim).\n", - " all_cam_features = Tensor.cat(*all_cam_features, dim=-1)\n", - " print(f'len all_cam_features after cat: {len(all_cam_features)}')\n", - " print(f'Before encoder_in_tokens.extend, encoder_in token len: {len(encoder_in_tokens)}')\n", - " encoder_in_tokens.extend(all_cam_features.permute(2, 3, 0, 1).reshape(-1, all_cam_features.shape[0], all_cam_features.shape[1]))\n", - " print(f'encoder_in_tokens: {len(encoder_in_tokens)}')\n", - " all_cam_pos_embeds = Tensor.cat(*all_cam_pos_embeds, dim=-1)\n", - " print(f'all_cam_pos_embeds: {all_cam_pos_embeds}')\n", - " encoder_in_pos_embed.extend(all_cam_pos_embeds.permute(2, 3, 0, 1).reshape(-1, all_cam_pos_embeds.shape[0], all_cam_pos_embeds.shape[1]))\n", - "\n", - " print(f'Before tensor.stack, encoder_in token len: {len(encoder_in_tokens)}')\n", - " print(f'Before tensor.stack, encoder_in_pos_embed token len: {len(encoder_in_pos_embed)}')\n", - " # Stack all tokens along the sequence dimension.\n", - " encoder_in_tokens = Tensor.stack(*encoder_in_tokens, dim=0)\n", - " encoder_in_pos_embed = Tensor.stack(*encoder_in_pos_embed, dim=0)\n", - "\n", - " print(f'encoder_in_tokens: {len(encoder_in_tokens)}')\n", - " print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')\n", - "\n", - " # Forward pass through the transformer modules.\n", - " encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)\n", - " # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer\n", - " decoder_in = Tensor.zeros(\n", - " *(self.config.chunk_size, batch_size, self.config.dim_model),\n", - " dtype=encoder_in_pos_embed.dtype\n", - " )\n", - " print(f'encoder_out.shape: {encoder_out.shape}')\n", - " print(f'decoder_in.shape: {decoder_in.shape}')\n", - " print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')\n", - " print(f'decoder_pos_embed.shape: {self.decoder_pos_embed.weight.shape}')\n", - " print(f'decoder_pos_embed.shape unsqueezed: {self.decoder_pos_embed.weight.unsqueeze(1).shape}')\n", - " decoder_out = self.decoder(\n", - " decoder_in.permute(1,0,2),\n", - " encoder_out.permute(1,0,2),\n", - " encoder_pos_embed=encoder_in_pos_embed.permute(1,0,2),\n", - " decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1).permute(1,0,2),\n", - " )\n", - "\n", - " # Move back to (B, S, C).\n", - " # decoder_out = decoder_out.transpose(0, 1)\n", - " print(f'decoder_out: {decoder_out.shape}')\n", - "\n", - " actions = self.action_head(decoder_out)\n", - "\n", - " return actions, (mu, log_sigma_x2)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "169cdbfe-7b4c-406a-ad51-276e0b5dabe8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ram used: 1.52 GB, layer4.1.bn2.running_var : 100%|█| \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loaded weights in 14.43 ms, 0.04 GB loaded at 3.10 GB/s\n" - ] - } - ], - "source": [ - "act = ACT(ACTConfig())" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "ac22d861-8cd6-4287-972c-1a270205101b", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTTemporalEnsembler:\n", - " def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:\n", - " \"\"\"Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.\n", - "\n", - " The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.\n", - " They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the\n", - " coefficient works:\n", - " - Setting it to 0 uniformly weighs all actions.\n", - " - Setting it positive gives more weight to older actions.\n", - " - Setting it negative gives more weight to newer actions.\n", - " NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This\n", - " results in older actions being weighed more highly than newer actions (the experiments documented in\n", - " https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be\n", - " detrimental: doing so aggressively may diminish the benefits of action chunking).\n", - "\n", - " Here we use an online method for computing the average rather than caching a history of actions in\n", - " order to compute the average offline. For a simple 1D sequence it looks something like:\n", - "\n", - " ```\n", - " import torch\n", - "\n", - " seq = torch.linspace(8, 8.5, 100)\n", - " print(seq)\n", - "\n", - " m = 0.01\n", - " exp_weights = torch.exp(-m * torch.arange(len(seq)))\n", - " print(exp_weights)\n", - "\n", - " # Calculate offline\n", - " avg = (exp_weights * seq).sum() / exp_weights.sum()\n", - " print(\"offline\", avg)\n", - "\n", - " # Calculate online\n", - " for i, item in enumerate(seq):\n", - " if i == 0:\n", - " avg = item\n", - " continue\n", - " avg *= exp_weights[:i].sum()\n", - " avg += item * exp_weights[i]\n", - " avg /= exp_weights[:i+1].sum()\n", - " print(\"online\", avg)\n", - " ```\n", - " \"\"\"\n", - " self.chunk_size = chunk_size\n", - " self.ensemble_weights = (-temporal_ensemble_coeff * Tensor.arange(chunk_size)).exp()\n", - " self.ensemble_weights_cumsum = self.ensemble_weights.cumsum(axis=0)\n", - " self.reset()\n", - "\n", - " def reset(self):\n", - " \"\"\"Resets the online computation variables.\"\"\"\n", - " self.ensembled_actions = None\n", - " # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.\n", - " self.ensembled_actions_count = None\n", - "\n", - " def update(self, actions: Tensor) -> Tensor:\n", - " \"\"\"\n", - " Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all\n", - " time steps, and pop/return the next batch of actions in the sequence.\n", - " \"\"\"\n", - " if self.ensembled_actions is None:\n", - " # Initializes `self._ensembled_action` to the sequence of actions predicted during the first\n", - " # time step of the episode.\n", - " self.ensembled_actions = actions.contiguous()\n", - " # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor\n", - " # operations later.\n", - " self.ensembled_actions_count = Tensor.ones(\n", - " *(self.chunk_size, 1), dtype=dtypes.long\n", - " )\n", - " else:\n", - " # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute\n", - " # the online update for those entries.\n", - " self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]\n", - " self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]\n", - " self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]\n", - " self.ensembled_actions_count = (self.ensembled_actions_count + 1).clamp(max_=self.chunk_size)\n", - " # The last action, which has no prior online average, needs to get concatenated onto the end.\n", - " self.ensembled_actions = Tensor.cat(*[self.ensembled_actions, actions[:, -1:]], dim=1)\n", - " self.ensembled_actions_count = Tensor.cat(\n", - " *[self.ensembled_actions_count, Tensor.ones_like(self.ensembled_actions_count[-1:])]\n", - " )\n", - " # \"Consume\" the first action.\n", - " action, self.ensembled_actions, self.ensembled_actions_count = (\n", - " self.ensembled_actions[:, 0],\n", - " self.ensembled_actions[:, 1:],\n", - " self.ensembled_actions_count[1:],\n", - " )\n", - " return action" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "0da33efc-c00f-4675-949f-ddf9d32a8c91", - "metadata": {}, - "outputs": [], - "source": [ - "from normalize import *\n", - "\n", - "class ACTPolicy:\n", - " \"\"\"\n", - " Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost\n", - " Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)\n", - " \"\"\"\n", - "\n", - " name = \"act\"\n", - "\n", - " def __init__(\n", - " self,\n", - " config: ACTConfig | None = None,\n", - " dataset_stats: dict[str, dict[str, Tensor]] | None = None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " config: Policy configuration class instance or None, in which case the default instantiation of\n", - " the configuration class is used.\n", - " dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected\n", - " that they will be passed with a call to `load_state_dict` before the policy is used.\n", - " \"\"\"\n", - " super().__init__()\n", - " if config is None:\n", - " config = ACTConfig()\n", - " self.config: ACTConfig = config\n", - "\n", - " self.normalize_inputs = Normalize(\n", - " config.input_shapes, config.input_normalization_modes, dataset_stats\n", - " )\n", - " self.normalize_targets = Normalize(\n", - " config.output_shapes, config.output_normalization_modes, dataset_stats\n", - " )\n", - " self.unnormalize_outputs = Unnormalize(\n", - " config.output_shapes, config.output_normalization_modes, dataset_stats\n", - " )\n", - "\n", - " self.model = ACT(config)\n", - "\n", - " self.expected_image_keys = [k for k in config.input_shapes if k.startswith(\"observation.image\")]\n", - "\n", - " if config.temporal_ensemble_coeff is not None:\n", - " self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)\n", - "\n", - " self.reset()\n", - "\n", - " def reset(self):\n", - " \"\"\"This should be called whenever the environment is reset.\"\"\"\n", - " if self.config.temporal_ensemble_coeff is not None:\n", - " self.temporal_ensembler.reset()\n", - " else:\n", - " self._action_queue = deque([], maxlen=self.config.n_action_steps)\n", - "\n", - " def select_action(self, batch: dict[str, Tensor]) -> Tensor:\n", - " \"\"\"Select a single action given environment observations.\n", - "\n", - " This method wraps `select_actions` in order to return one action at a time for execution in the\n", - " environment. It works by managing the actions in a queue and only calling `select_actions` when the\n", - " queue is empty.\n", - " \"\"\"\n", - " Tensor.no_grad = True\n", - " self.eval()\n", - "\n", - " batch = self.normalize_inputs(batch)\n", - " if len(self.expected_image_keys) > 0:\n", - " batch = dict(batch) # shallow copy so that adding a key doesn't modify the original\n", - " batch[\"observation.images\"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)\n", - "\n", - " # If we are doing temporal ensembling, do online updates where we keep track of the number of actions\n", - " # we are ensembling over.\n", - " if self.config.temporal_ensemble_coeff is not None:\n", - " actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)\n", - " actions = self.unnormalize_outputs({\"action\": actions})[\"action\"]\n", - " action = self.temporal_ensembler.update(actions)\n", - " return action\n", - "\n", - " # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by\n", - " # querying the policy.\n", - " if len(self._action_queue) == 0:\n", - " actions = self.model(batch)[0][:, : self.config.n_action_steps]\n", - "\n", - " # TODO(rcadene): make _forward return output dictionary?\n", - " actions = self.unnormalize_outputs({\"action\": actions})[\"action\"]\n", - "\n", - " # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue\n", - " # effectively has shape (n_action_steps, batch_size, *), hence the transpose.\n", - " self._action_queue.extend(actions.transpose(0, 1))\n", - " item_to_return = self._action_queue.popleft()\n", - " Tensor.no_grad = False\n", - " return item_to_return\n", - "\n", - " def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:\n", - " \"\"\"Run the batch through the model and compute the loss for training or validation.\"\"\"\n", - " batch = self.normalize_inputs(batch)\n", - " if len(self.expected_image_keys) > 0:\n", - " batch = dict(batch) # shallow copy so that adding a key doesn't modify the original\n", - " batch[\"observation.images\"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)\n", - " batch = self.normalize_targets(batch)\n", - " actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)\n", - "\n", - " l1_loss = (\n", - " (batch[\"action\"] - actions_hat).abs() * batch[\"action_is_pad\"].logical_not().int().unsqueeze(-1)\n", - " ).mean()\n", - "\n", - " loss_dict = {\"l1_loss\": l1_loss.item()}\n", - " if self.config.use_vae:\n", - " # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for\n", - " # each dimension independently, we sum over the latent dimension to get the total\n", - " # KL-divergence per batch element, then take the mean over the batch.\n", - " # (See App. B of https://arxiv.org/abs/1312.6114 for more details).\n", - " mean_kld = (\n", - " (-0.5 * (1 + log_sigma_x2_hat - mu_hat.square() - (log_sigma_x2_hat).exp())).sum(axis=-1).mean()\n", - " )\n", - " loss_dict[\"kld_loss\"] = mean_kld.item()\n", - " loss_dict[\"loss\"] = l1_loss + mean_kld * self.config.kl_weight\n", - " else:\n", - " loss_dict[\"loss\"] = l1_loss\n", - "\n", - " return loss_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "88ad724a-64be-47e7-9326-085ad096ab4a", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7c0600c9f3764a6383b2c57f21ff1f5b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 56 files: 0%| | 0/56 [00:00, None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n", - "cam_features: (8, 512, 15, 20)\n", - "len all_cam_features: 1\n", - "len all_cam_features after cat: 8\n", - "Before encoder_in_tokens.extend, encoder_in token len: 2\n", - "encoder_in_tokens: 302\n", - "all_cam_pos_embeds: on METAL with grad None>\n", - "Before tensor.stack, encoder_in token len: 302\n", - "Before tensor.stack, encoder_in_pos_embed token len: 302\n", - "encoder_in_tokens: 302\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "encoder_out.shape: (302, 8, 512)\n", - "decoder_in.shape: (100, 8, 512)\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "decoder_pos_embed.shape: (100, 512)\n", - "decoder_pos_embed.shape unsqueezed: (100, 1, 512)\n", - "decoder_out: (8, 100, 512)\n", - "step: 0 loss: 106.413\n", - "batch_size: 8\n", - "vae_encoder_input.shape: (8, 102, 512)\n", - "pos_embed.shape: (1, 102, 512)\n", - "key_padding_mask.shape: (8, 102)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "cls_token_out.shape: (102, 8, 512)\n", - "cls_token_out[0].shape: (8, 512)\n", - "backbone output: (8, 512, 15, 20)\n", - ", None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n", - "cam_features: (8, 512, 15, 20)\n", - "len all_cam_features: 1\n", - "len all_cam_features after cat: 8\n", - "Before encoder_in_tokens.extend, encoder_in token len: 2\n", - "encoder_in_tokens: 302\n", - "all_cam_pos_embeds: on METAL with grad None>\n", - "Before tensor.stack, encoder_in token len: 302\n", - "Before tensor.stack, encoder_in_pos_embed token len: 302\n", - "encoder_in_tokens: 302\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "encoder_out.shape: (302, 8, 512)\n", - "decoder_in.shape: (100, 8, 512)\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "decoder_pos_embed.shape: (100, 512)\n", - "decoder_pos_embed.shape unsqueezed: (100, 1, 512)\n", - "decoder_out: (8, 100, 512)\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[34], line 80\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m dataloader:\n\u001b[1;32m 79\u001b[0m batch \u001b[38;5;241m=\u001b[39m {k: Tensor(v\u001b[38;5;241m.\u001b[39mnumpy(), requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 80\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step \u001b[38;5;241m%\u001b[39m log_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py:81\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[34], line 59\u001b[0m, in \u001b[0;36mtrain_step\u001b[0;34m(batch)\u001b[0m\n\u001b[1;32m 57\u001b[0m opt_backbone\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 58\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 59\u001b[0m \u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m opt_backbone\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:34\u001b[0m, in \u001b[0;36mOptimizer.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 31\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;124;03m Performs a single optimization step.\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 34\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mrealize(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mschedule_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:42\u001b[0m, in \u001b[0;36mOptimizer.schedule_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;124;03mReturns the tensors that need to be realized to perform a single optimization step.\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m Tensor\u001b[38;5;241m.\u001b[39mtraining, (\n\u001b[1;32m 40\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\u001b[38;5;124mTensor.training=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mTensor\u001b[38;5;241m.\u001b[39mtraining\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Tensor.training must be enabled to use the optimizer.\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124m - help: Consider setting Tensor.training=True before calling Optimizer.step().\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m)\n\u001b[0;32m---> 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffers\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:149\u001b[0m, in \u001b[0;36mLAMB._step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1.0\u001b[39m\n\u001b[0;32m--> 149\u001b[0m t\u001b[38;5;241m.\u001b[39massign((t\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mup\u001b[49m)\u001b[38;5;241m.\u001b[39mcast(t\u001b[38;5;241m.\u001b[39mdtype))\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb1_t, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb2_t] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mm \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2747\u001b[0m, in \u001b[0;36mTensor.__mul__\u001b[0;34m(self, x)\u001b[0m\n\u001b[0;32m-> 2747\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__mul__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmul\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2554\u001b[0m, in \u001b[0;36mTensor.mul\u001b[0;34m(self, x, reverse)\u001b[0m\n\u001b[1;32m 2536\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmul\u001b[39m(\u001b[38;5;28mself\u001b[39m, x:Union[Tensor, ConstType], reverse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m 2537\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2538\u001b[0m \u001b[38;5;124;03m Multiplies `self` and `x`.\u001b[39;00m\n\u001b[1;32m 2539\u001b[0m \u001b[38;5;124;03m Equivalent to `self * x`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2554\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mMul\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_broadcasted\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2489\u001b[0m, in \u001b[0;36mTensor._broadcasted\u001b[0;34m(self, y, reverse, match_dtype)\u001b[0m\n\u001b[1;32m 2487\u001b[0m \u001b[38;5;66;03m# broadcast\u001b[39;00m\n\u001b[1;32m 2488\u001b[0m out_shape \u001b[38;5;241m=\u001b[39m _broadcast_shape(x\u001b[38;5;241m.\u001b[39mshape, y\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m-> 2489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_broadcast_to\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_shape\u001b[49m\u001b[43m)\u001b[49m, y\u001b[38;5;241m.\u001b[39m_broadcast_to(out_shape)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2466\u001b[0m, in \u001b[0;36mTensor._broadcast_to\u001b[0;34m(self, shape)\u001b[0m\n\u001b[1;32m 2464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlen\u001b[39m(shape): \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot broadcast tensor to fewer dimensions. shape=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2465\u001b[0m \u001b[38;5;66;03m# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html\u001b[39;00m\n\u001b[0;32m-> 2466\u001b[0m padded, _ \u001b[38;5;241m=\u001b[39m \u001b[43m_pad_left\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[38;5;66;03m# for each dimension, check either from_ is 1, or it does not change\u001b[39;00m\n\u001b[1;32m 2468\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(from_ \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m from_ \u001b[38;5;241m!=\u001b[39m to \u001b[38;5;28;01mfor\u001b[39;00m from_,to \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(padded, shape)): \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot broadcast from shape=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:88\u001b[0m, in \u001b[0;36m_pad_left\u001b[0;34m(*shapes)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_pad_left\u001b[39m(\u001b[38;5;241m*\u001b[39mshapes:Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 87\u001b[0m max_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(shape) \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_dim\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mshapes\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:88\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_pad_left\u001b[39m(\u001b[38;5;241m*\u001b[39mshapes:Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 87\u001b[0m max_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(shape) \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m((\u001b[38;5;241m1\u001b[39m,) \u001b[38;5;241m*\u001b[39m (max_dim \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mlen\u001b[39m(shape)) \u001b[38;5;241m+\u001b[39m shape \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "from pathlib import Path\n", - "\n", - "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset\n", - "from torch.utils.data import DataLoader\n", - "import torch\n", - "\n", - "import tinygrad\n", - "from tinygrad import Tensor, nn, TinyJit\n", - "\n", - "from omegaconf import ListConfig, OmegaConf\n", - "\n", - "from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict\n", - "\n", - "# Start of training code\n", - "\n", - "# Create a directory to store the training checkpoint.\n", - "output_directory = Path(\"outputs/train/example_pusht\")\n", - "output_directory.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# Number of offline training steps (we'll only do offline training for this example.)\n", - "# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.\n", - "training_steps = 100000\n", - "log_freq = 1\n", - "\n", - "# Set up the dataset.\n", - "delta_timestamps = {\n", - " \"action\": [i / 50.0 for i in range(100)],\n", - "}\n", - "dataset = LeRobotDataset('lerobot/aloha_sim_insertion_human', delta_timestamps=delta_timestamps)\n", - "print(dataset.stats)\n", - "\n", - "cfg = ACTConfig()\n", - "policy = ACTPolicy(cfg, dataset_stats=dataset.stats)\n", - "\n", - "params_not_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and not n.startswith(\"model.backbone\")]\n", - "params_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and n.startswith(\"model.backbone\")]\n", - "\n", - "Tensor.manual_seed(1000)\n", - "\n", - "if hasattr(cfg, 'override_dataset_stats'):\n", - " for key, stats_dict in cfg.override_dataset_stats.items():\n", - " for stats_type, listconfig in stats_dict.items():\n", - " # example of stats_type: min, max, mean, std\n", - " print(f'listconfig: {listconfig}')\n", - " dataset.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32)\n", - "\n", - "opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4)\n", - "opt_backbone = nn.optim.AdamW(params_backbone, lr=1e-5, weight_decay=1e-4)\n", - "\n", - "#@TinyJit\n", - "@Tensor.train()\n", - "def train_step(batch) -> Tensor:\n", - " Tensor.training = True\n", - " output_dict = policy(batch)\n", - " loss = output_dict[\"loss\"]\n", - " opt.zero_grad()\n", - " opt_backbone.zero_grad()\n", - " loss.backward()\n", - " opt.step()\n", - " opt_backbone.step()\n", - " return loss\n", - "\n", - "print(f'Starting training loop')\n", - "# Create dataloader for offline training.\n", - "dataloader = DataLoader(\n", - " dataset,\n", - " num_workers=0,\n", - " batch_size=8,\n", - " shuffle=True,\n", - " pin_memory=False,\n", - " drop_last=True,\n", - ")\n", - "\n", - "step = 0\n", - "done = False\n", - "with Tensor.train():\n", - " while not done:\n", - " for batch in dataloader:\n", - " batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()}\n", - " loss = train_step(batch)\n", - " \n", - " if step % log_freq == 0:\n", - " print(f\"step: {step} loss: {loss.numpy():.3f}\")\n", - " step += 1\n", - "\n", - " if step % 10000 == 0:\n", - " try:\n", - " state_dict = get_state_dict(policy)\n", - " safe_save(state_dict, f'{output_directory}/model_{step}.safetensors')\n", - " except:\n", - " print(f'Exception with safe save occured')\n", - " if step >= training_steps:\n", - " done = True\n", - " break\n", - "\n", - "# Save a policy checkpoint.\n", - "state_dict = get_state_dict(policy)\n", - "safe_save(state_dict, f'{output_directory}/model_final.safetensors')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b347771d-4f2a-4a01-8382-e8e2970bbb92", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7601f50-41d4-493f-8957-e407ee3ababe", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6bfdb2a1-6a3e-4d79-aef0-11e90efece43", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} +version https://git-lfs.github.com/spec/v1 +oid sha256:09c9db189405883e57a98eb4926349b5aece2df7d62c7ec37c0b093531a130db +size 102213 diff --git a/01. Test/test.py b/01. Test/test.py new file mode 100644 index 0000000..01bff52 --- /dev/null +++ b/01. Test/test.py @@ -0,0 +1,21 @@ +# example.py +import imageio +import gymnasium as gym +import numpy as np +import gym_aloha + +env = gym.make("gym_aloha/AlohaInsertion-v0") +observation, info = env.reset() +frames = [] + +for _ in range(1000): + action = env.action_space.sample() + observation, reward, terminated, truncated, info = env.step(action) + image = env.render() + frames.append(image) + + if terminated or truncated: + observation, info = env.reset() + +env.close() +imageio.mimsave("example.mp4", np.stack(frames), fps=25) \ No newline at end of file diff --git a/modeling_act.ipynb b/modeling_act.ipynb index 8ba9868..18d2e33 100644 --- a/modeling_act.ipynb +++ b/modeling_act.ipynb @@ -1,1738 +1,3 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0441978e-0b8a-4dc5-9300-f04249e4e2e3", - "metadata": {}, - "source": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0c941a80-30a4-4680-9f76-5ac76c67bae1", - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "from collections import deque\n", - "from itertools import chain\n", - "from typing import Callable\n", - "\n", - "import numpy as np\n", - "import tinygrad\n", - "from tinygrad import Tensor, nn, dtypes" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "88efb4dc-05e4-4ec7-8c46-59657e281141", - "metadata": {}, - "outputs": [], - "source": [ - "def get_activation_fn(activation: str) -> Callable:\n", - " \"\"\"Return an activation function given a string.\"\"\"\n", - " if activation == \"relu\":\n", - " return Tensor.relu\n", - " if activation == \"gelu\":\n", - " return Tensor.gelu\n", - " raise RuntimeError(f\"activation should be relu/gelu/glu, not {activation}.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "e6ce0e70-2b63-4240-878b-215a80e856a4", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "get_activation_fn('relu')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "aac9f722-55ce-4bca-9008-cf0f68152dd6", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTSinusoidalPositionEmbedding2d:\n", - " \"\"\"2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need.\n", - "\n", - " The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H\n", - " for the vertical direction, and 1/W for the horizontal direction.\n", - " \"\"\"\n", - "\n", - " def __init__(self, dimension: int):\n", - " \"\"\"\n", - " Args:\n", - " dimension: The desired dimension of the embeddings.\n", - " \"\"\"\n", - " super().__init__()\n", - " self.dimension = dimension\n", - " self._two_pi = 2 * math.pi\n", - " self._eps = 1e-6\n", - " # Inverse \"common ratio\" for the geometric progression in sinusoid frequencies.\n", - " self._temperature = 10000\n", - "\n", - " def __call__(self, x: Tensor) -> Tensor:\n", - " \"\"\"\n", - " Args:\n", - " x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for.\n", - " Returns:\n", - " A (1, C, H, W) batch of corresponding sinusoidal positional embeddings.\n", - " \"\"\"\n", - " not_mask = Tensor.ones_like(x[0, :1]) # (1, H, W)\n", - " # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations\n", - " # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code.\n", - " y_range = not_mask.cumsum(1).cast(dtype=dtypes.float32)\n", - " x_range = not_mask.cumsum(2).cast(dtype=dtypes.float32)\n", - "\n", - " # \"Normalize\" the position index such that it ranges in [0, 2π].\n", - " # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range\n", - " # are non-zero by construction. This is an artifact of the original code.\n", - " y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi\n", - " x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi\n", - "\n", - " inverse_frequency = Tensor(self._temperature ** (\n", - " 2 * (np.arange(self.dimension, dtype='f') // 2) / self.dimension\n", - " ))\n", - "\n", - " x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)\n", - " y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1)\n", - "\n", - " print(x_range)\n", - " print(y_range)\n", - "\n", - " # Note: this stack then flatten operation results in interleaved sine and cosine terms.\n", - " # pos_embed_x and pos_embed_y are (1, H, W, C // 2).\n", - " x_range_sin = x_range[..., 0::2].sin()\n", - " x_range_cos = x_range[..., 1::2].cos()\n", - " y_range_sin = y_range[..., 0::2].sin()\n", - " y_range_cos = y_range[..., 1::2].cos()\n", - " print(f'x_range[..., 0::2].sin(): {x_range_sin}')\n", - " print(f'x_range[..., 1::2].cos(): {x_range_cos}')\n", - " pos_embed_x = x_range_sin.stack(x_range_cos, dim=-1).flatten(3)\n", - " pos_embed_y = y_range_sin.stack(y_range_cos, dim=-1).flatten(3)\n", - " pos_embed = pos_embed_y.cat(pos_embed_x, dim=3).permute(0, 3, 1, 2) # (1, C, H, W)\n", - "\n", - " return pos_embed" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7692187e-822c-4b68-924b-2788e1cf7723", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - ", None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n" - ] - }, - { - "data": { - "text/plain": [ - " on METAL with grad None>" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actSin = ACTSinusoidalPositionEmbedding2d(10)\n", - "actSin(Tensor.zeros(4,4,4,4))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "2a0ef1f0-ddf7-40d7-a6bf-2b3e7480ef82", - "metadata": {}, - "outputs": [], - "source": [ - "def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor:\n", - " \"\"\"1D sinusoidal positional embeddings as in Attention is All You Need.\n", - "\n", - " Args:\n", - " num_positions: Number of token positions required.\n", - " Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension).\n", - "\n", - " \"\"\"\n", - "\n", - " def get_position_angle_vec(position):\n", - " return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)]\n", - "\n", - " sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)], dtype='f')\n", - " sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i\n", - " sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1\n", - " return Tensor(sinusoid_table).float()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "19404914-6edd-4bcb-8731-dc1855149fce", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 0. , 1. , 0. ],\n", - " [ 0.841471 , 0.5403023 , 0.00215443],\n", - " [ 0.9092974 , -0.4161468 , 0.00430886]], dtype=float32)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "create_sinusoidal_pos_embedding(3, 3).numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f8dd33d3-a35f-403d-a46c-6f228739c928", - "metadata": {}, - "outputs": [], - "source": [ - "from tinygrad import Tensor, nn\n", - "from typing import Optional, Union, Literal\n", - "from tinygrad.ops import Variable\n", - "\n", - "class MultiheadAttention:\n", - " def __init__(self, embed_dim, num_heads, dropout=0.0):\n", - " self.embed_dim = embed_dim\n", - " self.num_heads = num_heads\n", - " self.head_dim = embed_dim // num_heads\n", - " assert self.head_dim * num_heads == embed_dim, \"n_state must be divisible by n_head\"\n", - "\n", - " self.query = nn.Linear(embed_dim, embed_dim)\n", - " self.key = nn.Linear(embed_dim, embed_dim)\n", - " self.value = nn.Linear(embed_dim, embed_dim)\n", - " self.out = nn.Linear(embed_dim, embed_dim)\n", - "\n", - " self.scaling = self.head_dim ** -0.5 \n", - " self.dropout = dropout\n", - "\n", - " def __call__(self, q: Tensor, k: Tensor, v: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, training: bool = True):\n", - " batch_size, tgt_len, _ = q.shape\n", - " src_len = k.shape[1]\n", - "\n", - " # Apply linear transformations\n", - " q = self.query(q)\n", - " k = self.key(k)\n", - " v = self.value(v)\n", - "\n", - " # Reshape and transpose for multi-head attention\n", - " q = q.reshape(batch_size, tgt_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " k = k.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " v = v.reshape(batch_size, src_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)\n", - " \n", - " # Calculate attention scores\n", - " attn_scores = (q @ k.transpose(-2, -1)) * self.scaling\n", - "\n", - " # Apply key padding mask if provided\n", - " if key_padding_mask is not None:\n", - " print(f'(q,k,v): {q.shape}, {k.shape}, {v.shape}')\n", - " print(f'(key_padding_mask): {key_padding_mask.shape}')\n", - " # Reshape and expand key_padding_mask to match attn_scores dimensions\n", - " key_padding_mask = key_padding_mask.squeeze(1).squeeze(1) # Remove extra dimensions\n", - " key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # Add dimensions for heads and query length\n", - " key_padding_mask = key_padding_mask.expand(batch_size, self.num_heads, tgt_len, src_len)\n", - " attn_scores = attn_scores.masked_fill(key_padding_mask, float('-inf')) \n", - " \n", - " # Apply softmax to get attention weights\n", - " attn_weights = attn_scores.softmax(axis=-1)\n", - "\n", - " # Apply dropout\n", - " if self.dropout > 0:\n", - " attn_weights = attn_weights.dropout(p=self.dropout)\n", - "\n", - " # Apply attention to values\n", - " attn_output = attn_weights @ v\n", - "\n", - " # Reshape and combine heads\n", - " attn_output = attn_output.permute(0, 2, 1, 3).reshape(batch_size, tgt_len, self.embed_dim)\n", - "\n", - " # Final projection\n", - " attn_output = self.out(attn_output)\n", - "\n", - " return attn_output\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6a91632b-0a49-4bec-b500-215fc7b34ac5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mha = MultiheadAttention(9, 9)\n", - "mha(Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9), Tensor.zeros(9, 9, 9))" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "db7c53ee-be78-41ba-b463-4083cf12d92c", - "metadata": {}, - "outputs": [], - "source": [ - "from dataclasses import dataclass, field\n", - "\n", - "@dataclass\n", - "class ACTConfig:\n", - " \"\"\"Configuration class for the Action Chunking Transformers policy.\n", - "\n", - " Defaults are configured for training on bimanual Aloha tasks like \"insertion\" or \"transfer\".\n", - "\n", - " The parameters you will most likely need to change are the ones which depend on the environment / sensors.\n", - " Those are: `input_shapes` and 'output_shapes`.\n", - "\n", - " Notes on the inputs and outputs:\n", - " - Either:\n", - " - At least one key starting with \"observation.image is required as an input.\n", - " AND/OR\n", - " - The key \"observation.environment_state\" is required as input.\n", - " - If there are multiple keys beginning with \"observation.images.\" they are treated as multiple camera\n", - " views. Right now we only support all images having the same shape.\n", - " - May optionally work without an \"observation.state\" key for the proprioceptive robot state.\n", - " - \"action\" is required as an output key.\n", - "\n", - " Args:\n", - " n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the\n", - " current step and additional steps going back).\n", - " chunk_size: The size of the action prediction \"chunks\" in units of environment steps.\n", - " n_action_steps: The number of action steps to run in the environment for one invocation of the policy.\n", - " This should be no greater than the chunk size. For example, if the chunk size size 100, you may\n", - " set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the\n", - " environment, and throws the other 50 out.\n", - " input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents\n", - " the input data name, and the value is a list indicating the dimensions of the corresponding data.\n", - " For example, \"observation.image\" refers to an input from a camera with dimensions [3, 96, 96],\n", - " indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't\n", - " include batch dimension or temporal dimension.\n", - " output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents\n", - " the output data name, and the value is a list indicating the dimensions of the corresponding data.\n", - " For example, \"action\" refers to an output shape of [14], indicating 14-dimensional actions.\n", - " Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.\n", - " input_normalization_modes: A dictionary with key representing the modality (e.g. \"observation.state\"),\n", - " and the value specifies the normalization mode to apply. The two available modes are \"mean_std\"\n", - " which subtracts the mean and divides by the standard deviation and \"min_max\" which rescale in a\n", - " [-1, 1] range.\n", - " output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the\n", - " original scale. Note that this is also used for normalizing the training targets.\n", - " vision_backbone: Name of the torchvision resnet backbone to use for encoding images.\n", - " pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone.\n", - " `None` means no pretrained weights.\n", - " replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated\n", - " convolution.\n", - " pre_norm: Whether to use \"pre-norm\" in the transformer blocks.\n", - " dim_model: The transformer blocks' main hidden dimension.\n", - " n_heads: The number of heads to use in the transformer blocks' multi-head attention.\n", - " dim_feedforward: The dimension to expand the transformer's hidden dimension to in the feed-forward\n", - " layers.\n", - " feedforward_activation: The activation to use in the transformer block's feed-forward layers.\n", - " n_encoder_layers: The number of transformer layers to use for the transformer encoder.\n", - " n_decoder_layers: The number of transformer layers to use for the transformer decoder.\n", - " use_vae: Whether to use a variational objective during training. This introduces another transformer\n", - " which is used as the VAE's encoder (not to be confused with the transformer encoder - see\n", - " documentation in the policy class).\n", - " latent_dim: The VAE's latent dimension.\n", - " n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.\n", - " temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal\n", - " ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be\n", - " 1 when using this feature, as inference needs to happen at every step to form an ensemble. For\n", - " more information on how ensembling works, please see `ACTTemporalEnsembler`.\n", - " dropout: Dropout to use in the transformer layers (see code for details).\n", - " kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective\n", - " is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.\n", - " \"\"\"\n", - "\n", - " # Input / output structure.\n", - " n_obs_steps: int = 1\n", - " chunk_size: int = 100\n", - " n_action_steps: int = 100\n", - "\n", - " input_shapes: dict[str, list[int]] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": [3, 480, 640],\n", - " \"observation.state\": [14],\n", - " }\n", - " )\n", - " output_shapes: dict[str, list[int]] = field(\n", - " default_factory=lambda: {\n", - " \"action\": [14],\n", - " }\n", - " )\n", - "\n", - " # Normalization / Unnormalization\n", - " input_normalization_modes: dict[str, str] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": \"mean_std\",\n", - " \"observation.state\": \"mean_std\",\n", - " }\n", - " )\n", - " output_normalization_modes: dict[str, str] = field(\n", - " default_factory=lambda: {\n", - " \"action\": \"mean_std\",\n", - " }\n", - " )\n", - "\n", - " # Overrides.\n", - " override_dataset_stats: dict[str, dict[str, list[[float]]]] = field(\n", - " default_factory=lambda: {\n", - " \"observation.images.top\": {\n", - " \"mean\": [[[0.485]], [[0.456]], [[0.406]]], # (c,1,1)\n", - " \"std\": [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)\n", - " }\n", - " }\n", - " )\n", - "\n", - " # Architecture.\n", - " # Vision backbone.\n", - " vision_backbone: str = \"resnet18\"\n", - " pretrained_backbone_weights: str | None = \"ResNet18_Weights.IMAGENET1K_V1\"\n", - " replace_final_stride_with_dilation: int = False\n", - " # Transformer layers.\n", - " pre_norm: bool = False\n", - " dim_model: int = 512\n", - " n_heads: int = 8\n", - " dim_feedforward: int = 3200\n", - " feedforward_activation: str = \"relu\"\n", - " n_encoder_layers: int = 4\n", - " # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code\n", - " # that means only the first layer is used. Here we match the original implementation by setting this to 1.\n", - " # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.\n", - " n_decoder_layers: int = 1\n", - " # VAE.\n", - " use_vae: bool = True\n", - " latent_dim: int = 32\n", - " n_vae_encoder_layers: int = 4\n", - "\n", - " # Inference.\n", - " # Note: the value used in ACT when temporal ensembling is enabled is 0.01.\n", - " temporal_ensemble_coeff: float | None = None\n", - "\n", - " # Training and loss computation.\n", - " dropout: float = 0.1\n", - " kl_weight: float = 10.0\n", - "\n", - " def __post_init__(self):\n", - " \"\"\"Input validation (not exhaustive).\"\"\"\n", - " if not self.vision_backbone.startswith(\"resnet\"):\n", - " raise ValueError(\n", - " f\"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}.\"\n", - " )\n", - " if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:\n", - " raise NotImplementedError(\n", - " \"`n_action_steps` must be 1 when using temporal ensembling. This is \"\n", - " \"because the policy needs to be queried every step to compute the ensembled action.\"\n", - " )\n", - " if self.n_action_steps > self.chunk_size:\n", - " raise ValueError(\n", - " f\"The chunk size is the upper bound for the number of action steps per model invocation. Got \"\n", - " f\"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`.\"\n", - " )\n", - " if self.n_obs_steps != 1:\n", - " raise ValueError(\n", - " f\"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`\"\n", - " )\n", - " if (\n", - " not any(k.startswith(\"observation.image\") for k in self.input_shapes)\n", - " and \"observation.environment_state\" not in self.input_shapes\n", - " ):\n", - " raise ValueError(\"You must provide at least one image or the environment state among the inputs.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "f21a9b74-1b92-4646-b093-27fedfc72513", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTDecoderLayer:\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - " self.multihead_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - "\n", - " # Feed forward layers.\n", - " self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)\n", - " self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)\n", - "\n", - " self.norm1 = nn.LayerNorm(config.dim_model)\n", - " self.norm2 = nn.LayerNorm(config.dim_model)\n", - " self.norm3 = nn.LayerNorm(config.dim_model)\n", - " self.dropout_rate = config.dropout\n", - "\n", - " self.activation = get_activation_fn(config.feedforward_activation)\n", - " self.pre_norm = config.pre_norm\n", - "\n", - " def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor:\n", - " return tensor if pos_embed is None else tensor + pos_embed\n", - "\n", - " def __call__(\n", - " self,\n", - " x: Tensor,\n", - " encoder_out: Tensor,\n", - " decoder_pos_embed: Tensor | None = None,\n", - " encoder_pos_embed: Tensor | None = None,\n", - " ) -> Tensor:\n", - " \"\"\"\n", - " Args:\n", - " x: (Decoder Sequence, Batch, Channel) tensor of input tokens.\n", - " encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are\n", - " cross-attending with.\n", - " decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder).\n", - " encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder).\n", - " Returns:\n", - " (DS, B, C) tensor of decoder output features.\n", - " \"\"\"\n", - " skip = x\n", - " if self.pre_norm:\n", - " x = self.norm1(x)\n", - " q = k = self.maybe_add_pos_embed(x, decoder_pos_embed)\n", - " x = self.self_attn(q, k, x) \n", - " #x = x[0] # select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm2(x)\n", - " else:\n", - " x = self.norm1(x)\n", - " skip = x\n", - " x = self.multihead_attn(\n", - " self.maybe_add_pos_embed(x, decoder_pos_embed),\n", - " self.maybe_add_pos_embed(encoder_out, encoder_pos_embed),\n", - " encoder_out,\n", - " )\n", - " #x = x[0] # select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm3(x)\n", - " else:\n", - " x = self.norm2(x)\n", - " skip = x\n", - " \n", - " x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout_rate).sequential([self.linear2])\n", - " x = skip + x.dropout(p=self.dropout_rate)\n", - " if not self.pre_norm:\n", - " x = self.norm3(x)\n", - " return x\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "79ac1c63-ded4-4b6e-8291-3e8b0c5e5d49", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actDecoder = ACTDecoderLayer(ACTConfig())\n", - "actDecoder(Tensor.zeros(3,512, 512), Tensor.zeros(3,512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "199f37fd-b9ff-4946-9531-53c7d90e6332", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTDecoder:\n", - " def __init__(self, config: ACTConfig):\n", - " \"\"\"Convenience module for running multiple decoder layers followed by normalization.\"\"\"\n", - " super().__init__()\n", - " self.layers = [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]\n", - " self.norm = nn.LayerNorm(config.dim_model)\n", - "\n", - " def __call__(\n", - " self,\n", - " x: Tensor,\n", - " encoder_out: Tensor,\n", - " decoder_pos_embed: Tensor | None = None,\n", - " encoder_pos_embed: Tensor | None = None,\n", - " ) -> Tensor:\n", - " for layer in self.layers:\n", - " x = layer(\n", - " x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed\n", - " )\n", - " if self.norm is not None:\n", - " x = self.norm(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "8fc180ee-fe07-44a8-863a-6d064a93fb32", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actDecode = ACTDecoder(ACTConfig())\n", - "actDecode(Tensor.zeros(3,512,512), Tensor.zeros(3,512,512))" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "5fd920a5-55bc-4f74-a4f0-896ecbc800cc", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTEncoderLayer:\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.self_attn = MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout)\n", - "\n", - " # Feed forward layers.\n", - " self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward)\n", - " self.dropout = config.dropout\n", - " self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model)\n", - "\n", - " self.norm1 = nn.LayerNorm(config.dim_model)\n", - " self.norm2 = nn.LayerNorm(config.dim_model)\n", - "\n", - " self.activation = get_activation_fn(config.feedforward_activation)\n", - " self.pre_norm = config.pre_norm\n", - "\n", - " def __call__(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:\n", - " skip = x\n", - " if self.pre_norm:\n", - " x = self.norm1(x)\n", - " q = k = x if pos_embed is None else x + pos_embed\n", - " x = self.self_attn(q, k, x, key_padding_mask=key_padding_mask)\n", - " # x = x[0] # note: [0] to select just the output, not the attention weights\n", - " x = skip + x.dropout(p=self.dropout)\n", - " if self.pre_norm:\n", - " skip = x\n", - " x = self.norm2(x)\n", - " else:\n", - " x = self.norm1(x)\n", - " skip = x\n", - " x = x.sequential([self.linear1, self.activation]).dropout(p=self.dropout).sequential([self.linear2])\n", - " x = skip + x.dropout(p=self.dropout)\n", - " if not self.pre_norm:\n", - " x = self.norm2(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "4e2839ad-6fa9-47f9-80e5-2681f7eb0d1e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actEncode = ACTEncoderLayer(ACTConfig())\n", - "actEncode(Tensor.zeros(3, 512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "a7ec9da4-bea0-46bc-9bb0-7ed27a1fd6d7", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTEncoder:\n", - " \"\"\"Convenience module for running multiple encoder layers, maybe followed by normalization.\"\"\"\n", - "\n", - " def __init__(self, config: ACTConfig, is_vae_encoder: bool = False):\n", - " super().__init__()\n", - " self.is_vae_encoder = is_vae_encoder\n", - " num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers\n", - " self.layers = [ACTEncoderLayer(config) for _ in range(num_layers)]\n", - " self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else lambda x: x\n", - "\n", - " def __call__(\n", - " self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None\n", - " ) -> Tensor:\n", - " for layer in self.layers:\n", - " print(f'ACTEncoder x.shape per layer: {x.shape}')\n", - " x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)\n", - " x = self.norm(x)\n", - " return x\n" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "c3747d03-b4ac-4f61-a4f3-5b9e94eb417f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ", None)> on METAL with grad None>" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "actEncoder = ACTEncoder(ACTConfig())\n", - "actEncode(Tensor.zeros(3, 512, 512))" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "1ad78fa2-7179-477c-88c0-ae2dea72ef09", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ram used: 0.04 GB, layer4.1.bn2.running_var : 100%|█| \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loaded weights in 30.70 ms, 0.04 GB loaded at 1.46 GB/s\n" - ] - } - ], - "source": [ - "import tinygrad.nn as nn\n", - "from tinygrad import Tensor, dtypes\n", - "from tinygrad.helpers import fetch, get_child\n", - "\n", - "# allow monkeypatching in layer implementations\n", - "BatchNorm = nn.BatchNorm2d\n", - "Conv2d = nn.Conv2d\n", - "Linear = nn.Linear\n", - "\n", - "class FrozenBatchNorm2d:\n", - " def __init__(self, num_features, eps=1e-5):\n", - " super().__init__()\n", - " self.num_features = num_features\n", - " self.eps = eps\n", - " # Register buffers instead of parameters\n", - " self.weight = Tensor.ones(num_features, requires_grad=False)\n", - " self.bias = Tensor.zeros(num_features, requires_grad=False)\n", - " self.running_mean = Tensor.zeros(num_features, requires_grad=False)\n", - " self.running_var = Tensor.ones(num_features, requires_grad=False)\n", - " def __call__(self, x:Tensor) -> Tensor:\n", - " # Reshape for 2D input\n", - " scale = (self.weight / (self.running_var + self.eps).sqrt()).reshape(1, -1, 1, 1)\n", - " bias = (self.bias - self.running_mean * scale.flatten()).reshape(1, -1, 1, 1)\n", - " return x * scale + bias\n", - "\n", - "class Block:\n", - " def __init__(self, in_dims, dims, stride=1):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(\n", - " in_dims, dims, kernel_size=3, stride=stride, padding=1, bias=False\n", - " )\n", - " self.bn1 = FrozenBatchNorm2d(dims)\n", - " self.conv2 = nn.Conv2d(\n", - " dims, dims, kernel_size=3, stride=1, padding=1, bias=False\n", - " )\n", - " self.bn2 = FrozenBatchNorm2d(dims)\n", - " self.downsample = []\n", - " if stride != 1:\n", - " self.downsample = [\n", - " nn.Conv2d(in_dims, dims, kernel_size=1, stride=stride, bias=False),\n", - " FrozenBatchNorm2d(dims)\n", - " ]\n", - " def __call__(self, x):\n", - " base_operations = [\n", - " self.conv1,\n", - " self.bn1,\n", - " Tensor.relu,\n", - " self.conv2,\n", - " self.bn2\n", - " ]\n", - " out = x.sequential(base_operations)\n", - " \n", - " if self.downsample != []:\n", - " return (x.sequential(base_operations) + x.sequential(self.downsample)).relu()\n", - " else:\n", - " return x.sequential(base_operations).relu()\n", - "\n", - "class ResNet:\n", - " def __init__(self, block, num_blocks, num_classes=10):\n", - " super().__init__()\n", - " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)\n", - " self.bn1 = FrozenBatchNorm2d(64)\n", - " self.layer1 = self._make_layer(block, 64, 64, num_blocks[0], stride=1)\n", - " self.layer2 = self._make_layer(block, 64, 128, num_blocks[1], stride=2)\n", - " self.layer3 = self._make_layer(block, 128, 256, num_blocks[2], stride=2)\n", - " self.layer4 = self._make_layer(block, 256, 512, num_blocks[3], stride=2)\n", - " #self.fc = nn.Linear(512, num_classes, requires_grad=False) # if we decide to use this someday, remove the grad\n", - " def _make_layer(self, block, in_dims, dims, num_blocks, stride):\n", - " strides = [stride] + [1] * (num_blocks - 1)\n", - " layers = []\n", - " for stride in strides:\n", - " layers.append(block(in_dims, dims, stride))\n", - " in_dims = dims\n", - " return layers\n", - " def __call__(self, x:Tensor):\n", - " x = self.bn1(self.conv1(x)).relu().max_pool2d()\n", - " x = x.sequential(self.layer1)\n", - " x = x.sequential(self.layer2 + self.layer3 + self.layer4)\n", - " \"\"\"\n", - " Commented out for now, because we're just using the output from layer4\n", - " \"\"\"\n", - " #x = x.mean([2, 3])\n", - " #x = self.fc(x)\n", - " return x\n", - "\n", - "resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)\n", - "state_dict = nn.state.safe_load(\"resnet18-f37072fd.safetensors\")\n", - "nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "e226e543-2c2f-4326-9d50-5c825834c044", - "metadata": {}, - "outputs": [], - "source": [ - "from itertools import chain\n", - "\n", - "class ACT:\n", - " \"\"\"Action Chunking Transformer: The underlying neural network for ACTPolicy.\n", - "\n", - " Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows.\n", - " - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the\n", - " model that encodes the target data (a sequence of actions), and the condition (the robot\n", - " joint-space).\n", - " - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with\n", - " cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we\n", - " have an option to train this model without the variational objective (in which case we drop the\n", - " `vae_encoder` altogether, and nothing about this model has anything to do with a VAE).\n", - "\n", - " Transformer\n", - " Used alone for inference\n", - " (acts as VAE decoder\n", - " during training)\n", - " ┌───────────────────────┐\n", - " │ Outputs │\n", - " │ ▲ │\n", - " │ ┌─────►┌───────┐ │\n", - " ┌──────┐ │ │ │Transf.│ │\n", - " │ │ │ ├─────►│decoder│ │\n", - " ┌────┴────┐ │ │ │ │ │ │\n", - " │ │ │ │ ┌───┴───┬─►│ │ │\n", - " │ VAE │ │ │ │ │ └───────┘ │\n", - " │ encoder │ │ │ │Transf.│ │\n", - " │ │ │ │ │encoder│ │\n", - " └───▲─────┘ │ │ │ │ │\n", - " │ │ │ └▲──▲─▲─┘ │\n", - " │ │ │ │ │ │ │\n", - " inputs └─────┼──┘ │ image emb. │\n", - " │ state emb. │\n", - " └───────────────────────┘\n", - " \"\"\"\n", - "\n", - " def __init__(self, config: ACTConfig):\n", - " super().__init__()\n", - " self.config = config\n", - " # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].\n", - " # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).\n", - " self.use_robot_state = \"observation.state\" in config.input_shapes\n", - " self.use_images = any(k.startswith(\"observation.image\") for k in config.input_shapes)\n", - " self.use_env_state = \"observation.environment_state\" in config.input_shapes\n", - " if self.config.use_vae:\n", - " self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)\n", - " self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)\n", - " # Projection layer for joint-space configuration to hidden dimension.\n", - " if self.use_robot_state:\n", - " self.vae_encoder_robot_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.state\"][0], config.dim_model\n", - " )\n", - " # Projection layer for action (joint-space target) to hidden dimension.\n", - " self.vae_encoder_action_input_proj = nn.Linear(\n", - " config.output_shapes[\"action\"][0], config.dim_model\n", - " )\n", - " # Projection layer from the VAE encoder's output to the latent distribution's parameter space.\n", - " self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)\n", - " # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch\n", - " # dimension.\n", - " num_input_token_encoder = 1 + config.chunk_size\n", - " if self.use_robot_state:\n", - " num_input_token_encoder += 1\n", - " self.vae_encoder_pos_enc = create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0)\n", - " self.vae_encoder_pos_enc.requires_grad = False\n", - "\n", - " # Backbone for image feature extraction.\n", - " if self.use_images:\n", - " resnet18_IMAGENET1K_V1 = ResNet(Block, [2, 2, 2, 2], num_classes=1000)\n", - " state_dict = nn.state.safe_load(\"resnet18-f37072fd.safetensors\")\n", - " nn.state.load_state_dict(resnet18_IMAGENET1K_V1, state_dict)\n", - " backbone_model = resnet18_IMAGENET1K_V1\n", - " # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final\n", - " # feature map).\n", - " # Note: The forward method of this returns a dict: {\"feature_map\": output}.\n", - " self.backbone = backbone_model #IntermediateLayerGetter(backbone_model, return_layers={\"layer4\": \"feature_map\"})\n", - "\n", - " # Transformer (acts as VAE decoder when training with the variational objective).\n", - " self.encoder = ACTEncoder(config)\n", - " self.decoder = ACTDecoder(config)\n", - "\n", - " # Transformer encoder input projections. The tokens will be structured like\n", - " # [latent, (robot_state), (env_state), (image_feature_map_pixels)].\n", - " if self.use_robot_state:\n", - " self.encoder_robot_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.state\"][0], config.dim_model\n", - " )\n", - " if self.use_env_state:\n", - " self.encoder_env_state_input_proj = nn.Linear(\n", - " config.input_shapes[\"observation.environment_state\"][0], config.dim_model\n", - " )\n", - " self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)\n", - " if self.use_images:\n", - " self.encoder_img_feat_input_proj = nn.Conv2d(\n", - " 512, config.dim_model, kernel_size=1\n", - " )\n", - " # Transformer encoder positional embeddings.\n", - " n_1d_tokens = 1 # for the latent\n", - " if self.use_robot_state:\n", - " n_1d_tokens += 1\n", - " if self.use_env_state:\n", - " n_1d_tokens += 1\n", - " self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)\n", - " if self.use_images:\n", - " self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)\n", - "\n", - " # Transformer decoder.\n", - " # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).\n", - " self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)\n", - "\n", - " # Final action regression head on the output of the transformer's decoder.\n", - " self.action_head = nn.Linear(config.dim_model, config.output_shapes[\"action\"][0])\n", - "\n", - " self._reset_parameters()\n", - "\n", - " # CHANGE THIS WHEN RUNNING.\n", - " self.training=True\n", - "\n", - " def _reset_parameters(self):\n", - " \"\"\"Xavier-uniform initialization of the transformer parameters as in the original code.\"\"\"\n", - " for p in chain(nn.state.get_parameters(self.encoder), nn.state.get_parameters(self.decoder)):\n", - " if p.ndim > 1:\n", - " def xavier_uniform_(tensor: Tensor) -> Tensor:\n", - " fan_in, fan_out = tensor.shape[:2]\n", - " \n", - " # Calculate the range for the uniform distribution\n", - " # This is the glorot/xavier uniform initialization formula\n", - " a = math.sqrt(6.0 / (fan_in + fan_out))\n", - " \n", - " # Use uniform distribution to initialize the tensor\n", - " return Tensor.uniform(*tensor.shape, low=-a, high=a)\n", - " p = xavier_uniform_(p)\n", - "\n", - " def __call__(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]:\n", - " \"\"\"A forward pass through the Action Chunking Transformer (with optional VAE encoder).\n", - "\n", - " `batch` should have the following structure:\n", - " {\n", - " \"observation.state\" (optional): (B, state_dim) batch of robot states.\n", - "\n", - " \"observation.images\": (B, n_cameras, C, H, W) batch of images.\n", - " AND/OR\n", - " \"observation.environment_state\": (B, env_dim) batch of environment states.\n", - "\n", - " \"action\" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.\n", - " }\n", - "\n", - " Returns:\n", - " (B, chunk_size, action_dim) batch of action sequences\n", - " Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the\n", - " latent dimension.\n", - " \"\"\"\n", - " if self.config.use_vae and self.training:\n", - " assert (\n", - " \"action\" in batch\n", - " ), \"actions must be provided when using the variational objective in training mode.\"\n", - "\n", - " batch_size = (\n", - " batch[\"observation.images\"]\n", - " if \"observation.images\" in batch\n", - " else batch[\"observation.environment_state\"]\n", - " ).shape[0]\n", - "\n", - " print(f'batch_size: {batch_size}')\n", - "\n", - " # Prepare the latent for input to the transformer encoder.\n", - " if self.config.use_vae and \"action\" in batch:\n", - " # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence].\n", - " cls_embed = self.vae_encoder_cls_embed.weight.repeat(batch_size, 1, 1) # (B, 1, D)\n", - " if self.use_robot_state:\n", - " robot_state_embed = self.vae_encoder_robot_state_input_proj(batch[\"observation.state\"])\n", - " robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)\n", - " action_embed = self.vae_encoder_action_input_proj(batch[\"action\"]) # (B, S, D)\n", - "\n", - " if self.use_robot_state:\n", - " vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)\n", - " else:\n", - " vae_encoder_input = [cls_embed, action_embed]\n", - " vae_encoder_input = Tensor.cat(*vae_encoder_input, dim=1)\n", - "\n", - " # Prepare fixed positional embedding.\n", - " # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.\n", - " pos_embed = self.vae_encoder_pos_enc.contiguous().detach() # (1, S+2, D)\n", - "\n", - " # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the\n", - " # sequence depending whether we use the input states or not (cls and robot state)\n", - " # False means not a padding token.\n", - " cls_joint_is_pad = Tensor.full(\n", - " shape=(batch_size, 2 if self.use_robot_state else 1),\n", - " fill_value=False\n", - " )\n", - " key_padding_mask = Tensor.cat(\n", - " cls_joint_is_pad, batch[\"action_is_pad\"], dim=1\n", - " ) # (bs, seq+1 or 2)\n", - "\n", - " print(f'vae_encoder_input.shape: {vae_encoder_input.shape}')\n", - " print(f'pos_embed.shape: {pos_embed.shape}')\n", - " print(f'key_padding_mask.shape: {key_padding_mask.shape}')\n", - "\n", - " # Forward pass through VAE encoder to get the latent PDF parameters.\n", - " cls_token_out = self.vae_encoder(\n", - " vae_encoder_input.permute(1, 0, 2),\n", - " pos_embed=pos_embed.permute(1, 0, 2),\n", - " key_padding_mask=key_padding_mask.permute(1,0),\n", - " )\n", - " print(f'cls_token_out.shape: {cls_token_out.shape}')\n", - " cls_token_out = cls_token_out[0] # select the class token, with shape (B, D)\n", - " print(f'cls_token_out[0].shape: {cls_token_out.shape}')\n", - " latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)\n", - " mu = latent_pdf_params[:, : self.config.latent_dim]\n", - " # This is 2log(sigma). Done this way to match the original implementation.\n", - " log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]\n", - "\n", - " # Sample the latent with the reparameterization trick.\n", - " latent_sample = mu + log_sigma_x2.div(2).exp() * Tensor.randn(*(mu.shape))\n", - " else:\n", - " # When not using the VAE encoder, we set the latent to be all zeros.\n", - " mu = log_sigma_x2 = None\n", - " # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer\n", - " latent_sample = Tensor.zeros(batch_size, self.config.latent_dim, dtype=dtypes.float32)\n", - "\n", - " # Prepare transformer encoder inputs.\n", - " encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]\n", - " encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))\n", - " # Robot state token.\n", - " if self.use_robot_state:\n", - " encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch[\"observation.state\"]))\n", - " # Environment state token.\n", - " if self.use_env_state:\n", - " encoder_in_tokens.append(\n", - " self.encoder_env_state_input_proj(batch[\"observation.environment_state\"])\n", - " )\n", - "\n", - " # Camera observation features and positional embeddings.\n", - " if self.use_images:\n", - " all_cam_features = []\n", - " all_cam_pos_embeds = []\n", - "\n", - " for cam_index in range(batch[\"observation.images\"].shape[-4]):\n", - " cam_features = self.backbone(batch[\"observation.images\"][:, cam_index]) #[\"feature_map\"]\n", - " print(f'backbone output: {cam_features.shape}')\n", - " # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use\n", - " # buffer\n", - " cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).cast(dtype=cam_features.dtype)\n", - " cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)\n", - " print(f'cam_features: {cam_features.shape}')\n", - " all_cam_features.append(cam_features)\n", - " print(f'len all_cam_features: {len(all_cam_features)}')\n", - " all_cam_pos_embeds.append(cam_pos_embed)\n", - " # Concatenate camera observation feature maps and positional embeddings along the width dimension,\n", - " # and move to (sequence, batch, dim).\n", - " all_cam_features = Tensor.cat(*all_cam_features, dim=-1)\n", - " print(f'len all_cam_features after cat: {len(all_cam_features)}')\n", - " print(f'Before encoder_in_tokens.extend, encoder_in token len: {len(encoder_in_tokens)}')\n", - " encoder_in_tokens.extend(all_cam_features.permute(2, 3, 0, 1).reshape(-1, all_cam_features.shape[0], all_cam_features.shape[1]))\n", - " print(f'encoder_in_tokens: {len(encoder_in_tokens)}')\n", - " all_cam_pos_embeds = Tensor.cat(*all_cam_pos_embeds, dim=-1)\n", - " print(f'all_cam_pos_embeds: {all_cam_pos_embeds}')\n", - " encoder_in_pos_embed.extend(all_cam_pos_embeds.permute(2, 3, 0, 1).reshape(-1, all_cam_pos_embeds.shape[0], all_cam_pos_embeds.shape[1]))\n", - "\n", - " print(f'Before tensor.stack, encoder_in token len: {len(encoder_in_tokens)}')\n", - " print(f'Before tensor.stack, encoder_in_pos_embed token len: {len(encoder_in_pos_embed)}')\n", - " # Stack all tokens along the sequence dimension.\n", - " encoder_in_tokens = Tensor.stack(*encoder_in_tokens, dim=0)\n", - " encoder_in_pos_embed = Tensor.stack(*encoder_in_pos_embed, dim=0)\n", - "\n", - " print(f'encoder_in_tokens: {len(encoder_in_tokens)}')\n", - " print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')\n", - "\n", - " # Forward pass through the transformer modules.\n", - " encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed)\n", - " # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer\n", - " decoder_in = Tensor.zeros(\n", - " *(self.config.chunk_size, batch_size, self.config.dim_model),\n", - " dtype=encoder_in_pos_embed.dtype\n", - " )\n", - " print(f'encoder_out.shape: {encoder_out.shape}')\n", - " print(f'decoder_in.shape: {decoder_in.shape}')\n", - " print(f'encoder_in_pos_embed.shape: {encoder_in_pos_embed.shape}')\n", - " print(f'decoder_pos_embed.shape: {self.decoder_pos_embed.weight.shape}')\n", - " print(f'decoder_pos_embed.shape unsqueezed: {self.decoder_pos_embed.weight.unsqueeze(1).shape}')\n", - " decoder_out = self.decoder(\n", - " decoder_in.permute(1,0,2),\n", - " encoder_out.permute(1,0,2),\n", - " encoder_pos_embed=encoder_in_pos_embed.permute(1,0,2),\n", - " decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1).permute(1,0,2),\n", - " )\n", - "\n", - " # Move back to (B, S, C).\n", - " # decoder_out = decoder_out.transpose(0, 1)\n", - " print(f'decoder_out: {decoder_out.shape}')\n", - "\n", - " actions = self.action_head(decoder_out)\n", - "\n", - " return actions, (mu, log_sigma_x2)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "169cdbfe-7b4c-406a-ad51-276e0b5dabe8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ram used: 1.52 GB, layer4.1.bn2.running_var : 100%|█| \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "loaded weights in 14.43 ms, 0.04 GB loaded at 3.10 GB/s\n" - ] - } - ], - "source": [ - "act = ACT(ACTConfig())" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "ac22d861-8cd6-4287-972c-1a270205101b", - "metadata": {}, - "outputs": [], - "source": [ - "class ACTTemporalEnsembler:\n", - " def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:\n", - " \"\"\"Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.\n", - "\n", - " The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.\n", - " They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the\n", - " coefficient works:\n", - " - Setting it to 0 uniformly weighs all actions.\n", - " - Setting it positive gives more weight to older actions.\n", - " - Setting it negative gives more weight to newer actions.\n", - " NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This\n", - " results in older actions being weighed more highly than newer actions (the experiments documented in\n", - " https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be\n", - " detrimental: doing so aggressively may diminish the benefits of action chunking).\n", - "\n", - " Here we use an online method for computing the average rather than caching a history of actions in\n", - " order to compute the average offline. For a simple 1D sequence it looks something like:\n", - "\n", - " ```\n", - " import torch\n", - "\n", - " seq = torch.linspace(8, 8.5, 100)\n", - " print(seq)\n", - "\n", - " m = 0.01\n", - " exp_weights = torch.exp(-m * torch.arange(len(seq)))\n", - " print(exp_weights)\n", - "\n", - " # Calculate offline\n", - " avg = (exp_weights * seq).sum() / exp_weights.sum()\n", - " print(\"offline\", avg)\n", - "\n", - " # Calculate online\n", - " for i, item in enumerate(seq):\n", - " if i == 0:\n", - " avg = item\n", - " continue\n", - " avg *= exp_weights[:i].sum()\n", - " avg += item * exp_weights[i]\n", - " avg /= exp_weights[:i+1].sum()\n", - " print(\"online\", avg)\n", - " ```\n", - " \"\"\"\n", - " self.chunk_size = chunk_size\n", - " self.ensemble_weights = (-temporal_ensemble_coeff * Tensor.arange(chunk_size)).exp()\n", - " self.ensemble_weights_cumsum = self.ensemble_weights.cumsum(axis=0)\n", - " self.reset()\n", - "\n", - " def reset(self):\n", - " \"\"\"Resets the online computation variables.\"\"\"\n", - " self.ensembled_actions = None\n", - " # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.\n", - " self.ensembled_actions_count = None\n", - "\n", - " def update(self, actions: Tensor) -> Tensor:\n", - " \"\"\"\n", - " Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all\n", - " time steps, and pop/return the next batch of actions in the sequence.\n", - " \"\"\"\n", - " if self.ensembled_actions is None:\n", - " # Initializes `self._ensembled_action` to the sequence of actions predicted during the first\n", - " # time step of the episode.\n", - " self.ensembled_actions = actions.contiguous()\n", - " # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor\n", - " # operations later.\n", - " self.ensembled_actions_count = Tensor.ones(\n", - " *(self.chunk_size, 1), dtype=dtypes.long\n", - " )\n", - " else:\n", - " # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute\n", - " # the online update for those entries.\n", - " self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]\n", - " self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]\n", - " self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]\n", - " self.ensembled_actions_count = (self.ensembled_actions_count + 1).clamp(max_=self.chunk_size)\n", - " # The last action, which has no prior online average, needs to get concatenated onto the end.\n", - " self.ensembled_actions = Tensor.cat(*[self.ensembled_actions, actions[:, -1:]], dim=1)\n", - " self.ensembled_actions_count = Tensor.cat(\n", - " *[self.ensembled_actions_count, Tensor.ones_like(self.ensembled_actions_count[-1:])]\n", - " )\n", - " # \"Consume\" the first action.\n", - " action, self.ensembled_actions, self.ensembled_actions_count = (\n", - " self.ensembled_actions[:, 0],\n", - " self.ensembled_actions[:, 1:],\n", - " self.ensembled_actions_count[1:],\n", - " )\n", - " return action" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "0da33efc-c00f-4675-949f-ddf9d32a8c91", - "metadata": {}, - "outputs": [], - "source": [ - "from normalize import *\n", - "\n", - "class ACTPolicy:\n", - " \"\"\"\n", - " Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost\n", - " Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)\n", - " \"\"\"\n", - "\n", - " name = \"act\"\n", - "\n", - " def __init__(\n", - " self,\n", - " config: ACTConfig | None = None,\n", - " dataset_stats: dict[str, dict[str, Tensor]] | None = None,\n", - " ):\n", - " \"\"\"\n", - " Args:\n", - " config: Policy configuration class instance or None, in which case the default instantiation of\n", - " the configuration class is used.\n", - " dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected\n", - " that they will be passed with a call to `load_state_dict` before the policy is used.\n", - " \"\"\"\n", - " super().__init__()\n", - " if config is None:\n", - " config = ACTConfig()\n", - " self.config: ACTConfig = config\n", - "\n", - " self.normalize_inputs = Normalize(\n", - " config.input_shapes, config.input_normalization_modes, dataset_stats\n", - " )\n", - " self.normalize_targets = Normalize(\n", - " config.output_shapes, config.output_normalization_modes, dataset_stats\n", - " )\n", - " self.unnormalize_outputs = Unnormalize(\n", - " config.output_shapes, config.output_normalization_modes, dataset_stats\n", - " )\n", - "\n", - " self.model = ACT(config)\n", - "\n", - " self.expected_image_keys = [k for k in config.input_shapes if k.startswith(\"observation.image\")]\n", - "\n", - " if config.temporal_ensemble_coeff is not None:\n", - " self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)\n", - "\n", - " self.reset()\n", - "\n", - " def reset(self):\n", - " \"\"\"This should be called whenever the environment is reset.\"\"\"\n", - " if self.config.temporal_ensemble_coeff is not None:\n", - " self.temporal_ensembler.reset()\n", - " else:\n", - " self._action_queue = deque([], maxlen=self.config.n_action_steps)\n", - "\n", - " def select_action(self, batch: dict[str, Tensor]) -> Tensor:\n", - " \"\"\"Select a single action given environment observations.\n", - "\n", - " This method wraps `select_actions` in order to return one action at a time for execution in the\n", - " environment. It works by managing the actions in a queue and only calling `select_actions` when the\n", - " queue is empty.\n", - " \"\"\"\n", - " Tensor.no_grad = True\n", - " self.eval()\n", - "\n", - " batch = self.normalize_inputs(batch)\n", - " if len(self.expected_image_keys) > 0:\n", - " batch = dict(batch) # shallow copy so that adding a key doesn't modify the original\n", - " batch[\"observation.images\"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)\n", - "\n", - " # If we are doing temporal ensembling, do online updates where we keep track of the number of actions\n", - " # we are ensembling over.\n", - " if self.config.temporal_ensemble_coeff is not None:\n", - " actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)\n", - " actions = self.unnormalize_outputs({\"action\": actions})[\"action\"]\n", - " action = self.temporal_ensembler.update(actions)\n", - " return action\n", - "\n", - " # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by\n", - " # querying the policy.\n", - " if len(self._action_queue) == 0:\n", - " actions = self.model(batch)[0][:, : self.config.n_action_steps]\n", - "\n", - " # TODO(rcadene): make _forward return output dictionary?\n", - " actions = self.unnormalize_outputs({\"action\": actions})[\"action\"]\n", - "\n", - " # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue\n", - " # effectively has shape (n_action_steps, batch_size, *), hence the transpose.\n", - " self._action_queue.extend(actions.transpose(0, 1))\n", - " item_to_return = self._action_queue.popleft()\n", - " Tensor.no_grad = False\n", - " return item_to_return\n", - "\n", - " def __call__(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:\n", - " \"\"\"Run the batch through the model and compute the loss for training or validation.\"\"\"\n", - " batch = self.normalize_inputs(batch)\n", - " if len(self.expected_image_keys) > 0:\n", - " batch = dict(batch) # shallow copy so that adding a key doesn't modify the original\n", - " batch[\"observation.images\"] = Tensor.stack(*[batch[k] for k in self.expected_image_keys], dim=-4)\n", - " batch = self.normalize_targets(batch)\n", - " actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)\n", - "\n", - " l1_loss = (\n", - " (batch[\"action\"] - actions_hat).abs() * batch[\"action_is_pad\"].logical_not().int().unsqueeze(-1)\n", - " ).mean()\n", - "\n", - " loss_dict = {\"l1_loss\": l1_loss.item()}\n", - " if self.config.use_vae:\n", - " # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for\n", - " # each dimension independently, we sum over the latent dimension to get the total\n", - " # KL-divergence per batch element, then take the mean over the batch.\n", - " # (See App. B of https://arxiv.org/abs/1312.6114 for more details).\n", - " mean_kld = (\n", - " (-0.5 * (1 + log_sigma_x2_hat - mu_hat.square() - (log_sigma_x2_hat).exp())).sum(axis=-1).mean()\n", - " )\n", - " loss_dict[\"kld_loss\"] = mean_kld.item()\n", - " loss_dict[\"loss\"] = l1_loss + mean_kld * self.config.kl_weight\n", - " else:\n", - " loss_dict[\"loss\"] = l1_loss\n", - "\n", - " return loss_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "88ad724a-64be-47e7-9326-085ad096ab4a", - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7c0600c9f3764a6383b2c57f21ff1f5b", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Fetching 56 files: 0%| | 0/56 [00:00, None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n", - "cam_features: (8, 512, 15, 20)\n", - "len all_cam_features: 1\n", - "len all_cam_features after cat: 8\n", - "Before encoder_in_tokens.extend, encoder_in token len: 2\n", - "encoder_in_tokens: 302\n", - "all_cam_pos_embeds: on METAL with grad None>\n", - "Before tensor.stack, encoder_in token len: 302\n", - "Before tensor.stack, encoder_in_pos_embed token len: 302\n", - "encoder_in_tokens: 302\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "encoder_out.shape: (302, 8, 512)\n", - "decoder_in.shape: (100, 8, 512)\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "decoder_pos_embed.shape: (100, 512)\n", - "decoder_pos_embed.shape unsqueezed: (100, 1, 512)\n", - "decoder_out: (8, 100, 512)\n", - "step: 0 loss: 106.413\n", - "batch_size: 8\n", - "vae_encoder_input.shape: (8, 102, 512)\n", - "pos_embed.shape: (1, 102, 512)\n", - "key_padding_mask.shape: (8, 102)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "ACTEncoder x.shape per layer: (102, 8, 512)\n", - "(q,k,v): (102, 8, 8, 64), (102, 8, 8, 64), (102, 8, 8, 64)\n", - "(key_padding_mask): (102, 8)\n", - "cls_token_out.shape: (102, 8, 512)\n", - "cls_token_out[0].shape: (8, 512)\n", - "backbone output: (8, 512, 15, 20)\n", - ", None)> on METAL with grad None>\n", - ", None)> on METAL with grad None>\n", - "x_range[..., 0::2].sin(): , None)> on METAL with grad None>\n", - "x_range[..., 1::2].cos(): , None)> on METAL with grad None>\n", - "cam_features: (8, 512, 15, 20)\n", - "len all_cam_features: 1\n", - "len all_cam_features after cat: 8\n", - "Before encoder_in_tokens.extend, encoder_in token len: 2\n", - "encoder_in_tokens: 302\n", - "all_cam_pos_embeds: on METAL with grad None>\n", - "Before tensor.stack, encoder_in token len: 302\n", - "Before tensor.stack, encoder_in_pos_embed token len: 302\n", - "encoder_in_tokens: 302\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "ACTEncoder x.shape per layer: (302, 8, 512)\n", - "encoder_out.shape: (302, 8, 512)\n", - "decoder_in.shape: (100, 8, 512)\n", - "encoder_in_pos_embed.shape: (302, 1, 512)\n", - "decoder_pos_embed.shape: (100, 512)\n", - "decoder_pos_embed.shape unsqueezed: (100, 1, 512)\n", - "decoder_out: (8, 100, 512)\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[34], line 80\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m dataloader:\n\u001b[1;32m 79\u001b[0m batch \u001b[38;5;241m=\u001b[39m {k: Tensor(v\u001b[38;5;241m.\u001b[39mnumpy(), requires_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m batch\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 80\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 82\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step \u001b[38;5;241m%\u001b[39m log_freq \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 83\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstep: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mstep\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m loss: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloss\u001b[38;5;241m.\u001b[39mnumpy()\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.3f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/Cellar/python@3.12/3.12.6/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py:81\u001b[0m, in \u001b[0;36mContextDecorator.__call__..inner\u001b[0;34m(*args, **kwds)\u001b[0m\n\u001b[1;32m 78\u001b[0m \u001b[38;5;129m@wraps\u001b[39m(func)\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21minner\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_recreate_cm():\n\u001b[0;32m---> 81\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwds\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[34], line 59\u001b[0m, in \u001b[0;36mtrain_step\u001b[0;34m(batch)\u001b[0m\n\u001b[1;32m 57\u001b[0m opt_backbone\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[1;32m 58\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m---> 59\u001b[0m \u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 60\u001b[0m opt_backbone\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 61\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:34\u001b[0m, in \u001b[0;36mOptimizer.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 31\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 32\u001b[0m \u001b[38;5;124;03m Performs a single optimization step.\u001b[39;00m\n\u001b[1;32m 33\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 34\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mrealize(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mschedule_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:42\u001b[0m, in \u001b[0;36mOptimizer.schedule_step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;124;03mReturns the tensors that need to be realized to perform a single optimization step.\u001b[39;00m\n\u001b[1;32m 38\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m Tensor\u001b[38;5;241m.\u001b[39mtraining, (\n\u001b[1;32m 40\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\u001b[38;5;124mTensor.training=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mTensor\u001b[38;5;241m.\u001b[39mtraining\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, Tensor.training must be enabled to use the optimizer.\u001b[39m\n\u001b[1;32m 41\u001b[0m \u001b[38;5;124m - help: Consider setting Tensor.training=True before calling Optimizer.step().\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m)\n\u001b[0;32m---> 42\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparams\u001b[38;5;241m+\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffers\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/nn/optim.py:149\u001b[0m, in \u001b[0;36mLAMB._step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 148\u001b[0m r \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1.0\u001b[39m\n\u001b[0;32m--> 149\u001b[0m t\u001b[38;5;241m.\u001b[39massign((t\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m-\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mup\u001b[49m)\u001b[38;5;241m.\u001b[39mcast(t\u001b[38;5;241m.\u001b[39mdtype))\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb1_t, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mb2_t] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mm \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mv\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2747\u001b[0m, in \u001b[0;36mTensor.__mul__\u001b[0;34m(self, x)\u001b[0m\n\u001b[0;32m-> 2747\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__mul__\u001b[39m(\u001b[38;5;28mself\u001b[39m, x) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmul\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2554\u001b[0m, in \u001b[0;36mTensor.mul\u001b[0;34m(self, x, reverse)\u001b[0m\n\u001b[1;32m 2536\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmul\u001b[39m(\u001b[38;5;28mself\u001b[39m, x:Union[Tensor, ConstType], reverse\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[1;32m 2537\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2538\u001b[0m \u001b[38;5;124;03m Multiplies `self` and `x`.\u001b[39;00m\n\u001b[1;32m 2539\u001b[0m \u001b[38;5;124;03m Equivalent to `self * x`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2552\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 2553\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2554\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mMul\u001b[38;5;241m.\u001b[39mapply(\u001b[38;5;241m*\u001b[39m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_broadcasted\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreverse\u001b[49m\u001b[43m)\u001b[49m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2489\u001b[0m, in \u001b[0;36mTensor._broadcasted\u001b[0;34m(self, y, reverse, match_dtype)\u001b[0m\n\u001b[1;32m 2487\u001b[0m \u001b[38;5;66;03m# broadcast\u001b[39;00m\n\u001b[1;32m 2488\u001b[0m out_shape \u001b[38;5;241m=\u001b[39m _broadcast_shape(x\u001b[38;5;241m.\u001b[39mshape, y\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m-> 2489\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_broadcast_to\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_shape\u001b[49m\u001b[43m)\u001b[49m, y\u001b[38;5;241m.\u001b[39m_broadcast_to(out_shape)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:3231\u001b[0m, in \u001b[0;36m_metadata_wrapper.._wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 3230\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrapper\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 3231\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _METADATA\u001b[38;5;241m.\u001b[39mget() \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m TRACEMETA \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3234\u001b[0m caller_frame \u001b[38;5;241m=\u001b[39m sys\u001b[38;5;241m.\u001b[39m_getframe(frame \u001b[38;5;241m:=\u001b[39m \u001b[38;5;241m1\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:2466\u001b[0m, in \u001b[0;36mTensor._broadcast_to\u001b[0;34m(self, shape)\u001b[0m\n\u001b[1;32m 2464\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mlen\u001b[39m(shape): \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot broadcast tensor to fewer dimensions. shape=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 2465\u001b[0m \u001b[38;5;66;03m# first pad left with 1s https://data-apis.org/array-api/latest/API_specification/broadcasting.html\u001b[39;00m\n\u001b[0;32m-> 2466\u001b[0m padded, _ \u001b[38;5;241m=\u001b[39m \u001b[43m_pad_left\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[38;5;66;03m# for each dimension, check either from_ is 1, or it does not change\u001b[39;00m\n\u001b[1;32m 2468\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28many\u001b[39m(from_ \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m from_ \u001b[38;5;241m!=\u001b[39m to \u001b[38;5;28;01mfor\u001b[39;00m from_,to \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(padded, shape)): \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot broadcast from shape=\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mshape\u001b[38;5;132;01m=}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:88\u001b[0m, in \u001b[0;36m_pad_left\u001b[0;34m(*shapes)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_pad_left\u001b[39m(\u001b[38;5;241m*\u001b[39mshapes:Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 87\u001b[0m max_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(shape) \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mtuple\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_dim\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mshape\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mshapes\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m/opt/homebrew/lib/python3.12/site-packages/tinygrad/tensor.py:88\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_pad_left\u001b[39m(\u001b[38;5;241m*\u001b[39mshapes:Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[Tuple[sint, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 87\u001b[0m max_dim \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;28mlen\u001b[39m(shape) \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n\u001b[0;32m---> 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mtuple\u001b[39m((\u001b[38;5;241m1\u001b[39m,) \u001b[38;5;241m*\u001b[39m (max_dim \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mlen\u001b[39m(shape)) \u001b[38;5;241m+\u001b[39m shape \u001b[38;5;28;01mfor\u001b[39;00m shape \u001b[38;5;129;01min\u001b[39;00m shapes)\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "from pathlib import Path\n", - "\n", - "from lerobot.common.datasets.lerobot_dataset import LeRobotDataset\n", - "from torch.utils.data import DataLoader\n", - "import torch\n", - "\n", - "import tinygrad\n", - "from tinygrad import Tensor, nn, TinyJit\n", - "\n", - "from omegaconf import ListConfig, OmegaConf\n", - "\n", - "from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict\n", - "\n", - "# Start of training code\n", - "\n", - "# Create a directory to store the training checkpoint.\n", - "output_directory = Path(\"outputs/train/example_pusht\")\n", - "output_directory.mkdir(parents=True, exist_ok=True)\n", - "\n", - "# Number of offline training steps (we'll only do offline training for this example.)\n", - "# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.\n", - "training_steps = 100000\n", - "log_freq = 1\n", - "\n", - "# Set up the dataset.\n", - "delta_timestamps = {\n", - " \"action\": [i / 50.0 for i in range(100)],\n", - "}\n", - "dataset = LeRobotDataset('lerobot/aloha_sim_insertion_human', delta_timestamps=delta_timestamps)\n", - "print(dataset.stats)\n", - "\n", - "cfg = ACTConfig()\n", - "policy = ACTPolicy(cfg, dataset_stats=dataset.stats)\n", - "\n", - "params_not_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and not n.startswith(\"model.backbone\")]\n", - "params_backbone = [p for n, p in nn.state.get_state_dict(policy).items() if p.requires_grad != False and n.startswith(\"model.backbone\")]\n", - "\n", - "Tensor.manual_seed(1000)\n", - "\n", - "if hasattr(cfg, 'override_dataset_stats'):\n", - " for key, stats_dict in cfg.override_dataset_stats.items():\n", - " for stats_type, listconfig in stats_dict.items():\n", - " # example of stats_type: min, max, mean, std\n", - " print(f'listconfig: {listconfig}')\n", - " dataset.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32)\n", - "\n", - "opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4)\n", - "opt_backbone = nn.optim.AdamW(params_backbone, lr=1e-5, weight_decay=1e-4)\n", - "\n", - "#@TinyJit\n", - "@Tensor.train()\n", - "def train_step(batch) -> Tensor:\n", - " Tensor.training = True\n", - " output_dict = policy(batch)\n", - " loss = output_dict[\"loss\"]\n", - " opt.zero_grad()\n", - " opt_backbone.zero_grad()\n", - " loss.backward()\n", - " opt.step()\n", - " opt_backbone.step()\n", - " return loss\n", - "\n", - "print(f'Starting training loop')\n", - "# Create dataloader for offline training.\n", - "dataloader = DataLoader(\n", - " dataset,\n", - " num_workers=0,\n", - " batch_size=8,\n", - " shuffle=True,\n", - " pin_memory=False,\n", - " drop_last=True,\n", - ")\n", - "\n", - "step = 0\n", - "done = False\n", - "with Tensor.train():\n", - " while not done:\n", - " for batch in dataloader:\n", - " batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()}\n", - " loss = train_step(batch)\n", - " \n", - " if step % log_freq == 0:\n", - " print(f\"step: {step} loss: {loss.numpy():.3f}\")\n", - " step += 1\n", - "\n", - " if step % 10000 == 0:\n", - " try:\n", - " state_dict = get_state_dict(policy)\n", - " safe_save(state_dict, f'{output_directory}/model_{step}.safetensors')\n", - " except:\n", - " print(f'Exception with safe save occured')\n", - " if step >= training_steps:\n", - " done = True\n", - " break\n", - "\n", - "# Save a policy checkpoint.\n", - "state_dict = get_state_dict(policy)\n", - "safe_save(state_dict, f'{output_directory}/model_final.safetensors')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b347771d-4f2a-4a01-8382-e8e2970bbb92", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f7601f50-41d4-493f-8957-e407ee3ababe", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6bfdb2a1-6a3e-4d79-aef0-11e90efece43", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} +version https://git-lfs.github.com/spec/v1 +oid sha256:a9680e0d68e4aa8d6712bdae1e8e8151fa19efbc0b316415016114a1e69d95e3 +size 235585 diff --git a/networks.py b/networks.py index ff352fe..f1a453a 100644 --- a/networks.py +++ b/networks.py @@ -6,7 +6,7 @@ import numpy as np import tinygrad from tinygrad import Tensor, nn, dtypes -from tinygrad.ops import Variable +# from tinygrad.ops import Variable from utils import * diff --git a/normalize.py b/normalize.py index 0e4a2cc..fb36329 100644 --- a/normalize.py +++ b/normalize.py @@ -52,14 +52,28 @@ def create_stats_buffers( # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if mode == "mean_std": - buffer["mean"].assign(stats[key]["mean"].numpy()) + mean_val = stats[key]["mean"] + std_val = stats[key]["std"] + # Convert to numpy if needed (handle both numpy arrays and torch tensors) + if hasattr(mean_val, 'numpy'): + mean_val = mean_val.numpy() + if hasattr(std_val, 'numpy'): + std_val = std_val.numpy() + buffer["mean"].assign(mean_val) buffer["mean"].requires_grad = False - buffer["std"].assign(stats[key]["std"].numpy()) + buffer["std"].assign(std_val) buffer["std"].requires_grad = False elif mode == "min_max": - buffer["min"].assign(stats[key]["min"].numpy()) + min_val = stats[key]["min"] + max_val = stats[key]["max"] + # Convert to numpy if needed (handle both numpy arrays and torch tensors) + if hasattr(min_val, 'numpy'): + min_val = min_val.numpy() + if hasattr(max_val, 'numpy'): + max_val = max_val.numpy() + buffer["min"].assign(min_val) buffer["min"].requires_grad = False - buffer["max"].assign(stats[key]["max"].numpy()) + buffer["max"].assign(max_val) buffer["max"].requires_grad = False stats_buffers[key] = buffer diff --git a/outputs/train/act_aloha_sim_transfer_cube_human/model_10000.safetensors b/outputs/train/act_aloha_sim_transfer_cube_human/model_10000.safetensors deleted file mode 100755 index 8013f7b..0000000 --- a/outputs/train/act_aloha_sim_transfer_cube_human/model_10000.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5c391031ac1980499cfb4dd2bdcf17e1e528db7c91d26461fd0ed899472e7552 -size 206771192 diff --git a/outputs/train/aloha_sim_insertion_human/model_30000.safetensors b/outputs/train/aloha_sim_insertion_human/model_30000.safetensors index 68418b0..7d81855 100755 --- a/outputs/train/aloha_sim_insertion_human/model_30000.safetensors +++ b/outputs/train/aloha_sim_insertion_human/model_30000.safetensors @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a4b06fbccfb18628d2967f498c2fc44b5fe4d5edecb3f1a4e075aaca17ec85de +oid sha256:d8764d4a4d368d9799f4977a071d1372b1d67b6f434f33109ab5ed4b74f7b815 size 206771192 diff --git a/resnet.py b/resnet.py index fcbb5b4..01802dc 100644 --- a/resnet.py +++ b/resnet.py @@ -90,8 +90,10 @@ def __call__(self, x:Tensor): class ResNetInstances: def resnet18_IMAGENET1K_V1_Generator(): + import pathlib resnet18_IMAGENET1K = ResNet(Block, [2, 2, 2, 2], num_classes=1000) - state_dict = nn.state.safe_load("resnet18-f37072fd.safetensors") + model_path = pathlib.Path(__file__).parent / "resnet18-f37072fd.safetensors" + state_dict = nn.state.safe_load(str(model_path)) nn.state.load_state_dict(resnet18_IMAGENET1K, state_dict) return resnet18_IMAGENET1K diff --git a/test.py b/test.py index 3c9cac3..23672e1 100644 --- a/test.py +++ b/test.py @@ -16,8 +16,11 @@ import argparse parser=argparse.ArgumentParser(description="Argument Parser for ACT testing on simulated environments") -parser.add_argument("env_name", type=str, choices=['AlohaTransferCube-v0', 'AlohaInsertion-v0'], default='AlohaTransferCube-v0') -parser.add_argument("model_path", type=str) +# parser.add_argument("--env_name", type=str, choices=['AlohaTransferCube-v0', 'AlohaInsertion-v0'], default='AlohaTransferCube-v0') +parser.add_argument("--env_name", type=str, choices=['AlohaTransferCube-v0', 'AlohaInsertion-v0'], default='AlohaInsertion-v0') +# parser.add_argument("--model_path", type=str, default='outputs/train/aloha_sim_transfer_cube_human/model_final.safetensors') +# parser.add_argument("--model_path", type=str, default='outputs/train/aloha_sim_transfer_cube_human/model_final.safetensors') +parser.add_argument("--model_path", type=str, default='outputs/train/aloha_sim_insertion_human/model_30000_original.safetensors') args=parser.parse_args() env_name = args.env_name @@ -115,7 +118,7 @@ def test(state:Tensor, image:Tensor) -> Tensor: fps = env.metadata["render_fps"] # Encode all frames into a mp4 video. - video_path = output_directory / "rollout.mp4" + video_path = output_directory / "rollout3.mp4" imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps) print(f"Video of the evaluation is available in '{video_path}'.") diff --git a/train.py b/train.py index b515ad6..282c8f7 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,8 @@ from pathlib import Path -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +# from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + +from lerobot.datasets.lerobot_dataset import LeRobotDataset from torch.utils.data import DataLoader import torch @@ -17,7 +19,7 @@ import argparse # Start of training code parser=argparse.ArgumentParser(description="Argument Parser for ACT training on simulated environments") -parser.add_argument("env_name", type=str, choices=['aloha_sim_transfer_cube_human', 'aloha_sim_insertion_human'], default='aloha_sim_insertion_human') +parser.add_argument("--env_name", type=str, choices=['aloha_sim_transfer_cube_human', 'aloha_sim_insertion_human'], default='aloha_sim_insertion_human') parser.add_argument("--model_starting_point", type=str) parser.add_argument("--model_start_step_count", type=int) args=parser.parse_args() @@ -30,17 +32,18 @@ # Number of offline training steps (we'll only do offline training for this example.) # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. training_steps = 100000 -log_freq = 1 +log_freq = 10 # Set up the dataset. delta_timestamps = { "action": [i / 50.0 for i in range(100)], } dataset = LeRobotDataset(f'lerobot/{env_name}', delta_timestamps=delta_timestamps) -print(dataset.stats) +print(dataset.meta.stats) + cfg = ACTConfig() -policy = ACTPolicy(cfg, dataset_stats=dataset.stats) +policy = ACTPolicy(cfg, dataset_stats=dataset.meta.stats) policy.reset() step = 0 @@ -65,7 +68,7 @@ for stats_type, listconfig in stats_dict.items(): # example of stats_type: min, max, mean, std print(f'listconfig: {listconfig}') - dataset.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32) + dataset.meta.stats[key][stats_type] = torch.tensor(listconfig, dtype=torch.float32) if cfg.train_backbone_separately == True: opt = nn.optim.AdamW(params_not_backbone, lr=1e-5, weight_decay=1e-4) @@ -89,6 +92,19 @@ def train_step( if cfg.train_backbone_separately: opt_backbone.zero_grad() loss.backward() + ######################################################################## + # Handle unused parameters by assigning zero gradients + optimizers_list = [opt] + if cfg.train_backbone_separately: + optimizers_list.append(opt_backbone) + for optimizer in optimizers_list: + for param in optimizer.params: + if param.grad is None: + # Create zero gradient with same shape and device as parameter + param.grad = Tensor.zeros(*param.shape, device=param.device, requires_grad=False) + + ######################################################################## + if cfg.train_backbone_separately: grad_norm_not_backbone = clip_grad_norm_(params_not_backbone, 10.0) grad_norm_backbone = clip_grad_norm_(params_backbone, 10.0) @@ -118,7 +134,16 @@ def train_step( with Tensor.train(): while not done: for batch in dataloader: - batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()} + # batch = {k: Tensor(v.numpy(), requires_grad=False) for k, v in batch.items()} + batch_converted = {} + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch_converted[k] = Tensor(v.detach().cpu().numpy(), requires_grad=False) + else: + batch_converted[k] = v # Keep strings, lists, etc. as-is + + batch = batch_converted + batch = policy.normalize_batch_inputs_and_targets(batch) print(f'batch: {batch}') info = train_step(