Skip to content

Commit befe42d

Browse files
Add printoptions (#3333)
1 parent 80a1c20 commit befe42d

9 files changed

Lines changed: 174 additions & 13 deletions

File tree

docs/src/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ are the CPU and GPU.
3232
install
3333

3434
.. toctree::
35-
:caption: Usage
35+
:caption: Usage
3636
:maxdepth: 1
3737

3838
usage/quick_start
@@ -78,6 +78,7 @@ are the CPU and GPU.
7878
python/optimizers
7979
python/distributed
8080
python/tree_utils
81+
python/printoptions
8182

8283
.. toctree::
8384
:caption: C++ API Reference

docs/src/python/printoptions.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Print Options
2+
===============
3+
4+
.. currentmodule:: mlx.core
5+
6+
.. autosummary::
7+
:toctree: _autosummary
8+
9+
PrintOptions
10+
set_printoptions
11+
printoptions
12+
get_printoptions

mlx/utils.cpp

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright © 2023 Apple Inc.
22

33
#include <cstdlib>
4+
#include <iomanip>
45
#include <iostream>
56
#include <sstream>
67
#include <vector>
@@ -57,23 +58,51 @@ inline void PrintFormatter::print(std::ostream& os, uint64_t val) {
5758
os << val;
5859
}
5960
inline void PrintFormatter::print(std::ostream& os, float16_t val) {
60-
os << val;
61+
if (format_options.precision == -1) {
62+
os << val;
63+
} else {
64+
os << std::fixed << std::setprecision(format_options.precision) << val;
65+
}
6166
}
6267
inline void PrintFormatter::print(std::ostream& os, bfloat16_t val) {
63-
os << val;
68+
if (format_options.precision == -1) {
69+
os << val;
70+
} else {
71+
os << std::fixed << std::setprecision(format_options.precision) << val;
72+
}
6473
}
6574
inline void PrintFormatter::print(std::ostream& os, float val) {
66-
os << val;
75+
if (format_options.precision == -1) {
76+
os << val;
77+
} else {
78+
os << std::fixed << std::setprecision(format_options.precision) << val;
79+
}
6780
}
6881
inline void PrintFormatter::print(std::ostream& os, double val) {
69-
os << val;
82+
if (format_options.precision == -1) {
83+
os << val;
84+
} else {
85+
os << std::fixed << std::setprecision(format_options.precision) << val;
86+
}
7087
}
7188
inline void PrintFormatter::print(std::ostream& os, complex64_t val) {
72-
os << val.real();
73-
if (val.imag() >= 0 || std::isnan(val.imag())) {
74-
os << "+" << val.imag() << "j";
89+
if (format_options.precision == -1) {
90+
os << val.real();
91+
if (val.imag() >= 0 || std::isnan(val.imag())) {
92+
os << "+" << val.imag() << "j";
93+
} else {
94+
os << "-" << -val.imag() << "j";
95+
}
7596
} else {
76-
os << "-" << -val.imag() << "j";
97+
os << std::fixed << std::setprecision(format_options.precision)
98+
<< val.real();
99+
if (val.imag() >= 0 || std::isnan(val.imag())) {
100+
os << "+" << std::fixed << std::setprecision(format_options.precision)
101+
<< val.imag() << "j";
102+
} else {
103+
os << "-" << std::fixed << std::setprecision(format_options.precision)
104+
<< -val.imag() << "j";
105+
}
77106
}
78107
}
79108

@@ -82,6 +111,11 @@ PrintFormatter& get_global_formatter() {
82111
return formatter;
83112
}
84113

114+
void set_printoptions(PrintOptions options) {
115+
auto& formatter = get_global_formatter();
116+
formatter.format_options = options;
117+
}
118+
85119
void abort_with_exception(const std::exception& error) {
86120
std::ostringstream msg;
87121
msg << "Terminating due to uncaught exception: " << error.what();

mlx/utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ struct StreamContext {
3838
Stream _stream;
3939
};
4040

41+
struct MLX_API PrintOptions {
42+
int precision{-1};
43+
};
44+
4145
struct PrintFormatter {
4246
inline void print(std::ostream& os, bool val);
4347
inline void print(std::ostream& os, int16_t val);
@@ -53,8 +57,11 @@ struct PrintFormatter {
5357
inline void print(std::ostream& os, complex64_t val);
5458

5559
bool capitalize_bool{false};
60+
PrintOptions format_options;
5661
};
5762

63+
MLX_API void set_printoptions(PrintOptions options);
64+
5865
MLX_API PrintFormatter& get_global_formatter();
5966

6067
/** Print the exception and then abort. */

python/src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ nanobind_add_module(
2727
${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp
2828
${CMAKE_CURRENT_SOURCE_DIR}/constants.cpp
2929
${CMAKE_CURRENT_SOURCE_DIR}/trees.cpp
30-
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
30+
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
31+
${CMAKE_CURRENT_SOURCE_DIR}/print.cpp)
3132

3233
if(MLX_BUILD_PYTHON_STUBS)
3334
nanobind_add_stub(

python/src/array.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <nanobind/typing.h>
1313

1414
#include "mlx/backend/metal/metal.h"
15+
#include "mlx/utils.h"
1516
#include "python/src/buffer.h"
1617
#include "python/src/convert.h"
1718
#include "python/src/indexing.h"
@@ -97,9 +98,6 @@ class ArrayPythonIterator {
9798
};
9899

99100
void init_array(nb::module_& m) {
100-
// Set Python print formatting options
101-
mx::get_global_formatter().capitalize_bool = true;
102-
103101
// Types
104102
nb::class_<mx::Dtype>(
105103
m,

python/src/mlx.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ void init_constants(nb::module_&);
2323
void init_fast(nb::module_&);
2424
void init_distributed(nb::module_&);
2525
void init_export(nb::module_&);
26+
void init_print(nb::module_&);
2627

2728
NB_MODULE(core, m) {
2829
m.doc() = "mlx: A framework for machine learning on Apple silicon.";
@@ -46,6 +47,7 @@ NB_MODULE(core, m) {
4647
init_fast(m);
4748
init_distributed(m);
4849
init_export(m);
50+
init_print(m);
4951

5052
m.attr("__version__") = mx::version();
5153
}

python/src/print.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#include <cstdint>
2+
#include <cstring>
3+
#include <sstream>
4+
5+
#include <nanobind/typing.h>
6+
7+
#include "mlx/utils.h"
8+
#include "python/src/utils.h"
9+
10+
#include "mlx/mlx.h"
11+
12+
namespace mx = mlx::core;
13+
namespace nb = nanobind;
14+
using namespace nb::literals;
15+
16+
struct PrintOptionsContext {
17+
mx::PrintOptions old_options;
18+
mx::PrintOptions new_options;
19+
PrintOptionsContext(mx::PrintOptions p) : new_options(p) {}
20+
PrintOptionsContext& enter() {
21+
old_options = mx::get_global_formatter().format_options;
22+
mx::set_printoptions(new_options);
23+
return *this;
24+
}
25+
void exit(nb::args) {
26+
mx::set_printoptions(old_options);
27+
}
28+
};
29+
30+
void init_print(nb::module_& m) {
31+
// Set Python print formatting options
32+
mx::get_global_formatter().capitalize_bool = true;
33+
// Expose printing options to Python: allow setting global precision.
34+
nb::class_<mx::PrintOptions>(m, "PrintOptions")
35+
.def(nb::init<int>(), "precision"_a = -1)
36+
.def_rw("precision", &mx::PrintOptions::precision);
37+
38+
m.def(
39+
"set_printoptions",
40+
[](int precision) { mx::set_printoptions({precision}); },
41+
"precision"_a = mx::get_global_formatter().format_options.precision,
42+
R"pbdoc(
43+
Set global printing precision for array formatting.
44+
45+
Example:
46+
>>> print(x) # Uses default precision
47+
>>> mx.set_printoptions(precision=3):
48+
>>> print(x) # Uses precision of 3
49+
>>> print(x) # Uses precision of 3 (again)
50+
51+
Args:
52+
precision (int): Number of decimal places.
53+
)pbdoc");
54+
m.def(
55+
"get_printoptions",
56+
[]() { return mx::get_global_formatter().format_options; },
57+
R"pbdoc(
58+
Get global printing precision for array formatting.
59+
60+
Returns:
61+
PrintOptions: The format options used for printing arrays.
62+
)pbdoc");
63+
64+
nb::class_<PrintOptionsContext>(m, "_PrintOptionsContext")
65+
.def(nb::init<mx::PrintOptions>())
66+
.def("__enter__", &PrintOptionsContext::enter)
67+
.def("__exit__", &PrintOptionsContext::exit);
68+
69+
m.def(
70+
"printoptions",
71+
[](int precision) { return PrintOptionsContext({precision}); },
72+
"precision"_a = mx::get_global_formatter().format_options.precision,
73+
R"pbdoc(
74+
Context manager for setting print options temporarily.
75+
76+
Example:
77+
>>> print(x) # Uses default precision
78+
>>> with mx.printoptions(precision=3):
79+
>>> print(x) # Uses precision of 3
80+
>>> print(x) # Back to default precision
81+
82+
83+
Args:
84+
precision (int): Number of decimal places. Use -1 for default
85+
)pbdoc");
86+
}

python/tests/test_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,26 @@ def test_array_repr(self):
597597
x = mx.array([1 - 1j], dtype=mx.complex64)
598598
expected = "array([1-1j], dtype=complex64)"
599599

600+
def test_array_repr_precision(self):
601+
x = mx.array([1.123456789], dtype=mx.float32)
602+
expected = "array([1.12346], dtype=float32)"
603+
self.assertEqual(str(x), expected)
604+
605+
with mx.printoptions(precision=4):
606+
expected = "array([1.1235], dtype=float32)"
607+
self.assertEqual(str(x), expected)
608+
mx.set_printoptions(precision=2)
609+
expected = "array([1.12], dtype=float32)"
610+
self.assertEqual(str(x), expected)
611+
612+
x = mx.sin(x)
613+
expected = "array([0.90], dtype=float32)"
614+
self.assertEqual(str(x), expected)
615+
616+
with mx.printoptions(precision=4):
617+
expected = "array([0.9016], dtype=float32)"
618+
self.assertEqual(str(x), expected)
619+
600620
def test_array_to_list(self):
601621
types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32]
602622
for t in types:

0 commit comments

Comments
 (0)