-
Notifications
You must be signed in to change notification settings - Fork 536
Expand file tree
/
Copy pathloss_utils.py
More file actions
109 lines (90 loc) · 3.5 KB
/
loss_utils.py
File metadata and controls
109 lines (90 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import inspect
from typing import Optional
from typing import Tuple
import torch
import torch.nn as nn
import liger_kernel.transformers.functional as F
from liger_kernel.transformers.functional import CrossEntropyOutput
def unpack_cross_entropy_result(
result,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
if isinstance(result, CrossEntropyOutput):
return result.loss, result.z_loss, result.token_accuracy, result.predicted_tokens
if isinstance(result, tuple):
loss = result[0]
z_loss = result[1] if len(result) > 1 else None
token_accuracy = result[2] if len(result) > 2 else None
predicted_tokens = result[3] if len(result) > 3 else None
return loss, z_loss, token_accuracy, predicted_tokens
return result, None, None, None
def fixed_fused_linear_cross_entropy(
hidden_states: torch.Tensor,
lm_head_weight: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
final_logit_softcapping: Optional[float] = None,
accum_dtype: Optional[torch.dtype] = None,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
**kwargs,
):
reduction = kwargs.pop("reduction", None)
if reduction is None:
reduction = "sum" if num_items_in_batch is not None else "mean"
result = F.liger_fused_linear_cross_entropy(
hidden_states,
lm_head_weight,
target,
reduction=reduction,
ignore_index=ignore_index,
softcap=final_logit_softcapping,
accum_dtype=accum_dtype,
return_token_accuracy=return_token_accuracy,
return_predicted_tokens=return_predicted_tokens,
**kwargs,
)
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
if reduction == "sum":
loss = loss / num_items_in_batch
if return_token_accuracy or return_predicted_tokens:
return CrossEntropyOutput(loss=loss, token_accuracy=token_accuracy, predicted_tokens=predicted_tokens)
return loss
def LigerForCausalLMLoss(
hidden_states,
lm_head_weight,
labels,
hidden_size: int,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
final_logit_softcapping: Optional[float] = None,
return_token_accuracy: bool = False,
return_predicted_tokens: bool = False,
**kwargs,
):
# Filter out inapplicable kwargs to liger_fused_linear_cross_entropy
applicable_params = inspect.signature(F.liger_fused_linear_cross_entropy).parameters
kwargs = {k: v for k, v in kwargs.items() if k in applicable_params}
# Skip upcast since intermediate values for the loss are all fp32 in kernel
if shift_labels is None:
# Shift so that token < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
hidden_states = hidden_states.view(-1, hidden_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(hidden_states.device)
result = fixed_fused_linear_cross_entropy(
hidden_states,
lm_head_weight,
shift_labels,
num_items_in_batch,
ignore_index,
final_logit_softcapping,
return_token_accuracy=return_token_accuracy,
return_predicted_tokens=return_predicted_tokens,
**kwargs,
)
return result