refactor evaluation, upgrade trl to 074

This commit is contained in:
hiyouga
2023-11-13 22:20:35 +08:00
parent 528d91192a
commit 442aefb925
21 changed files with 341 additions and 247 deletions

View File

@@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
@PPODecorators.empty_device_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",