fix #4209
DeepSpeed ZeRO3 has inflight param error when calling model.eval()
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -19,7 +18,6 @@ if is_galore_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -154,17 +152,6 @@ def create_reward_model(
|
||||
return reward_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
|
||||
r"""
|
||||
Gets adapter context for the reference model.
|
||||
"""
|
||||
with accelerator.unwrap_model(model).disable_adapter():
|
||||
model.eval()
|
||||
yield
|
||||
model.train()
|
||||
|
||||
|
||||
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
|
||||
r"""
|
||||
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)
|
||||
|
||||
Reference in New Issue
Block a user