Skip to content

Commit f6a0f86

Browse files
committed
refactor color adjustment to use batch processing for improved performance
1 parent 05a4d74 commit f6a0f86

File tree

4 files changed

+64
-11
lines changed

4 files changed

+64
-11
lines changed

ajet/context_tracker/multiagent_tracking.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from ajet.schema.extended_msg import INVALID_LOG_PROB_VALUE
1919
from ajet.schema.trajectory import Reward
20-
from ajet.utils.color_hsl import adjust_color_hsl
20+
from ajet.utils.color_hsl import adjust_color_hsl_batch
2121
from ajet.utils.compute_madness import compute_string_madness
2222
from ajet.utils.tokenizer import ajet_apply_chat_template
2323

@@ -503,15 +503,9 @@ def generate_log(self, task_id=None, global_step="NA"):
503503
logprobs = [INVALID_LOG_PROB_VALUE] * len(
504504
tracker_tokenized["prompt_ids"]
505505
) + tracker_tokenized["response_logprobs"]
506-
# Create adjusted color array
507-
loss_mask_color_abl_arr = [
508-
(
509-
adjust_color_hsl("#09ABCF", logprob)
510-
if mask == 1
511-
else adjust_color_hsl("#D98510", logprob)
512-
)
513-
for mask, logprob in zip(tracker_tokenized["loss_mask"], logprobs)
514-
]
506+
# Create adjusted color array using batch processing for better performance
507+
base_colors = ["#09ABCF" if mask == 1 else "#D98510" for mask in tracker_tokenized["loss_mask"]]
508+
loss_mask_color_abl_arr = adjust_color_hsl_batch(base_colors, logprobs)
515509
logprob_text_arr = [
516510
(f"{logprob:.4f}" if logprob != INVALID_LOG_PROB_VALUE else "N/A")
517511
for logprob in logprobs

ajet/copilot/train-complex-blackbox/SKILL.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ Finally, you can start training.
155155

156156
Run `ajet-swarm start` to start training server (if the user has already installed agentjet swarm environment),
157157
if the user has docker environment, you can also refer to `docs/en/ajet-swarm-docker.md` to start a AgentSwarm docker container.
158+
If the user can provider the ssh connection to the GPU server / cluster, you can send the `ajet-swarm start` command to the remote server via ssh to start the swarm server, the port forward `10086` port (default agentjet swarm port) to user local machine.
158159

159160
Create a duplication of `agent_roll.py` named `agent_roll_one_episode_debug.py`, and modify it to only run one episode, this can help you debug whether the episode runner and reward function work as expected.
160161

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def start_engine(self):
437437
self._wait_until_status_change_to(desired_status="ENGINE.ROLLING")
438438
logger.success("Training engine is now ROLLING and ready.")
439439

440-
def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=1800):
440+
def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=3600):
441441
"""
442442
Poll engine status until it reaches desired_status.
443443
Reports status every 5 seconds while waiting.

ajet/utils/color_hsl.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import colorsys
2+
import numpy as np
3+
from functools import lru_cache
24

35

6+
@lru_cache(maxsize=2048)
47
def adjust_color_hsl(base_color, logprob):
58
"""
69
Adjust color saturation using the HSL color space based on log probability.
@@ -39,3 +42,58 @@ def adjust_color_hsl(base_color, logprob):
3942

4043
# Convert back to hexadecimal
4144
return f"#{int(r_adjusted*255):02x}{int(g_adjusted*255):02x}{int(b_adjusted*255):02x}"
45+
46+
47+
def adjust_color_hsl_batch(base_colors, logprobs):
48+
"""
49+
Vectorized version of adjust_color_hsl for batch processing.
50+
Args:
51+
base_colors (list[str]): List of hexadecimal color strings.
52+
logprobs (list[float]): List of log probability values.
53+
Returns:
54+
list[str]: List of adjusted hexadecimal color strings.
55+
"""
56+
if not base_colors or not logprobs:
57+
return []
58+
59+
# Constants
60+
sat_min = 0.333
61+
sat_max = 1.0
62+
lp_min = -7
63+
lp_max = 0
64+
65+
# Convert to numpy arrays for vectorized operations
66+
logprobs_arr = np.array(logprobs, dtype=np.float32)
67+
68+
# Vectorized saturation factor calculation
69+
saturation_factors = np.where(
70+
logprobs_arr <= lp_min,
71+
sat_min,
72+
np.where(
73+
logprobs_arr >= 0,
74+
sat_max,
75+
sat_min + (logprobs_arr - lp_min) / (lp_max - lp_min) * (sat_max - sat_min)
76+
)
77+
)
78+
79+
# Pre-convert unique base colors to RGB and HSL
80+
unique_colors = list(set(base_colors))
81+
color_to_hls = {}
82+
83+
for color in unique_colors:
84+
r = int(color[1:3], 16) / 255.0
85+
g = int(color[3:5], 16) / 255.0
86+
b = int(color[5:7], 16) / 255.0
87+
h, l, s = colorsys.rgb_to_hls(r, g, b)
88+
color_to_hls[color] = (h, l, s)
89+
90+
# Process each color
91+
result = []
92+
for base_color, sat_factor in zip(base_colors, saturation_factors):
93+
h, l, s = color_to_hls[base_color]
94+
s_adjusted = s * sat_factor
95+
r_adj, g_adj, b_adj = colorsys.hls_to_rgb(h, l, s_adjusted)
96+
hex_color = f"#{int(r_adj*255):02x}{int(g_adj*255):02x}{int(b_adj*255):02x}"
97+
result.append(hex_color)
98+
99+
return result

0 commit comments

Comments
 (0)