Skip to content

Commit 3c2cc9f

Browse files
committed
Rewrite gpu collision for spritelists
1 parent b213792 commit 3c2cc9f

3 files changed

Lines changed: 107 additions & 50 deletions

File tree

arcade/resources/system/shaders/collision/col_trans_vs.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#version 330
22
// A simple passthrough shader forwarding data to the geomtry shader
33

4-
in vec3 in_pos;
4+
in vec4 in_pos;
55
in vec2 in_size;
66

77
out vec2 pos;

arcade/sprite_list/collision.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -134,50 +134,7 @@ def _get_nearby_sprites(
134134
sprite_count = len(sprite_list)
135135
if sprite_count == 0:
136136
return []
137-
138-
# Update the position and size to check
139-
ctx = get_window().ctx
140-
sprite_list._write_sprite_buffers_to_gpu()
141-
142-
ctx.collision_detection_program["check_pos"] = sprite.position
143-
ctx.collision_detection_program["check_size"] = sprite.width, sprite.height
144-
145-
# Ensure the result buffer can fit all the sprites (worst case)
146-
buffer = ctx.collision_buffer
147-
if buffer.size < sprite_count * 4:
148-
buffer.orphan(size=sprite_count * 4)
149-
150-
# Run the transform shader emitting sprites close to the configured position and size.
151-
# This runs in a query so we can measure the number of sprites emitted.
152-
with ctx.collision_query:
153-
sprite_list.sprite_data.geometry.transform( # type: ignore
154-
ctx.collision_detection_program,
155-
buffer,
156-
vertices=sprite_count,
157-
)
158-
159-
# Store the number of sprites emitted
160-
emit_count = ctx.collision_query.primitives_generated
161-
# print(
162-
# emit_count,
163-
# ctx.collision_query.time_elapsed,
164-
# ctx.collision_query.time_elapsed / 1_000_000_000,
165-
# )
166-
167-
# If no sprites emitted we can just return an empty list
168-
if emit_count == 0:
169-
return []
170-
171-
# # Debug block for transform data to keep around
172-
# print("emit_count", emit_count)
173-
# data = buffer.read(size=emit_count * 4)
174-
# print("bytes", data)
175-
# print("data", struct.unpack(f'{emit_count}i', data))
176-
177-
# .. otherwise build and return a list of the sprites selected by the transform
178-
return [
179-
sprite_list[i] for i in struct.unpack(f"{emit_count}i", buffer.read(size=emit_count * 4))
180-
]
137+
return sprite_list.get_nearby_sprites_gpu(sprite.position, sprite.size)
181138

182139

183140
def check_for_collision_with_list(

arcade/sprite_list/sprite_list.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import random
11+
import struct
1112
from abc import abstractmethod
1213
from array import array
1314
from collections import deque
@@ -139,6 +140,23 @@ def draw_hit_boxes(
139140
"""
140141
...
141142

143+
@abstractmethod
144+
def get_nearby_sprites_gpu(self, pos: Point2, size: Point2) -> list[SpriteType_co]:
145+
"""
146+
Get a list of sprites that are nearby the given position and size
147+
using the gpu. No spatial hashing is needed. This is a very fast method
148+
to find nearby sprites in large spritelists but is very expensive
149+
if the method is called many times per frame or if the sprite list
150+
is small.
151+
152+
Args:
153+
pos: The position to check for nearby sprites.
154+
size: The size of the area to check for nearby sprites.
155+
Returns:
156+
A list of sprites nearby the given position and size.
157+
"""
158+
...
159+
142160
@abstractmethod
143161
def _write_sprite_buffers_to_gpu(self) -> None: ...
144162

@@ -305,13 +323,13 @@ def _init_deferred(self) -> None:
305323

306324
# NOTE: Instantiate the appropriate spritelist data class here
307325
# Desktop GL (with geo shader)
308-
# self._spritelist_data = SpriteListBufferData(
309-
# self.ctx, capacity=self._buf_capacity, atlas=self._atlas
310-
# )
311-
# WebGL (without geo shader)
312-
self._spritelist_data = SpriteListTextureData(
326+
self._spritelist_data = SpriteListBufferData(
313327
self.ctx, capacity=self._buf_capacity, atlas=self._atlas
314328
)
329+
# WebGL (without geo shader)
330+
# self._spritelist_data = SpriteListTextureData(
331+
# self.ctx, capacity=self._buf_capacity, atlas=self._atlas
332+
# )
315333

316334
self._initialized = True
317335

@@ -981,6 +999,30 @@ def draw_hit_boxes(
981999

9821000
arcade.draw_lines(points, color=converted_color, line_width=line_thickness)
9831001

1002+
def get_nearby_sprites_gpu(self, pos: Point2, size: Point2) -> list[SpriteType]:
1003+
"""
1004+
Get a list of sprites that are nearby the given position and size
1005+
using the gpu. No spatial hashing is needed. This is a very fast method
1006+
to find nearby sprites in large spritelists but is very expensive
1007+
if the method is called many times per frame or if the sprite list
1008+
is small.
1009+
1010+
Args:
1011+
pos: The position to check for nearby sprites.
1012+
size: The size of the area to check for nearby sprites.
1013+
Returns:
1014+
A list of sprites nearby the given position and size.
1015+
"""
1016+
if not self._initialized:
1017+
self._init_deferred()
1018+
1019+
if len(self.sprite_list) == 0:
1020+
return []
1021+
1022+
self._write_sprite_buffers_to_gpu()
1023+
indices = self._spritelist_data.get_nearby_sprite_indices(pos, size, len(self.sprite_list))
1024+
return [self.sprite_list[i] for i in indices]
1025+
9841026
def _grow_sprite_buffers(self) -> None:
9851027
"""Double the internal buffer sizes"""
9861028
# Resize sprite buffers if needed
@@ -1294,6 +1336,19 @@ def render(
12941336
"""
12951337
raise NotImplementedError("This method should be implemented in subclasses.")
12961338

1339+
def get_nearby_sprite_indices(self, pos: Point2, size: Point2, length: int) -> list[int]:
1340+
"""
1341+
Get indices of sprites that are nearby the given position and size.
1342+
1343+
Args:
1344+
pos: The position to check for nearby sprites.
1345+
size: The size of the area to check for nearby sprites.
1346+
length: The number of sprites in the list.
1347+
Returns:
1348+
A list of indices of nearby sprites.
1349+
"""
1350+
raise NotImplementedError("This method should be implemented in subclasses.")
1351+
12971352

12981353
class SpriteListBufferData(SpriteListData):
12991354
"""Container for all gpu data used by the SpriteList."""
@@ -1577,6 +1632,39 @@ def render(
15771632
if blend_function is not None:
15781633
self.ctx.blend_func = prev_blend_func
15791634

1635+
def get_nearby_sprite_indices(self, pos: Point2, size: Point2, length: int) -> list[int]:
1636+
"""
1637+
Get indices of sprites that are nearby the given position and size.
1638+
1639+
Args:
1640+
pos: The position to check for nearby sprites.
1641+
size: The size of the area to check for nearby sprites.
1642+
length: The number of sprites in the spritelist.
1643+
Returns:
1644+
A list of indices of nearby sprites.
1645+
"""
1646+
ctx = self.ctx
1647+
ctx.collision_detection_program["check_pos"] = pos
1648+
ctx.collision_detection_program["check_size"] = size
1649+
1650+
# Ensure the result buffer can fit all the sprites (worst case)
1651+
buffer = ctx.collision_buffer
1652+
# NOTE: Right now the limit is 1000 hits
1653+
# Run the transform shader emitting sprites close to the configured position and size.
1654+
# This runs in a query so we can measure the number of sprites emitted.
1655+
with ctx.collision_query:
1656+
self._geometry.transform( # type: ignore
1657+
ctx.collision_detection_program,
1658+
buffer,
1659+
vertices=length,
1660+
)
1661+
1662+
# Store the number of sprites emitted
1663+
emit_count = ctx.collision_query.primitives_generated
1664+
if emit_count == 0:
1665+
return []
1666+
return [i for i in struct.unpack(f"{emit_count}i", buffer.read(size=emit_count * 4))]
1667+
15801668

15811669
class SpriteListTextureData(SpriteListData):
15821670
"""Container for all gpu data used by the SpriteList without buffers."""
@@ -1753,3 +1841,15 @@ def render(
17531841
self.ctx.disable(self.ctx.BLEND)
17541842
if blend_function is not None:
17551843
self.ctx.blend_func = prev_blend_func
1844+
1845+
# def get_nearby_sprite_indices(self, pos: Point2, size: Point2, length: int) -> list[int]:
1846+
# """
1847+
# Get indices of sprites that are nearby the given position and size.
1848+
1849+
# Args:
1850+
# pos: The position to check for nearby sprites.
1851+
# size: The size of the area to check for nearby sprites.
1852+
# length: The number of sprites in the spritelist.
1853+
# Returns:
1854+
# A list of indices of nearby sprites.
1855+
# """

0 commit comments

Comments
 (0)