Skip to content

about start_token #1

@threeColorFr

Description

@threeColorFr

why don't y_ids and lm_labels have a start_token like <s>
May this affect decode during the testing phase?

def train(epoch, tokenizer, model, device, loader, optimizer):

    """
    用于训练的方法
    Function to be called for training with the parameters passed from main function

    """

    model.train()
    time1=time.time()
    for _, data in enumerate(loader, 0):
        y = data["target_ids"].to(device, dtype=torch.long)
        y_ids = y[:, :-1].contiguous() # target, from start to end(except end of token, <EOS>). e.g. "你好吗?"
        lm_labels = y[:, 1:].clone().detach() # target, for second to end.e.g."好吗?<EOS>"
        lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100 # releted to pad_token and loss. for detail, check here: https://github.com/Shivanandroy/T5-Finetuning-PyTorch/issues/3
        ids = data["source_ids"].to(device, dtype=torch.long) # input. e.g. "how are you?"
        mask = data["source_mask"].to(device, dtype=torch.long)

        outputs = model(
            input_ids=ids,
            attention_mask=mask,
            decoder_input_ids=y_ids,
            labels=lm_labels,
        )
        loss = outputs[0]
        # 每100步打印日志
        if _ % 100 == 0 and _!=0:
            time2=time.time()
            print(_,"epoch:"+str(epoch)+"-loss:"+str(loss)+";each step's time spent:"+str(float(time2-time1)/float(_+0.0001)))
            # training_logger.add_row(str(epoch), str(_), str(loss))
            # console.print(training_logger)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

this is the official code in T5ForConditionalGeneration; i think we should only pass labels like 你好吗?<EOS>, then the code below will do shift.

 if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions