DeepSpeed ZeRO3 has inflight param error when calling model.eval()
This commit is contained in:
hiyouga
2024-06-13 02:25:50 +08:00
parent 2ed8270112
commit cf9f2d6c42
4 changed files with 12 additions and 17 deletions

View File

@@ -1,4 +1,3 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -19,7 +18,6 @@ if is_galore_available():
if TYPE_CHECKING:
from accelerate import Accelerator
from transformers import PreTrainedModel, Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
@@ -154,17 +152,6 @@ def create_reward_model(
return reward_model
@contextmanager
def get_ref_context(accelerator: "Accelerator", model: "PreTrainedModel"):
r"""
Gets adapter context for the reference model.
"""
with accelerator.unwrap_model(model).disable_adapter():
model.eval()
yield
model.train()
def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]:
r"""
Returns a list of names of parameters with weight decay. (weights in non-layernorm layers)