Skip to content

Commit c9edf83

Browse files
author
Tim-phant
committed
custom positional embedds
1 parent c2d9f3f commit c9edf83

2 files changed

Lines changed: 266 additions & 1 deletion

File tree

pufferlib/sweep.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def _params_from_puffer_sweep(sweep_config, only_include=None):
146146

147147
for name, param in sweep_config.items():
148148
if name in ('method', 'metric', 'metric_distribution', 'goal', 'downsample', 'use_gpu', 'prune_pareto',
149-
'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs'):
149+
'sweep_only', 'max_suggestion_cost', 'early_stop_quantile', 'gpus', 'max_runs',
150+
'match_enemy_model_path', 'match_num_games',
151+
'match_enemy_hidden_size', 'match_enemy_num_layers'):
150152
continue
151153

152154
assert isinstance(param, dict), f'Param {name} is not a dict'

src/ocean.cu

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,255 @@ static void* nmmo3_encoder_create_weights(void* self) {
570570
static void nmmo3_encoder_free_weights(void* weights) { free(weights); }
571571
static void nmmo3_encoder_free_activations(void* activations) { free(activations); }
572572

573+
// ---- Boxoban positional encoder ----
574+
575+
static constexpr int BOX_NUM_TYPES = 4;
576+
static constexpr int BOX_NUM_CELLS = 100;
577+
static constexpr int BOX_EMBED_DIM = 8;
578+
static constexpr int BOX_EMBED_FLAT = BOX_NUM_CELLS * BOX_EMBED_DIM;
579+
580+
__global__ void box_bias_relu_kernel(
581+
precision_t* __restrict__ data, const precision_t* __restrict__ bias, int total, int dim) {
582+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
583+
if (idx >= total) return;
584+
data[idx] = from_float(relu(to_float(data[idx]) + to_float(bias[idx % dim])));
585+
}
586+
587+
__global__ void box_relu_backward_kernel(
588+
precision_t* __restrict__ grad, const precision_t* __restrict__ out, int total) {
589+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
590+
if (idx >= total) return;
591+
grad[idx] = from_float(relu_backward(to_float(out[idx]), to_float(grad[idx])));
592+
}
593+
594+
__global__ void boxoban_cell_embed_kernel(
595+
precision_t* __restrict__ out, const precision_t* __restrict__ obs,
596+
const precision_t* __restrict__ type_embed, const precision_t* __restrict__ pos_embed,
597+
int B, int obs_size) {
598+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
599+
if (idx >= B * BOX_NUM_CELLS * BOX_EMBED_DIM) return;
600+
int b = idx / (BOX_NUM_CELLS * BOX_EMBED_DIM);
601+
int rem = idx % (BOX_NUM_CELLS * BOX_EMBED_DIM);
602+
int cell = rem / BOX_EMBED_DIM;
603+
int d = rem % BOX_EMBED_DIM;
604+
float sum = to_float(pos_embed[cell * BOX_EMBED_DIM + d]);
605+
int base = b * obs_size + cell;
606+
for (int t = 0; t < BOX_NUM_TYPES; t++) {
607+
float occ = to_float(obs[base + t * BOX_NUM_CELLS]);
608+
sum += occ * to_float(type_embed[t * BOX_EMBED_DIM + d]);
609+
}
610+
out[idx] = from_float(sum);
611+
}
612+
613+
__global__ void boxoban_type_embed_backward_kernel(
614+
float* __restrict__ type_grad_f, const precision_t* __restrict__ grad_cell,
615+
const precision_t* __restrict__ obs, int B, int obs_size) {
616+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
617+
if (idx >= B * BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM) return;
618+
int b = idx / (BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM);
619+
int rem = idx % (BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM);
620+
int cell = rem / (BOX_NUM_TYPES * BOX_EMBED_DIM);
621+
rem %= (BOX_NUM_TYPES * BOX_EMBED_DIM);
622+
int t = rem / BOX_EMBED_DIM;
623+
int d = rem % BOX_EMBED_DIM;
624+
float occ = to_float(obs[b * obs_size + t * BOX_NUM_CELLS + cell]);
625+
float g = occ * to_float(grad_cell[b * BOX_EMBED_FLAT + cell * BOX_EMBED_DIM + d]);
626+
atomicAdd(&type_grad_f[t * BOX_EMBED_DIM + d], g);
627+
}
628+
629+
__global__ void boxoban_pos_embed_backward_kernel(
630+
float* __restrict__ pos_grad_f, const precision_t* __restrict__ grad_cell, int B) {
631+
int idx = blockIdx.x * blockDim.x + threadIdx.x;
632+
if (idx >= B * BOX_NUM_CELLS * BOX_EMBED_DIM) return;
633+
int rem = idx % (BOX_NUM_CELLS * BOX_EMBED_DIM);
634+
int cell = rem / BOX_EMBED_DIM;
635+
int d = rem % BOX_EMBED_DIM;
636+
atomicAdd(&pos_grad_f[cell * BOX_EMBED_DIM + d], to_float(grad_cell[idx]));
637+
}
638+
639+
struct BoxobanEncoderWeights {
640+
PrecisionTensor type_embed, pos_embed;
641+
PrecisionTensor w1, b1, w2, b2, w3, b3;
642+
int obs_size, hidden;
643+
};
644+
645+
struct BoxobanEncoderActivations {
646+
PrecisionTensor cell_flat, l1_preact, l1_act, l2_preact, l2_act, l3_preact, out, saved_obs;
647+
PrecisionTensor grad_cell_flat, grad_l1, grad_l2;
648+
PrecisionTensor type_embed_wgrad, pos_embed_wgrad, w1_wgrad, b1_wgrad, w2_wgrad, b2_wgrad, w3_wgrad, b3_wgrad;
649+
FloatTensor type_embed_wgrad_f, pos_embed_wgrad_f;
650+
};
651+
652+
static BoxobanEncoderWeights* boxoban_encoder_create(int obs_size, int hidden) {
653+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)calloc(1, sizeof(BoxobanEncoderWeights));
654+
ew->obs_size = obs_size;
655+
ew->hidden = hidden;
656+
return ew;
657+
}
658+
659+
static PrecisionTensor boxoban_encoder_forward(void* w, void* activations, PrecisionTensor input, cudaStream_t stream) {
660+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
661+
BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
662+
int B = input.shape[0];
663+
664+
if (a->saved_obs.data) puf_copy(&a->saved_obs, &input, stream);
665+
666+
boxoban_cell_embed_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
667+
a->cell_flat.data, input.data, ew->type_embed.data, ew->pos_embed.data, B, ew->obs_size);
668+
669+
puf_mm(&a->cell_flat, &ew->w1, &a->l1_preact, stream);
670+
puf_copy(&a->l1_act, &a->l1_preact, stream);
671+
box_bias_relu_kernel<<<grid_size(B * (2 * ew->hidden)), BLOCK_SIZE, 0, stream>>>(
672+
a->l1_act.data, ew->b1.data, B * (2 * ew->hidden), 2 * ew->hidden);
673+
674+
puf_mm(&a->l1_act, &ew->w2, &a->l2_preact, stream);
675+
puf_copy(&a->l2_act, &a->l2_preact, stream);
676+
box_bias_relu_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE, 0, stream>>>(
677+
a->l2_act.data, ew->b2.data, B * ew->hidden, ew->hidden);
678+
679+
puf_mm(&a->l2_act, &ew->w3, &a->l3_preact, stream);
680+
puf_copy(&a->out, &a->l3_preact, stream);
681+
box_bias_relu_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE, 0, stream>>>(
682+
a->out.data, ew->b3.data, B * ew->hidden, ew->hidden);
683+
684+
return a->out;
685+
}
686+
687+
static void boxoban_encoder_backward(void* w, void* activations, PrecisionTensor grad, cudaStream_t stream) {
688+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
689+
BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
690+
int B = grad.shape[0];
691+
692+
box_relu_backward_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE, 0, stream>>>(
693+
grad.data, a->out.data, B * ew->hidden);
694+
bias_grad_kernel<<<ew->hidden, 256, 0, stream>>>(a->b3_wgrad.data, grad.data, B, ew->hidden);
695+
puf_mm_tn(&grad, &a->l2_act, &a->w3_wgrad, stream);
696+
697+
puf_mm_nn(&grad, &ew->w3, &a->grad_l2, stream);
698+
box_relu_backward_kernel<<<grid_size(B * ew->hidden), BLOCK_SIZE, 0, stream>>>(
699+
a->grad_l2.data, a->l2_act.data, B * ew->hidden);
700+
bias_grad_kernel<<<ew->hidden, 256, 0, stream>>>(a->b2_wgrad.data, a->grad_l2.data, B, ew->hidden);
701+
puf_mm_tn(&a->grad_l2, &a->l1_act, &a->w2_wgrad, stream);
702+
703+
puf_mm_nn(&a->grad_l2, &ew->w2, &a->grad_l1, stream);
704+
box_relu_backward_kernel<<<grid_size(B * (2 * ew->hidden)), BLOCK_SIZE, 0, stream>>>(
705+
a->grad_l1.data, a->l1_act.data, B * (2 * ew->hidden));
706+
bias_grad_kernel<<<2 * ew->hidden, 256, 0, stream>>>(a->b1_wgrad.data, a->grad_l1.data, B, 2 * ew->hidden);
707+
puf_mm_tn(&a->grad_l1, &a->cell_flat, &a->w1_wgrad, stream);
708+
puf_mm_nn(&a->grad_l1, &ew->w1, &a->grad_cell_flat, stream);
709+
710+
puf_zero(&a->type_embed_wgrad_f, stream);
711+
puf_zero(&a->pos_embed_wgrad_f, stream);
712+
boxoban_type_embed_backward_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_NUM_TYPES * BOX_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
713+
a->type_embed_wgrad_f.data, a->grad_cell_flat.data, a->saved_obs.data, B, ew->obs_size);
714+
boxoban_pos_embed_backward_kernel<<<grid_size(B * BOX_NUM_CELLS * BOX_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
715+
a->pos_embed_wgrad_f.data, a->grad_cell_flat.data, B);
716+
cast<<<grid_size(BOX_NUM_TYPES * BOX_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
717+
a->type_embed_wgrad.data, a->type_embed_wgrad_f.data, BOX_NUM_TYPES * BOX_EMBED_DIM);
718+
cast<<<grid_size(BOX_NUM_CELLS * BOX_EMBED_DIM), BLOCK_SIZE, 0, stream>>>(
719+
a->pos_embed_wgrad.data, a->pos_embed_wgrad_f.data, BOX_NUM_CELLS * BOX_EMBED_DIM);
720+
}
721+
722+
static void boxoban_encoder_init_weights(void* w, uint64_t* seed, cudaStream_t stream) {
723+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
724+
auto init2d = [&](PrecisionTensor& t, int rows, int cols, float gain) {
725+
PrecisionTensor wt = {.data = t.data, .shape = {rows, cols}};
726+
puf_kaiming_init(&wt, gain, (*seed)++, stream);
727+
};
728+
puf_normal_init(&ew->type_embed, 1.0f, (*seed)++, stream);
729+
puf_normal_init(&ew->pos_embed, 1.0f, (*seed)++, stream);
730+
init2d(ew->w1, 2 * ew->hidden, BOX_EMBED_FLAT, 1.0f);
731+
init2d(ew->w2, ew->hidden, 2 * ew->hidden, 1.0f);
732+
init2d(ew->w3, ew->hidden, ew->hidden, 1.0f);
733+
cudaMemsetAsync(ew->b1.data, 0, numel(ew->b1.shape) * sizeof(precision_t), stream);
734+
cudaMemsetAsync(ew->b2.data, 0, numel(ew->b2.shape) * sizeof(precision_t), stream);
735+
cudaMemsetAsync(ew->b3.data, 0, numel(ew->b3.shape) * sizeof(precision_t), stream);
736+
}
737+
738+
static void boxoban_encoder_reg_params(void* w, Allocator* alloc) {
739+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
740+
ew->type_embed = {.shape = {BOX_NUM_TYPES, BOX_EMBED_DIM}};
741+
ew->pos_embed = {.shape = {BOX_NUM_CELLS, BOX_EMBED_DIM}};
742+
ew->w1 = {.shape = {2 * ew->hidden, BOX_EMBED_FLAT}};
743+
ew->b1 = {.shape = {2 * ew->hidden}};
744+
ew->w2 = {.shape = {ew->hidden, 2 * ew->hidden}};
745+
ew->b2 = {.shape = {ew->hidden}};
746+
ew->w3 = {.shape = {ew->hidden, ew->hidden}};
747+
ew->b3 = {.shape = {ew->hidden}};
748+
alloc_register(alloc, &ew->type_embed);
749+
alloc_register(alloc, &ew->pos_embed);
750+
alloc_register(alloc, &ew->w1); alloc_register(alloc, &ew->b1);
751+
alloc_register(alloc, &ew->w2); alloc_register(alloc, &ew->b2);
752+
alloc_register(alloc, &ew->w3); alloc_register(alloc, &ew->b3);
753+
}
754+
755+
static void boxoban_encoder_reg_train(void* w, void* activations, Allocator* acts, Allocator* grads, int B_TT) {
756+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
757+
BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
758+
*a = {};
759+
a->cell_flat = {.shape = {B_TT, BOX_EMBED_FLAT}};
760+
a->l1_preact = {.shape = {B_TT, 2 * ew->hidden}};
761+
a->l1_act = {.shape = {B_TT, 2 * ew->hidden}};
762+
a->l2_preact = {.shape = {B_TT, ew->hidden}};
763+
a->l2_act = {.shape = {B_TT, ew->hidden}};
764+
a->l3_preact = {.shape = {B_TT, ew->hidden}};
765+
a->out = {.shape = {B_TT, ew->hidden}};
766+
a->saved_obs = {.shape = {B_TT, ew->obs_size}};
767+
a->grad_cell_flat = {.shape = {B_TT, BOX_EMBED_FLAT}};
768+
a->grad_l1 = {.shape = {B_TT, 2 * ew->hidden}};
769+
a->grad_l2 = {.shape = {B_TT, ew->hidden}};
770+
alloc_register(acts, &a->cell_flat);
771+
alloc_register(acts, &a->l1_preact); alloc_register(acts, &a->l1_act);
772+
alloc_register(acts, &a->l2_preact); alloc_register(acts, &a->l2_act);
773+
alloc_register(acts, &a->l3_preact); alloc_register(acts, &a->out);
774+
alloc_register(acts, &a->saved_obs);
775+
alloc_register(acts, &a->grad_cell_flat);
776+
alloc_register(acts, &a->grad_l1);
777+
alloc_register(acts, &a->grad_l2);
778+
779+
a->type_embed_wgrad = {.shape = {BOX_NUM_TYPES, BOX_EMBED_DIM}};
780+
a->pos_embed_wgrad = {.shape = {BOX_NUM_CELLS, BOX_EMBED_DIM}};
781+
a->w1_wgrad = {.shape = {2 * ew->hidden, BOX_EMBED_FLAT}};
782+
a->b1_wgrad = {.shape = {2 * ew->hidden}};
783+
a->w2_wgrad = {.shape = {ew->hidden, 2 * ew->hidden}};
784+
a->b2_wgrad = {.shape = {ew->hidden}};
785+
a->w3_wgrad = {.shape = {ew->hidden, ew->hidden}};
786+
a->b3_wgrad = {.shape = {ew->hidden}};
787+
a->type_embed_wgrad_f = {.shape = {BOX_NUM_TYPES, BOX_EMBED_DIM}};
788+
a->pos_embed_wgrad_f = {.shape = {BOX_NUM_CELLS, BOX_EMBED_DIM}};
789+
alloc_register(grads, &a->type_embed_wgrad);
790+
alloc_register(grads, &a->pos_embed_wgrad);
791+
alloc_register(grads, &a->w1_wgrad); alloc_register(grads, &a->b1_wgrad);
792+
alloc_register(grads, &a->w2_wgrad); alloc_register(grads, &a->b2_wgrad);
793+
alloc_register(grads, &a->w3_wgrad); alloc_register(grads, &a->b3_wgrad);
794+
alloc_register(acts, &a->type_embed_wgrad_f);
795+
alloc_register(acts, &a->pos_embed_wgrad_f);
796+
}
797+
798+
static void boxoban_encoder_reg_rollout(void* w, void* activations, Allocator* alloc, int B) {
799+
BoxobanEncoderWeights* ew = (BoxobanEncoderWeights*)w;
800+
BoxobanEncoderActivations* a = (BoxobanEncoderActivations*)activations;
801+
a->cell_flat = {.shape = {B, BOX_EMBED_FLAT}};
802+
a->l1_preact = {.shape = {B, 2 * ew->hidden}};
803+
a->l1_act = {.shape = {B, 2 * ew->hidden}};
804+
a->l2_preact = {.shape = {B, ew->hidden}};
805+
a->l2_act = {.shape = {B, ew->hidden}};
806+
a->l3_preact = {.shape = {B, ew->hidden}};
807+
a->out = {.shape = {B, ew->hidden}};
808+
alloc_register(alloc, &a->cell_flat);
809+
alloc_register(alloc, &a->l1_preact); alloc_register(alloc, &a->l1_act);
810+
alloc_register(alloc, &a->l2_preact); alloc_register(alloc, &a->l2_act);
811+
alloc_register(alloc, &a->l3_preact); alloc_register(alloc, &a->out);
812+
}
813+
814+
static void* boxoban_encoder_create_weights(void* self) {
815+
Encoder* e = (Encoder*)self;
816+
return boxoban_encoder_create(e->in_dim, e->out_dim);
817+
}
818+
819+
static void boxoban_encoder_free_weights(void* weights) { free(weights); }
820+
static void boxoban_encoder_free_activations(void* activations) { free(activations); }
821+
573822
// Override encoder vtable for known ocean environments. No-op for unknown envs.
574823
static void create_custom_encoder(const std::string& env_name, Encoder* enc) {
575824
if (env_name == "nmmo3") {
@@ -586,5 +835,19 @@ static void create_custom_encoder(const std::string& env_name, Encoder* enc) {
586835
.in_dim = enc->in_dim, .out_dim = enc->out_dim,
587836
.activation_size = sizeof(NMMO3EncoderActivations),
588837
};
838+
} else if (env_name == "boxoban") {
839+
*enc = Encoder{
840+
.forward = boxoban_encoder_forward,
841+
.backward = boxoban_encoder_backward,
842+
.init_weights = boxoban_encoder_init_weights,
843+
.reg_params = boxoban_encoder_reg_params,
844+
.reg_train = boxoban_encoder_reg_train,
845+
.reg_rollout = boxoban_encoder_reg_rollout,
846+
.create_weights = boxoban_encoder_create_weights,
847+
.free_weights = boxoban_encoder_free_weights,
848+
.free_activations = boxoban_encoder_free_activations,
849+
.in_dim = enc->in_dim, .out_dim = enc->out_dim,
850+
.activation_size = sizeof(BoxobanEncoderActivations),
851+
};
589852
}
590853
}

0 commit comments

Comments
 (0)