|
| 1 | + |
| 2 | +.. _add_custom_plot: |
| 3 | + |
| 4 | +Add a custom plot to a benchmark |
| 5 | +================================ |
| 6 | + |
| 7 | +Benchopt provides a set of default plots to visualize the results of a benchmark. |
| 8 | +These plots can be complemented with custom plots, defined in the benchmark, |
| 9 | +to visualize the results in a different way. These plots are defined in the |
| 10 | +:code:`plots` directory, by adding python files with classes inheriting from |
| 11 | +:class:`benchopt.BasePlot`. This page details the API to generate custom |
| 12 | +visualizations for your benchmark. |
| 13 | + |
| 14 | +Structure of a custom plot |
| 15 | +-------------------------- |
| 16 | + |
| 17 | +A custom plot is defined by a class inheriting from :class:`benchopt.BasePlot` and implementing: |
| 18 | + |
| 19 | +- :code:`name`: The name of the plot title. This will be the name that appears |
| 20 | + in the plot selection menu of the HTML interface, or the name you can use to |
| 21 | + select this plot in config files for your benchmark. |
| 22 | +- :code:`type`: The type of the plot, which defines how the output of `plot` |
| 23 | + will be rendered. Supported types are :code:`"scatter"`, :code:`"bar_chart"`, |
| 24 | + :code:`"boxplot"` and :code:`"table"`. |
| 25 | +- :code:`options`: A dictionary defining the different options available for the |
| 26 | + plot. Typically, this can be used to have different plots depending on dataset's |
| 27 | + or objective's parameters, or to display customization options. The keys in the |
| 28 | + dictionary are the names of the options, associated to a list of their possible |
| 29 | + values. If a key :code:`objective/dataset/solver/objective_column` is associated |
| 30 | + with the value :code:`...`, the options are automatically inferred from the |
| 31 | + results DataFrame, as all unique values associated with this key. |
| 32 | +- :code:`plot(self, df, **kwargs)`: give the data to produce one plot, that is |
| 33 | + rendered with the plotly or matplotlib backend. The method takes the results DataFrame |
| 34 | + and the options values as arguments, and returns the plot data. The output |
| 35 | + depends on the plot's type, and are detailed below for each of them. |
| 36 | +- :code:`get_metadata(self, df, **kwargs)`: Gives global information about the plot, such |
| 37 | + as the title and axis labels. The method takes the results DataFrame and the options |
| 38 | + values as arguments, and returns the metadata of the plot, which is specific to each |
| 39 | + plot type. |
| 40 | + |
| 41 | +The :code:`get_metadata` method allow to change global properties of |
| 42 | +the resulting visualization, and the :code:`plot` method outputs the data |
| 43 | +necessary to render it. |
| 44 | +The visualization is rendered using either the ``plotly`` or ``matplotly`` backend. |
| 45 | + |
| 46 | +.. code-block:: python |
| 47 | +
|
| 48 | + from benchopt import BasePlot |
| 49 | +
|
| 50 | + class Plot(BasePlot): |
| 51 | + name = "My Custom Plot" |
| 52 | + type = "scatter" # or "bar_chart", "boxplot" or "table" |
| 53 | + options = { |
| 54 | + "dataset": ..., # Automatic options from DataFrame columns |
| 55 | + "objective": ..., |
| 56 | + "my_parameter": [1, 2], # custom options |
| 57 | + } |
| 58 | +
|
| 59 | + # The inputs args of this method correspond to `df` and |
| 60 | + # the keys in the `options` dictionary. |
| 61 | + def plot(self, df, dataset, objective, my_parameter): |
| 62 | + # ... process df ... |
| 63 | + return plot_data |
| 64 | +
|
| 65 | + def get_metadata(self, df, dataset, objective, my_parameter): |
| 66 | + return { |
| 67 | + "title": f"Plot for {dataset}", |
| 68 | + "xlabel": "X Label", |
| 69 | + "ylabel": "Y Label", |
| 70 | + } |
| 71 | +
|
| 72 | +
|
| 73 | +Plot Options |
| 74 | +------------ |
| 75 | + |
| 76 | +The :code:`options` dictionary keys define the arguments passed to |
| 77 | +:code:`plot` and :code:`get_metadata`. Special keys like |
| 78 | +:code:`dataset`, :code:`objective`, :code:`solver` will automatically |
| 79 | +try to match columns in the dataframe. Using :code:`...` as a value |
| 80 | +will populate the options with all unique values from the dataframe |
| 81 | +column :code:`{key}_name` (e.g. :code:`dataset_name`). |
| 82 | + |
| 83 | + |
| 84 | +Scatter Plot |
| 85 | +------------ |
| 86 | + |
| 87 | +For a scatter plot, the :code:`plot` method should return a list of dictionaries, where |
| 88 | +each dictionary represents a trace in the plot. Each dictionary must contain: |
| 89 | + |
| 90 | +- :code:`x`: A list of x values. |
| 91 | +- :code:`y`: A list of y values. |
| 92 | +- :code:`label`: The label of the trace |
| 93 | + |
| 94 | +Optional keys: |
| 95 | + |
| 96 | +- :code:`color`: The color of the trace. |
| 97 | +- :code:`marker`: The marker style of the trace. |
| 98 | +- :code:`x_low`, :code:`x_high`: Lists of values to display uncertainty in the plot. |
| 99 | + They will be used to display shaded area around the plot. |
| 100 | + |
| 101 | +The metadata dictionary returned by :code:`get_metadata` should contain: |
| 102 | + |
| 103 | +- :code:`title`: The title of the plot. |
| 104 | +- :code:`xlabel`: The label of the x-axis. |
| 105 | +- :code:`ylabel`: The label of the y-axis. |
| 106 | + |
| 107 | +.. code-block:: python |
| 108 | +
|
| 109 | + def plot(self, df, dataset, objective, my_parameter): |
| 110 | + # Filter the dataframe |
| 111 | + df = df.query( |
| 112 | + "dataset_name == @dataset and objective_name == @objective" |
| 113 | + ) |
| 114 | +
|
| 115 | + plot_traces = [] |
| 116 | + for solver, df_solver in df.groupby('solver_name'): |
| 117 | + # Compute the median over the repetitions |
| 118 | + curve = ( |
| 119 | + df_solver.groupby("stop_val")[["time", "'objective_value"]] |
| 120 | + .median() |
| 121 | + ) |
| 122 | + plot_traces.append({ |
| 123 | + "x": curve['time'].tolist(), |
| 124 | + "y": curve['objective_value'].tolist(), |
| 125 | + "label": solver, |
| 126 | + **self.get_style(solver) |
| 127 | + }) |
| 128 | + return plot_traces |
| 129 | +
|
| 130 | + def get_metadata(self, df, dataset, objective, my_parameter): |
| 131 | + return { |
| 132 | + "title": f"Convergence for {dataset}", |
| 133 | + "xlabel": "Time [sec]", |
| 134 | + "ylabel": "Objective value", |
| 135 | + } |
| 136 | +
|
| 137 | +.. note:: |
| 138 | + To help with consistent style accross figures, you can use |
| 139 | + the helper ``get_style``, as described in :ref:`plotting_utilities`. |
| 140 | + |
| 141 | + |
| 142 | +Bar Chart |
| 143 | +--------- |
| 144 | + |
| 145 | +For a bar chart, the :code:`plot` method should return a list of dictionaries, |
| 146 | +where each dictionary represents a bar. For each bar, the median value will be |
| 147 | +used to determine its height, while the individual values will be displayed as |
| 148 | +scatter points. The dictionary should contain: |
| 149 | + |
| 150 | +- :code:`y`: The list of values for the bar (the median will be the height of the bar). |
| 151 | +- :code:`label`: The label of the bar. |
| 152 | + |
| 153 | +Optional keys: |
| 154 | + |
| 155 | +- :code:`color`: The color of the bar. |
| 156 | +- :code:`text`: The text to display on the bar. |
| 157 | + |
| 158 | +The metadata dictionary returned by :code:`get_metadata` should contain: |
| 159 | + |
| 160 | +- :code:`title`: The title of the plot. |
| 161 | +- :code:`ylabel`: The label of the y-axis. |
| 162 | + |
| 163 | +.. code-block:: python |
| 164 | +
|
| 165 | + def plot(self, df, dataset, objective, **kwargs): |
| 166 | + df = df.query( |
| 167 | + "dataset_name == @dataset and objective_name == @objective" |
| 168 | + ) |
| 169 | + bars = [] |
| 170 | + for solver, df_solver in df.groupby('solver_name'): |
| 171 | + # Select the total runtime for each repetition |
| 172 | + runtimes = df_solver.groupby("idx_rep")["runtime"].last() |
| 173 | + bars.append({ |
| 174 | + "y": runtimes.tolist(), |
| 175 | + "label": solver, |
| 176 | + "text": "", |
| 177 | + "color": self.get_style(solver)['color'] |
| 178 | + }) |
| 179 | + return bars |
| 180 | +
|
| 181 | + def get_metadata(self, df, dataset, objective, **kwargs): |
| 182 | + return { |
| 183 | + "title": f"Average times for {objective} on {dataset}", |
| 184 | + "ylabel": "Time [sec]", |
| 185 | + } |
| 186 | +
|
| 187 | +
|
| 188 | +Box Plot |
| 189 | +-------- |
| 190 | + |
| 191 | +For a box plot, the :code:`plot` method should return a list of dictionaries, |
| 192 | +where each dictionary represents a box. Each dictionary should contain: |
| 193 | + |
| 194 | +- :code:`x`: The x coordinate. |
| 195 | +- :code:`y`: The values of the box for the corresponding x coordinate. |
| 196 | +- :code:`label`: The label of the box. |
| 197 | + |
| 198 | +Optional keys: |
| 199 | + |
| 200 | +- :code:`color`: The color of the box. |
| 201 | + |
| 202 | +The metadata dictionary returned by :code:`get_metadata` should contain: |
| 203 | + |
| 204 | +- :code:`title`: The title of the plot. |
| 205 | +- :code:`xlabel`: The label of the x-axis. |
| 206 | +- :code:`ylabel`: The label of the y-axis. |
| 207 | + |
| 208 | +.. code-block:: python |
| 209 | +
|
| 210 | + def plot(self, df, dataset, objective, **kwargs): |
| 211 | + df = df.query( |
| 212 | + "dataset_name == @dataset and objective_name == @objective" |
| 213 | + ) |
| 214 | + plot_data = [] |
| 215 | + for solver, df_solver in df.groupby('solver_name'): |
| 216 | + # Example: boxplot for the final objective values |
| 217 | + # for each solver |
| 218 | + final_objective_value = ( |
| 219 | + df_solver.groupby("idx_rep")['objective_value'].last() |
| 220 | + ) |
| 221 | + plot_data.append({ |
| 222 | + "x": [solver], |
| 223 | + "y": [final_objective_value.tolist()], |
| 224 | + "label": solver, |
| 225 | + "color": self.get_style(solver)['color'] |
| 226 | + }) |
| 227 | + return plot_data |
| 228 | +
|
| 229 | + def get_metadata(self, df, dataset, objective, **kwargs): |
| 230 | + return { |
| 231 | + "title": f"Boxplot for {objective} on {dataset}", |
| 232 | + "xlabel": "Solver", |
| 233 | + "ylabel": "Objective value", |
| 234 | + } |
| 235 | +
|
| 236 | +
|
| 237 | +Table Plot |
| 238 | +---------- |
| 239 | + |
| 240 | +For a table plot, the :code:`plot` method should return a list of lists, |
| 241 | +where each inner list represents a row in the table. |
| 242 | +The metadata dictionary returned by :code:`get_metadata` should contain: |
| 243 | + |
| 244 | +- :code:`title`: The title of the plot. |
| 245 | +- :code:`columns`: A list of column names. |
| 246 | + |
| 247 | +.. code-block:: python |
| 248 | +
|
| 249 | + def plot(self, df, dataset, objective, **kwargs): |
| 250 | + df = df.query( |
| 251 | + "dataset_name == @dataset and objective_name == @objective" |
| 252 | + ) |
| 253 | + rows = [] |
| 254 | + for solver, df_solver in df.groupby('solver_name'): |
| 255 | + # Example: table with solver name and mean time |
| 256 | + # when using `sampling_strategy = 'run_once'` |
| 257 | + rows.append([solver, df_solver['time'].mean()]) |
| 258 | + return rows |
| 259 | +
|
| 260 | + def get_metadata(self, df, dataset, objective, **kwargs): |
| 261 | + return { |
| 262 | + "title": f"Summary for {dataset}", |
| 263 | + "columns": ["Solver", "Mean Time [sec]"], |
| 264 | + } |
| 265 | +
|
| 266 | +
|
| 267 | +.. _plotting_utilities: |
| 268 | +Plotting Utilities |
| 269 | +------------------ |
| 270 | + |
| 271 | +To ensure consistency across plots (e.g., using the same color and marker for a |
| 272 | +given solver), :class:`benchopt.BasePlot` provides the helper method |
| 273 | +:code:`get_style(label)`. This method returns a dictionary with :code:`color` |
| 274 | +and :code:`marker` keys, which can be directly unpacked into the trace dictionary |
| 275 | +for scatter plots or used to select the color for other plot types. It |
| 276 | +automatically assigns a color from the default color palette based on the hash |
| 277 | +of the label, ensuring that the same solver always gets the same color. |
| 278 | + |
| 279 | +.. code-block:: python |
| 280 | +
|
| 281 | + # Usage in plot() |
| 282 | + style = self.get_style(solver_name) |
| 283 | + trace = { |
| 284 | + # ... |
| 285 | + "color": style["color"], |
| 286 | + "marker": style["marker"] |
| 287 | + } |
| 288 | + # Or simply: |
| 289 | + trace = { |
| 290 | + # ... |
| 291 | + **self.get_style(solver_name) |
| 292 | + } |
| 293 | +
|
0 commit comments