Skip to content

Commit 4501ed4

Browse files
committed
Avoid invoke tensor.transpose(0, 1).contiguous() when the shapes already match.
Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 28b3870 commit 4501ed4

2 files changed

Lines changed: 217 additions & 118 deletions

File tree

gptqmodel/utils/structure.py

Lines changed: 150 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,24 +1921,11 @@ def _transform_checkpoint_tensor(
19211921

19221922
expected_shape = tuple(expected_shape)
19231923

1924-
# Some checkpoints store expert projections as (out, in) while others store
1925-
# them as (in, out). Keep both candidates and let the defused leaf shape be
1926-
# the final arbiter instead of hard-coding one model family's layout.
1927-
candidates: list[tuple[torch.Tensor, bool]] = [(tensor, False)]
1928-
if tensor.ndim == 2:
1929-
transposed = tensor.transpose(0, 1).contiguous()
1930-
if prefer_transposed is True:
1931-
candidates = [(transposed, True), (tensor, False)]
1932-
elif prefer_transposed is None and transposed.shape != tensor.shape:
1933-
candidates.append((transposed, True))
1934-
elif prefer_transposed is False and transposed.shape != tensor.shape:
1935-
candidates.append((transposed, True))
1936-
1937-
for candidate, used_transpose in candidates:
1924+
def try_candidate(candidate: torch.Tensor, *, used_transpose: bool) -> Optional[torch.Tensor]:
19381925
if split_index is None:
19391926
if tuple(candidate.shape) == expected_shape:
19401927
return candidate.contiguous()
1941-
continue
1928+
return None
19421929

19431930
preferred_dims: list[int] = []
19441931
mapped_split_dim = split_dim
@@ -1976,7 +1963,27 @@ def _transform_checkpoint_tensor(
19761963
if tuple(split_tensor.shape) == expected_shape:
19771964
return split_tensor
19781965

1979-
return None
1966+
return None
1967+
1968+
# Some checkpoints store expert projections as (out, in) while others store
1969+
# them as (in, out). Try one layout at a time so the transpose copy is only
1970+
# materialized when the first candidate cannot satisfy the target shape.
1971+
if tensor.ndim == 2 and prefer_transposed is True:
1972+
# Honor explicit layout hints first; fall back only if shape/split checks fail.
1973+
transposed = tensor.transpose(0, 1).contiguous()
1974+
matched = try_candidate(transposed, used_transpose=True)
1975+
if matched is not None:
1976+
return matched
1977+
return try_candidate(tensor, used_transpose=False)
1978+
1979+
# Most checkpoint tensors already match the shell shape; avoid an eager transpose copy.
1980+
matched = try_candidate(tensor, used_transpose=False)
1981+
if matched is not None or tensor.ndim != 2 or tensor.shape[0] == tensor.shape[1]:
1982+
return matched
1983+
1984+
# Last resort for rectangular weights whose checkpoint layout is opposite of the shell.
1985+
transposed = tensor.transpose(0, 1).contiguous()
1986+
return try_candidate(transposed, used_transpose=True)
19801987

19811988
@staticmethod
19821989
def _resolve_prefer_transposed_hint(
@@ -2170,124 +2177,149 @@ def _copy_checkpoint_tensors_into_submodule(
21702177
continue
21712178
grouped_names.setdefault(shard, []).append(("buffer", rel_name, full_name, expert_index, split_index, split_dim))
21722179

2173-
with torch.inference_mode():
2174-
for shard, entries in grouped_names.items():
2175-
shard_path = os.path.join(self.model_local_path, shard)
2176-
with safe_open(shard_path, framework="pt", device="cpu") as handler:
2177-
for kind, rel_name, full_name, expert_index, split_index, split_dim in entries:
2178-
target_tensor = t_params.get(rel_name) if kind == "param" else t_bufs.get(rel_name)
2179-
expected_shape = tuple(target_tensor.shape) if target_tensor is not None else None
2180-
prefer_transposed = self._resolve_prefer_transposed_hint(
2181-
target_model=target_model,
2182-
module_path=module_path,
2183-
rel_name=rel_name,
2184-
modules_by_name=modules_by_name,
2185-
)
2186-
checkpoint_tensor = handler.get_tensor(full_name)
2187-
tensor = self._transform_checkpoint_tensor(
2188-
checkpoint_tensor,
2189-
expert_index=expert_index,
2190-
split_index=split_index,
2191-
split_dim=split_dim,
2192-
expected_shape=expected_shape,
2193-
prefer_transposed=prefer_transposed,
2194-
)
2195-
if tensor is None:
2196-
raise RuntimeError(self._materialization_issue_message(
2197-
phase="submodule materialization",
2198-
kind=kind,
2180+
total_entries = sum(len(entries) for entries in grouped_names.values())
2181+
progress = None
2182+
loaded_entries = 0
2183+
if total_entries:
2184+
progress = log.pb(range(total_entries)).manual().set(show_left_steps=False)
2185+
module_label = module_path or "<root>"
2186+
progress.title(f"Loading checkpoint tensors ({total_entries})")
2187+
progress.subtitle(f"{module_label}: 0/{total_entries}")
2188+
progress.draw(force=True)
2189+
2190+
try:
2191+
with torch.inference_mode():
2192+
for shard, entries in grouped_names.items():
2193+
shard_path = os.path.join(self.model_local_path, shard)
2194+
with safe_open(shard_path, framework="pt", device="cpu") as handler:
2195+
for kind, rel_name, full_name, expert_index, split_index, split_dim in entries:
2196+
if progress is not None:
2197+
progress.current_iter_step = loaded_entries
2198+
progress.subtitle(f"{rel_name}: {loaded_entries + 1}/{total_entries}")
2199+
progress.draw()
2200+
target_tensor = t_params.get(rel_name) if kind == "param" else t_bufs.get(rel_name)
2201+
expected_shape = tuple(target_tensor.shape) if target_tensor is not None else None
2202+
prefer_transposed = self._resolve_prefer_transposed_hint(
2203+
target_model=target_model,
21992204
module_path=module_path,
22002205
rel_name=rel_name,
2201-
reason="checkpoint tensor could not be reshaped into the target layout",
2202-
full_name=full_name,
2203-
source_shape=tuple(checkpoint_tensor.shape),
2204-
target_shape=expected_shape,
2206+
modules_by_name=modules_by_name,
2207+
)
2208+
checkpoint_tensor = handler.get_tensor(full_name)
2209+
tensor = self._transform_checkpoint_tensor(
2210+
checkpoint_tensor,
22052211
expert_index=expert_index,
22062212
split_index=split_index,
22072213
split_dim=split_dim,
2208-
))
2209-
if kind == "param":
2210-
target_param = t_params.get(rel_name)
2211-
if target_param is None:
2212-
raise RuntimeError(self._materialization_issue_message(
2213-
phase="submodule materialization",
2214-
kind=kind,
2215-
module_path=module_path,
2216-
rel_name=rel_name,
2217-
reason="target tensor disappeared before materialization",
2218-
full_name=full_name,
2219-
source_shape=tuple(tensor.shape),
2220-
expert_index=expert_index,
2221-
split_index=split_index,
2222-
split_dim=split_dim,
2223-
))
2224-
if target_param.shape != tensor.shape:
2214+
expected_shape=expected_shape,
2215+
prefer_transposed=prefer_transposed,
2216+
)
2217+
if tensor is None:
22252218
raise RuntimeError(self._materialization_issue_message(
22262219
phase="submodule materialization",
22272220
kind=kind,
22282221
module_path=module_path,
22292222
rel_name=rel_name,
2230-
reason="target tensor shape does not match the transformed checkpoint tensor",
2223+
reason="checkpoint tensor could not be reshaped into the target layout",
22312224
full_name=full_name,
2232-
source_shape=tuple(tensor.shape),
2233-
target_shape=tuple(target_param.shape),
2225+
source_shape=tuple(checkpoint_tensor.shape),
2226+
target_shape=expected_shape,
22342227
expert_index=expert_index,
22352228
split_index=split_index,
22362229
split_dim=split_dim,
22372230
))
2238-
target_param_new = _ensure_target_storage_on_device_(target_param, device)
2239-
if target_param_new is not target_param:
2231+
if kind == "param":
2232+
target_param = t_params.get(rel_name)
2233+
if target_param is None:
2234+
raise RuntimeError(self._materialization_issue_message(
2235+
phase="submodule materialization",
2236+
kind=kind,
2237+
module_path=module_path,
2238+
rel_name=rel_name,
2239+
reason="target tensor disappeared before materialization",
2240+
full_name=full_name,
2241+
source_shape=tuple(tensor.shape),
2242+
expert_index=expert_index,
2243+
split_index=split_index,
2244+
split_dim=split_dim,
2245+
))
2246+
if target_param.shape != tensor.shape:
2247+
raise RuntimeError(self._materialization_issue_message(
2248+
phase="submodule materialization",
2249+
kind=kind,
2250+
module_path=module_path,
2251+
rel_name=rel_name,
2252+
reason="target tensor shape does not match the transformed checkpoint tensor",
2253+
full_name=full_name,
2254+
source_shape=tuple(tensor.shape),
2255+
target_shape=tuple(target_param.shape),
2256+
expert_index=expert_index,
2257+
split_index=split_index,
2258+
split_dim=split_dim,
2259+
))
2260+
target_param_new = _ensure_target_storage_on_device_(target_param, device)
2261+
if target_param_new is not target_param:
2262+
t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name)
2263+
setattr(t_parent, leaf, target_param_new)
2264+
target_param = target_param_new
2265+
source = tensor.detach()
2266+
if source.dtype != target_param.dtype:
2267+
source = source.to(dtype=target_param.dtype)
2268+
target_param.detach().copy_(source, non_blocking=(non_blocking and source.is_pinned()))
2269+
else:
2270+
target_buffer = t_bufs.get(rel_name)
22402271
t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name)
2241-
setattr(t_parent, leaf, target_param_new)
2242-
target_param = target_param_new
2243-
source = tensor.detach()
2244-
if source.dtype != target_param.dtype:
2245-
source = source.to(dtype=target_param.dtype)
2246-
target_param.detach().copy_(source, non_blocking=(non_blocking and source.is_pinned()))
2247-
continue
2248-
2249-
target_buffer = t_bufs.get(rel_name)
2250-
t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name)
2251-
persistent = leaf not in getattr(t_parent, "_non_persistent_buffers_set", set())
2252-
2253-
source = tensor.detach()
2254-
if target_buffer is None:
2255-
new_buffer = source.to(device=device)
2256-
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2257-
t_bufs[rel_name] = new_buffer
2258-
continue
2259-
2260-
if tuple(target_buffer.shape) != tuple(source.shape):
2261-
raise RuntimeError(self._materialization_issue_message(
2262-
phase="submodule materialization",
2263-
kind=kind,
2264-
module_path=module_path,
2265-
rel_name=rel_name,
2266-
reason="target tensor shape does not match the transformed checkpoint tensor",
2267-
full_name=full_name,
2268-
source_shape=tuple(source.shape),
2269-
target_shape=tuple(target_buffer.shape),
2270-
expert_index=expert_index,
2271-
split_index=split_index,
2272-
split_dim=split_dim,
2273-
))
2274-
2275-
if getattr(target_buffer, "is_meta", False) or target_buffer.device.type == "meta":
2276-
new_buffer = torch.empty_like(target_buffer, device=device)
2277-
new_buffer.copy_(source.to(dtype=new_buffer.dtype), non_blocking=(non_blocking and source.is_pinned()))
2278-
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2279-
t_bufs[rel_name] = new_buffer
2280-
continue
2281-
2282-
if target_buffer.device != device:
2283-
new_buffer = torch.empty_like(target_buffer, device=device)
2284-
new_buffer.copy_(source.to(dtype=new_buffer.dtype), non_blocking=(non_blocking and source.is_pinned()))
2285-
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2286-
t_bufs[rel_name] = new_buffer
2287-
else:
2288-
if source.dtype != target_buffer.dtype:
2289-
source = source.to(dtype=target_buffer.dtype)
2290-
target_buffer.copy_(source, non_blocking=(non_blocking and source.is_pinned()))
2272+
persistent = leaf not in getattr(t_parent, "_non_persistent_buffers_set", set())
2273+
2274+
source = tensor.detach()
2275+
if target_buffer is None:
2276+
new_buffer = source.to(device=device)
2277+
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2278+
t_bufs[rel_name] = new_buffer
2279+
else:
2280+
if tuple(target_buffer.shape) != tuple(source.shape):
2281+
raise RuntimeError(self._materialization_issue_message(
2282+
phase="submodule materialization",
2283+
kind=kind,
2284+
module_path=module_path,
2285+
rel_name=rel_name,
2286+
reason="target tensor shape does not match the transformed checkpoint tensor",
2287+
full_name=full_name,
2288+
source_shape=tuple(source.shape),
2289+
target_shape=tuple(target_buffer.shape),
2290+
expert_index=expert_index,
2291+
split_index=split_index,
2292+
split_dim=split_dim,
2293+
))
2294+
2295+
if getattr(target_buffer, "is_meta", False) or target_buffer.device.type == "meta":
2296+
new_buffer = torch.empty_like(target_buffer, device=device)
2297+
new_buffer.copy_(
2298+
source.to(dtype=new_buffer.dtype),
2299+
non_blocking=(non_blocking and source.is_pinned()),
2300+
)
2301+
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2302+
t_bufs[rel_name] = new_buffer
2303+
elif target_buffer.device != device:
2304+
new_buffer = torch.empty_like(target_buffer, device=device)
2305+
new_buffer.copy_(
2306+
source.to(dtype=new_buffer.dtype),
2307+
non_blocking=(non_blocking and source.is_pinned()),
2308+
)
2309+
t_parent.register_buffer(leaf, new_buffer, persistent=persistent)
2310+
t_bufs[rel_name] = new_buffer
2311+
else:
2312+
if source.dtype != target_buffer.dtype:
2313+
source = source.to(dtype=target_buffer.dtype)
2314+
target_buffer.copy_(source, non_blocking=(non_blocking and source.is_pinned()))
2315+
2316+
loaded_entries += 1
2317+
if progress is not None:
2318+
progress.current_iter_step = loaded_entries
2319+
progress.draw()
2320+
finally:
2321+
if progress is not None:
2322+
progress.close()
22912323

22922324
self._restore_missing_nonpersistent_buffers(
22932325
target_model=target_model,

0 commit comments

Comments
 (0)