from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch import transformers
old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ defntk_scaled_init(self, dim, max_position_embeddings=2048, base=10000, device=None): #The method is just these three lines max_position_embeddings = 16384 a = 8#Alpha value base = base * a ** (dim / (dim-2)) #Base change formula # NTK-Aware old_init(self, dim, max_position_embeddings, base, device) transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = ntk_scaled_init