improve lora+ impl.

This commit is contained in:
hiyouga
2024-03-13 23:32:51 +08:00
parent 4e5e99af43
commit 72367307df
12 changed files with 165 additions and 169 deletions

View File

@@ -109,10 +109,6 @@ def load_model(
if not is_trainable:
model.requires_grad_(False)
if not getattr(model, "quantization_method", None):
for param in filter(lambda p: p.device.type == "cuda", model.parameters()):
param.data = param.data.to(model_args.compute_dtype)
model.eval()
else:
model.train()