@@ -570,6 +570,255 @@ static void* nmmo3_encoder_create_weights(void* self) {
570570static void nmmo3_encoder_free_weights (void * weights) { free (weights); }
571571static void nmmo3_encoder_free_activations (void * activations) { free (activations); }
572572
573+ // ---- Boxoban positional encoder ----
574+
575+ static constexpr int BOX_NUM_TYPES = 4 ;
576+ static constexpr int BOX_NUM_CELLS = 100 ;
577+ static constexpr int BOX_EMBED_DIM = 8 ;
578+ static constexpr int BOX_EMBED_FLAT = BOX_NUM_CELLS * BOX_EMBED_DIM ;
579+
580+ __global__ void box_bias_relu_kernel (
581+ precision_t * __restrict__ data, const precision_t * __restrict__ bias, int total, int dim) {
582+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
583+ if (idx >= total) return ;
584+ data[idx] = from_float (relu (to_float (data[idx]) + to_float (bias[idx % dim])));
585+ }
586+
587+ __global__ void box_relu_backward_kernel (
588+ precision_t * __restrict__ grad, const precision_t * __restrict__ out, int total) {
589+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
590+ if (idx >= total) return ;
591+ grad[idx] = from_float (relu_backward (to_float (out[idx]), to_float (grad[idx])));
592+ }
593+
594+ __global__ void boxoban_cell_embed_kernel (
595+ precision_t * __restrict__ out, const precision_t * __restrict__ obs,
596+ const precision_t * __restrict__ type_embed, const precision_t * __restrict__ pos_embed,
597+ int B, int obs_size) {
598+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
599+ if (idx >= B * BOX_NUM_CELLS * BOX_EMBED_DIM ) return ;
600+ int b = idx / (BOX_NUM_CELLS * BOX_EMBED_DIM );
601+ int rem = idx % (BOX_NUM_CELLS * BOX_EMBED_DIM );
602+ int cell = rem / BOX_EMBED_DIM ;
603+ int d = rem % BOX_EMBED_DIM ;
604+ float sum = to_float (pos_embed[cell * BOX_EMBED_DIM + d]);
605+ int base = b * obs_size + cell;
606+ for (int t = 0 ; t < BOX_NUM_TYPES ; t++) {
607+ float occ = to_float (obs[base + t * BOX_NUM_CELLS ]);
608+ sum += occ * to_float (type_embed[t * BOX_EMBED_DIM + d]);
609+ }
610+ out[idx] = from_float (sum);
611+ }
612+
613+ __global__ void boxoban_type_embed_backward_kernel (
614+ float * __restrict__ type_grad_f, const precision_t * __restrict__ grad_cell,
615+ const precision_t * __restrict__ obs, int B, int obs_size) {
616+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
617+ if (idx >= B * BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM ) return ;
618+ int b = idx / (BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM );
619+ int rem = idx % (BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM );
620+ int cell = rem / (BOX_NUM_TYPES * BOX_EMBED_DIM );
621+ rem %= (BOX_NUM_TYPES * BOX_EMBED_DIM );
622+ int t = rem / BOX_EMBED_DIM ;
623+ int d = rem % BOX_EMBED_DIM ;
624+ float occ = to_float (obs[b * obs_size + t * BOX_NUM_CELLS + cell]);
625+ float g = occ * to_float (grad_cell[b * BOX_EMBED_FLAT + cell * BOX_EMBED_DIM + d]);
626+ atomicAdd (&type_grad_f[t * BOX_EMBED_DIM + d], g);
627+ }
628+
629+ __global__ void boxoban_pos_embed_backward_kernel (
630+ float * __restrict__ pos_grad_f, const precision_t * __restrict__ grad_cell, int B) {
631+ int idx = blockIdx .x * blockDim .x + threadIdx .x ;
632+ if (idx >= B * BOX_NUM_CELLS * BOX_EMBED_DIM ) return ;
633+ int rem = idx % (BOX_NUM_CELLS * BOX_EMBED_DIM );
634+ int cell = rem / BOX_EMBED_DIM ;
635+ int d = rem % BOX_EMBED_DIM ;
636+ atomicAdd (&pos_grad_f[cell * BOX_EMBED_DIM + d], to_float (grad_cell[idx]));
637+ }
638+
639+ struct BoxobanEncoderWeights {
640+ PrecisionTensor type_embed, pos_embed;
641+ PrecisionTensor w1, b1, w2, b2, w3, b3;
642+ int obs_size, hidden;
643+ };
644+
645+ struct BoxobanEncoderActivations {
646+ PrecisionTensor cell_flat, l1_preact, l1_act, l2_preact, l2_act, l3_preact, out, saved_obs;
647+ PrecisionTensor grad_cell_flat, grad_l1, grad_l2;
648+ PrecisionTensor type_embed_wgrad, pos_embed_wgrad, w1_wgrad, b1_wgrad, w2_wgrad, b2_wgrad, w3_wgrad, b3_wgrad;
649+ FloatTensor type_embed_wgrad_f, pos_embed_wgrad_f;
650+ };
651+
652+ static BoxobanEncoderWeights* boxoban_encoder_create (int obs_size, int hidden) {
653+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)calloc (1 , sizeof (BoxobanEncoderWeights));
654+ ew->obs_size = obs_size;
655+ ew->hidden = hidden;
656+ return ew;
657+ }
658+
659+ static PrecisionTensor boxoban_encoder_forward (void * w, void * activations, PrecisionTensor input, cudaStream_t stream) {
660+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
661+ BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
662+ int B = input.shape [0 ];
663+
664+ if (a->saved_obs .data ) puf_copy (&a->saved_obs , &input, stream);
665+
666+ boxoban_cell_embed_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_EMBED_DIM ), BLOCK_SIZE , 0 , stream>>> (
667+ a->cell_flat .data , input.data , ew->type_embed .data , ew->pos_embed .data , B, ew->obs_size );
668+
669+ puf_mm (&a->cell_flat , &ew->w1 , &a->l1_preact , stream);
670+ puf_copy (&a->l1_act , &a->l1_preact , stream);
671+ box_bias_relu_kernel<<<grid_size(B * (2 * ew->hidden)), BLOCK_SIZE , 0 , stream>>> (
672+ a->l1_act .data , ew->b1 .data , B * (2 * ew->hidden ), 2 * ew->hidden );
673+
674+ puf_mm (&a->l1_act , &ew->w2 , &a->l2_preact , stream);
675+ puf_copy (&a->l2_act , &a->l2_preact , stream);
676+ box_bias_relu_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE , 0 , stream>>> (
677+ a->l2_act .data , ew->b2 .data , B * ew->hidden , ew->hidden );
678+
679+ puf_mm (&a->l2_act , &ew->w3 , &a->l3_preact , stream);
680+ puf_copy (&a->out , &a->l3_preact , stream);
681+ box_bias_relu_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE , 0 , stream>>> (
682+ a->out .data , ew->b3 .data , B * ew->hidden , ew->hidden );
683+
684+ return a->out ;
685+ }
686+
687+ static void boxoban_encoder_backward (void * w, void * activations, PrecisionTensor grad, cudaStream_t stream) {
688+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
689+ BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
690+ int B = grad.shape [0 ];
691+
692+ box_relu_backward_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE , 0 , stream>>> (
693+ grad.data , a->out .data , B * ew->hidden );
694+ bias_grad_kernel<<<ew->hidden, 256 , 0 , stream>>> (a->b3_wgrad .data , grad.data , B, ew->hidden );
695+ puf_mm_tn (&grad, &a->l2_act , &a->w3_wgrad , stream);
696+
697+ puf_mm_nn (&grad, &ew->w3 , &a->grad_l2 , stream);
698+ box_relu_backward_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE , 0 , stream>>> (
699+ a->grad_l2 .data , a->l2_act .data , B * ew->hidden );
700+ bias_grad_kernel<<<ew->hidden, 256 , 0 , stream>>> (a->b2_wgrad .data , a->grad_l2 .data , B, ew->hidden );
701+ puf_mm_tn (&a->grad_l2 , &a->l1_act , &a->w2_wgrad , stream);
702+
703+ puf_mm_nn (&a->grad_l2 , &ew->w2 , &a->grad_l1 , stream);
704+ box_relu_backward_kernel<<<grid_size(B * (2 * ew->hidden)), BLOCK_SIZE , 0 , stream>>> (
705+ a->grad_l1 .data , a->l1_act .data , B * (2 * ew->hidden ));
706+ bias_grad_kernel<<<2 * ew->hidden, 256 , 0 , stream>>> (a->b1_wgrad .data , a->grad_l1 .data , B, 2 * ew->hidden );
707+ puf_mm_tn (&a->grad_l1 , &a->cell_flat , &a->w1_wgrad , stream);
708+ puf_mm_nn (&a->grad_l1 , &ew->w1 , &a->grad_cell_flat , stream);
709+
710+ puf_zero (&a->type_embed_wgrad_f , stream);
711+ puf_zero (&a->pos_embed_wgrad_f , stream);
712+ boxoban_type_embed_backward_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM ), BLOCK_SIZE , 0 , stream>>> (
713+ a->type_embed_wgrad_f .data , a->grad_cell_flat .data , a->saved_obs .data , B, ew->obs_size );
714+ boxoban_pos_embed_backward_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_EMBED_DIM ), BLOCK_SIZE , 0 , stream>>> (
715+ a->pos_embed_wgrad_f .data , a->grad_cell_flat .data , B);
716+ cast<<<grid_size(BOX_NUM_TYPES * BOX_EMBED_DIM ), BLOCK_SIZE , 0 , stream>>> (
717+ a->type_embed_wgrad .data , a->type_embed_wgrad_f .data , BOX_NUM_TYPES * BOX_EMBED_DIM );
718+ cast<<<grid_size(BOX_NUM_CELLS * BOX_EMBED_DIM ), BLOCK_SIZE , 0 , stream>>> (
719+ a->pos_embed_wgrad .data , a->pos_embed_wgrad_f .data , BOX_NUM_CELLS * BOX_EMBED_DIM );
720+ }
721+
722+ static void boxoban_encoder_init_weights (void * w, uint64_t * seed, cudaStream_t stream) {
723+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
724+ auto init2d = [&](PrecisionTensor& t, int rows, int cols, float gain) {
725+ PrecisionTensor wt = {.data = t.data , .shape = {rows, cols}};
726+ puf_kaiming_init (&wt, gain, (*seed)++, stream);
727+ };
728+ puf_normal_init (&ew->type_embed , 1 .0f , (*seed)++, stream);
729+ puf_normal_init (&ew->pos_embed , 1 .0f , (*seed)++, stream);
730+ init2d (ew->w1 , 2 * ew->hidden , BOX_EMBED_FLAT , 1 .0f );
731+ init2d (ew->w2 , ew->hidden , 2 * ew->hidden , 1 .0f );
732+ init2d (ew->w3 , ew->hidden , ew->hidden , 1 .0f );
733+ cudaMemsetAsync (ew->b1 .data , 0 , numel (ew->b1 .shape ) * sizeof (precision_t ), stream);
734+ cudaMemsetAsync (ew->b2 .data , 0 , numel (ew->b2 .shape ) * sizeof (precision_t ), stream);
735+ cudaMemsetAsync (ew->b3 .data , 0 , numel (ew->b3 .shape ) * sizeof (precision_t ), stream);
736+ }
737+
738+ static void boxoban_encoder_reg_params (void * w, Allocator* alloc) {
739+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
740+ ew->type_embed = {.shape = {BOX_NUM_TYPES , BOX_EMBED_DIM }};
741+ ew->pos_embed = {.shape = {BOX_NUM_CELLS , BOX_EMBED_DIM }};
742+ ew->w1 = {.shape = {2 * ew->hidden , BOX_EMBED_FLAT }};
743+ ew->b1 = {.shape = {2 * ew->hidden }};
744+ ew->w2 = {.shape = {ew->hidden , 2 * ew->hidden }};
745+ ew->b2 = {.shape = {ew->hidden }};
746+ ew->w3 = {.shape = {ew->hidden , ew->hidden }};
747+ ew->b3 = {.shape = {ew->hidden }};
748+ alloc_register (alloc, &ew->type_embed );
749+ alloc_register (alloc, &ew->pos_embed );
750+ alloc_register (alloc, &ew->w1 ); alloc_register (alloc, &ew->b1 );
751+ alloc_register (alloc, &ew->w2 ); alloc_register (alloc, &ew->b2 );
752+ alloc_register (alloc, &ew->w3 ); alloc_register (alloc, &ew->b3 );
753+ }
754+
755+ static void boxoban_encoder_reg_train (void * w, void * activations, Allocator* acts, Allocator* grads, int B_TT ) {
756+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
757+ BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
758+ *a = {};
759+ a->cell_flat = {.shape = {B_TT , BOX_EMBED_FLAT }};
760+ a->l1_preact = {.shape = {B_TT , 2 * ew->hidden }};
761+ a->l1_act = {.shape = {B_TT , 2 * ew->hidden }};
762+ a->l2_preact = {.shape = {B_TT , ew->hidden }};
763+ a->l2_act = {.shape = {B_TT , ew->hidden }};
764+ a->l3_preact = {.shape = {B_TT , ew->hidden }};
765+ a->out = {.shape = {B_TT , ew->hidden }};
766+ a->saved_obs = {.shape = {B_TT , ew->obs_size }};
767+ a->grad_cell_flat = {.shape = {B_TT , BOX_EMBED_FLAT }};
768+ a->grad_l1 = {.shape = {B_TT , 2 * ew->hidden }};
769+ a->grad_l2 = {.shape = {B_TT , ew->hidden }};
770+ alloc_register (acts, &a->cell_flat );
771+ alloc_register (acts, &a->l1_preact ); alloc_register (acts, &a->l1_act );
772+ alloc_register (acts, &a->l2_preact ); alloc_register (acts, &a->l2_act );
773+ alloc_register (acts, &a->l3_preact ); alloc_register (acts, &a->out );
774+ alloc_register (acts, &a->saved_obs );
775+ alloc_register (acts, &a->grad_cell_flat );
776+ alloc_register (acts, &a->grad_l1 );
777+ alloc_register (acts, &a->grad_l2 );
778+
779+ a->type_embed_wgrad = {.shape = {BOX_NUM_TYPES , BOX_EMBED_DIM }};
780+ a->pos_embed_wgrad = {.shape = {BOX_NUM_CELLS , BOX_EMBED_DIM }};
781+ a->w1_wgrad = {.shape = {2 * ew->hidden , BOX_EMBED_FLAT }};
782+ a->b1_wgrad = {.shape = {2 * ew->hidden }};
783+ a->w2_wgrad = {.shape = {ew->hidden , 2 * ew->hidden }};
784+ a->b2_wgrad = {.shape = {ew->hidden }};
785+ a->w3_wgrad = {.shape = {ew->hidden , ew->hidden }};
786+ a->b3_wgrad = {.shape = {ew->hidden }};
787+ a->type_embed_wgrad_f = {.shape = {BOX_NUM_TYPES , BOX_EMBED_DIM }};
788+ a->pos_embed_wgrad_f = {.shape = {BOX_NUM_CELLS , BOX_EMBED_DIM }};
789+ alloc_register (grads, &a->type_embed_wgrad );
790+ alloc_register (grads, &a->pos_embed_wgrad );
791+ alloc_register (grads, &a->w1_wgrad ); alloc_register (grads, &a->b1_wgrad );
792+ alloc_register (grads, &a->w2_wgrad ); alloc_register (grads, &a->b2_wgrad );
793+ alloc_register (grads, &a->w3_wgrad ); alloc_register (grads, &a->b3_wgrad );
794+ alloc_register (acts, &a->type_embed_wgrad_f );
795+ alloc_register (acts, &a->pos_embed_wgrad_f );
796+ }
797+
798+ static void boxoban_encoder_reg_rollout (void * w, void * activations, Allocator* alloc, int B) {
799+ BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
800+ BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
801+ a->cell_flat = {.shape = {B, BOX_EMBED_FLAT }};
802+ a->l1_preact = {.shape = {B, 2 * ew->hidden }};
803+ a->l1_act = {.shape = {B, 2 * ew->hidden }};
804+ a->l2_preact = {.shape = {B, ew->hidden }};
805+ a->l2_act = {.shape = {B, ew->hidden }};
806+ a->l3_preact = {.shape = {B, ew->hidden }};
807+ a->out = {.shape = {B, ew->hidden }};
808+ alloc_register (alloc, &a->cell_flat );
809+ alloc_register (alloc, &a->l1_preact ); alloc_register (alloc, &a->l1_act );
810+ alloc_register (alloc, &a->l2_preact ); alloc_register (alloc, &a->l2_act );
811+ alloc_register (alloc, &a->l3_preact ); alloc_register (alloc, &a->out );
812+ }
813+
814+ static void * boxoban_encoder_create_weights (void * self) {
815+ Encoder* e = (Encoder*)self;
816+ return boxoban_encoder_create (e->in_dim , e->out_dim );
817+ }
818+
819+ static void boxoban_encoder_free_weights (void * weights) { free (weights); }
820+ static void boxoban_encoder_free_activations (void * activations) { free (activations); }
821+
573822// Override encoder vtable for known ocean environments. No-op for unknown envs.
574823static void create_custom_encoder (const std::string& env_name, Encoder* enc) {
575824 if (env_name == " nmmo3" ) {
@@ -586,5 +835,19 @@ static void create_custom_encoder(const std::string& env_name, Encoder* enc) {
586835 .in_dim = enc->in_dim , .out_dim = enc->out_dim ,
587836 .activation_size = sizeof (NMMO3EncoderActivations),
588837 };
838+ } else if (env_name == " boxoban" ) {
839+ *enc = Encoder{
840+ .forward = boxoban_encoder_forward,
841+ .backward = boxoban_encoder_backward,
842+ .init_weights = boxoban_encoder_init_weights,
843+ .reg_params = boxoban_encoder_reg_params,
844+ .reg_train = boxoban_encoder_reg_train,
845+ .reg_rollout = boxoban_encoder_reg_rollout,
846+ .create_weights = boxoban_encoder_create_weights,
847+ .free_weights = boxoban_encoder_free_weights,
848+ .free_activations = boxoban_encoder_free_activations,
849+ .in_dim = enc->in_dim , .out_dim = enc->out_dim ,
850+ .activation_size = sizeof (BoxobanEncoderActivations),
851+ };
589852 }
590853}
0 commit comments