fix shift short attention
This commit is contained in:
@@ -103,7 +103,6 @@ def load_model_and_tokenizer(
|
||||
logger.info("Using dynamic NTK scaling.")
|
||||
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
@@ -128,7 +127,7 @@ def load_model_and_tokenizer(
|
||||
else:
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
|
||||
# Set FlashAttention-2 and S^2-Attn
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
@@ -136,12 +135,22 @@ def load_model_and_tokenizer(
|
||||
LlamaPatches._prepare_decoder_attention_mask
|
||||
)
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif getattr(config, "model_type", None) == "qwen":
|
||||
logger.info("Qwen models automatically enable FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
is_mergeable = True
|
||||
if model_args.quantization_bit is not None:
|
||||
@@ -176,14 +185,6 @@ def load_model_and_tokenizer(
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(model, "shift_ratio", 0.25)
|
||||
logger.info("Using shift short attention proposed by LongLoRA.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
Reference in New Issue
Block a user