adapt for badam with ds zero3
This commit is contained in:
@@ -309,6 +309,12 @@ def _create_badam_optimizer(
|
||||
dict(params=decay_params, weight_decay=training_args.weight_decay),
|
||||
]
|
||||
|
||||
ds_zero3_enabled = False
|
||||
if hasattr(training_args, "deepspeed_plugin") and training_args.deepspeed_plugin is not None:
|
||||
assert training_args.deepspeed_plugin.zero_stage == 3, f"BAdam only supports deepspeed ZeRO-3 stage, got {training_args.deepspeed_plugin.zero_stage}"
|
||||
assert finetuning_args.badam_mode == "layer", "BAdam only supports layer-wise update in ZeRO-3 stage"
|
||||
ds_zero3_enabled = True
|
||||
|
||||
if finetuning_args.badam_mode == "layer":
|
||||
from badam import BlockOptimizer
|
||||
|
||||
@@ -321,6 +327,7 @@ def _create_badam_optimizer(
|
||||
start_block=finetuning_args.badam_start_block,
|
||||
switch_mode=finetuning_args.badam_switch_mode,
|
||||
verbose=finetuning_args.badam_verbose,
|
||||
ds_zero3_enabled=ds_zero3_enabled
|
||||
)
|
||||
logger.info(
|
||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||
|
||||
Reference in New Issue
Block a user