support BLOOM models

This commit is contained in:
hiyouga
2023-05-31 16:54:06 +08:00
parent a72492e649
commit 740a5daf56
16 changed files with 134 additions and 90 deletions

View File

@@ -58,7 +58,7 @@ def cast_layernorm_dtype(
return model, layer_norm_state_dict
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
r"""
Inherits PPOTrainer.
"""
@@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True
# Get response from LLaMA
# Get response from model
query_tensors: torch.Tensor = batch["input_ids"]
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)