Skip to content

Commit 247a34d

Browse files
fix
1 parent cc6d1bc commit 247a34d

1 file changed

Lines changed: 17 additions & 17 deletions

File tree

tests/test_modules.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import contextlib
12
import inspect
3+
import logging
24

35
import pytest
46
import torch
@@ -8,6 +10,12 @@
810
from tests.helpers import get_available_devices, id_formatter, is_supported_on_hpu
911

1012

13+
@contextlib.contextmanager
14+
def caplog_at_level(caplog, level, logger_name):
15+
with caplog.at_level(level, logger=logger_name):
16+
yield
17+
18+
1119
class MockArgs:
1220
def __init__(self, initial_data):
1321
for key in initial_data:
@@ -453,46 +461,38 @@ def test_embedding_error(device, embedding_class, input_shape, embedding_dim, qu
453461

454462

455463
@pytest.mark.parametrize("device", get_available_devices())
456-
def test_4bit_linear_warnings(device):
464+
def test_4bit_linear_warnings(device, caplog):
457465
dim1 = 64
458466

459-
with pytest.warns(UserWarning, match=r"inference or training"):
460-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
461-
net = net.to(device)
462-
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
463-
net(inp)
464-
with pytest.warns(UserWarning, match=r"inference."):
465-
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
466-
net = net.to(device)
467-
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
468-
net(inp)
469-
470-
with pytest.warns(UserWarning) as record:
467+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
471468
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
472469
net = net.to(device)
473470
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
474471
net(inp)
472+
assert any("inference or training" in msg for msg in caplog.messages)
475473

474+
caplog.clear()
475+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
476476
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
477477
net = net.to(device)
478478
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
479479
net(inp)
480-
481-
assert len(record) == 2
480+
assert any("inference." in msg for msg in caplog.messages)
482481

483482

484483
@pytest.mark.parametrize("device", get_available_devices())
485-
def test_4bit_embedding_warnings(device):
484+
def test_4bit_embedding_warnings(device, caplog):
486485
num_embeddings = 128
487486
default_block_size = 64
488487

489-
with pytest.warns(UserWarning, match=r"inference."):
488+
with caplog_at_level(caplog, logging.WARNING, "bitsandbytes.nn.modules"):
490489
net = bnb.nn.Embedding4bit(
491490
num_embeddings=num_embeddings, embedding_dim=default_block_size + 1, quant_type="nf4"
492491
)
493492
net.to(device)
494493
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)
495494
net(inp)
495+
assert any("inference" in msg for msg in caplog.messages)
496496

497497

498498
def test_4bit_embedding_weight_fsdp_fix(requires_cuda):

0 commit comments

Comments
 (0)