refactor evaluation, upgrade trl to 074
This commit is contained in:
@@ -226,7 +226,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
|
||||
Reference in New Issue
Block a user