Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightllm/common/kv_cache_mem_manager/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MemoryManager:
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
self.size = size
self.head_num = head_num
self.head_dim = head_dim
self.head_dim = head_dim * 2 # neo kv 是[k, k_h, k_w]拼在一起的
self.layer_num = layer_num
self.always_copy = always_copy
self.dtype = dtype
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
self.size,
dtype,
head_num,
head_dim,
self.head_dim,
layer_num,
)
self.HOLD_TOKEN_MEMINDEX = self.size
Expand Down
1 change: 1 addition & 0 deletions lightllm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@
Tarsier2LlamaTpPartModel,
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.neo_chat.model import NeoTpPartModel
from .registry import get_model, get_model_class
47 changes: 46 additions & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def _init_custom(self):
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
self._init_to_get_rotary()
if "rope_theta_hw" in self.config:
self._init_to_get_hw_rotary()
return

if "rope_type" in rope_scaling:
Expand All @@ -132,6 +134,8 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
if "rope_theta_hw" in self.config:
self._init_to_get_hw_rotary()
return

def _init_weights(self):
Expand Down Expand Up @@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000):
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta", float(default_base))

print(f"base is {base}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes and should be removed before merging.

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
Expand Down Expand Up @@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000):
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_hw_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta_hw", float(default_base))
print(f"hw_base is {base}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes and should be removed before merging.

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
max_position_embeddings = self.config.get(
"max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
)
max_seq_len = max_position_embeddings * rope_scaling_factor

# NTK
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass
Comment on lines +236 to +244
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a bare except: is generally discouraged as it can catch and silence a wide range of unexpected errors, making debugging difficult. It's better to catch specific exceptions that you expect might occur, such as ValueError or AssertionError, and log them for better diagnostics.

Suggested change
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except (ValueError, AssertionError) as e:
logger.warning(f"Could not apply NTK scaling: {e}")


inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = (
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
/ rope_scaling_factor
)
freqs = torch.outer(t, inv_freq)

self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def token_att_fwd(q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen
Lq, Lk = q.shape[-1], k.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
sm_scale = 1.0 / (Lk ** 0.5)
Lk_scale = Lk // 2
sm_scale = 1.0 / (Lk_scale ** 0.5)

batch, head_num = B_req_idx.shape[0], q.shape[1]

Expand Down
Empty file.
99 changes: 99 additions & 0 deletions lightllm/models/neo_chat/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional, List
import torch
import numpy as np
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.common.req_manager import ReqManager
from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton
from lightllm.models.llama.model import LlamaTpPartModel


class NeoChatInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()
self.position_cos = None
self.position_sin = None
self.position_cos_h = None
self.position_sin_h = None
self.position_cos_w = None
self.position_sin_w = None

def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor):
LlamaInferStateInfo.init_some_extra_state(self, model, input_ids)
if self.is_prefill:
self.position_ids = self.get_neo_position(self.multimodal_params)
else:
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
for batch_idx, p in enumerate(self.multimodal_params):
position_delta = 0
for image in p["images"]:
position_delta += image["grid_thwd"][3]
b_position_delta[batch_idx] = position_delta
position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone()
self.position_ids[1:].zero_()

self.position_ids = self.position_ids.contiguous()
self.position_cos = model._cos_cached[self.position_ids[0]]
self.position_sin = model._sin_cached[self.position_ids[0]]
self.position_cos_h = model._hw_cos_cached[self.position_ids[1]]
self.position_sin_h = model._hw_sin_cached[self.position_ids[1]]
self.position_cos_w = model._hw_cos_cached[self.position_ids[2]]
self.position_sin_w = model._hw_sin_cached[self.position_ids[2]]
return

def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
if len(multimodal_params) == 0:
position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
position_ids[0].copy_(self.position_ids)
return position_ids
b_image_start_idx = []
b_image_nums = []
b_image_start_num = []
b_image_len = []
image_start_num = 0
b_image_thwd = []

# pad multimodal_params to batch size.
batch_size = self.b_q_seq_len.shape[0]
multimodal_params = multimodal_params + [
{"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params))
]

for _, p in enumerate(multimodal_params):
images = p.get("images", [])
for img in images:
b_image_start_idx.append(img["start_idx"])
a = img["start_idx"]
print(f"img start_idx: {a}")
Comment on lines +66 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These lines, including a print statement, appear to be for debugging and should be removed.

b_image_len.append(img["token_num"])
b_image_thwd.append(img["grid_thwd"])
b_image_nums.append(len(images))
b_image_start_num.append(image_start_num)
image_start_num += len(images)

# 没有任何图片
if image_start_num == 0:
position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
position_ids[0].copy_(self.position_ids)
return position_ids.contiguous()
b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True)
b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)

position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
position_ids[0].copy_(self.position_ids)

get_neo_position_triton(
b_image_start_idx=b_image_start_idx,
b_image_thwd=b_image_thwd,
b_image_nums=b_image_nums,
b_image_start_num=b_image_start_num,
b_image_len=b_image_len,
position_ids=position_ids,
b_ready_cache_len=self.b_ready_cache_len,
b_q_seq_len=self.b_q_seq_len,
b_start_loc=self.b_start_loc,
)
return position_ids
Empty file.
159 changes: 159 additions & 0 deletions lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import torch
from functools import partial
from typing import Tuple
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo
from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
from lightllm.distributed import all_reduce
import torch.distributed as dist
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward


class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def _bind_attention(self):
self._context_attention_kernel = self._context_attention_kernel
self._token_attention_kernel = self._token_decode_attention_normal
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
return

def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight):
input = input.view(-1, self.embed_dim_)
q = layer_weight.q_proj.mm(input) # [T, Hq*D]

q_hw = layer_weight.q_hw_proj.mm(input)
q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_)
q_h, q_w = q_hw.chunk(2, dim=-1)

k_hw = layer_weight.k_hw_proj.mm(input)
k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_)
k_h, k_w = k_hw.chunk(2, dim=-1)

cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D]

qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_)

q_h_2d = q_h.reshape(q.shape[0], -1)
q_w_2d = q_w.reshape(q.shape[0], -1)
qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_)
qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_)
q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)

qk_rmsnorm_forward(
cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
weight=layer_weight.k_norm_weight_.weight,
eps=self.eps_,
)

k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)]
k_w_2d = k_w.reshape(q.shape[0], -1)
qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_)
qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_)
k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)

cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
rotary_emb_fwd(
q_h,
k_h,
infer_state.position_cos_h,
infer_state.position_sin_h,
)
rotary_emb_fwd(
q_w,
k_w,
infer_state.position_cos_w,
infer_state.position_sin_w,
)

q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_)
q3 = torch.cat([q3, q_h, q_w], dim=-1)
q = q3.reshape(q3.shape[0], -1)

k = cache_kv[:, : self.tp_k_head_num_, :]
k = torch.cat([k, k_h, k_w], dim=-1)

v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :]
v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype)
v = torch.cat([v, v_pad], dim=-1)

cache_kv = torch.cat([k, v], dim=1)
return q, cache_kv

def _context_attention_kernel(
self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
context_attention_fwd_neo(
q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2),
infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2)
o3 = o3[:, :, : self.head_dim_].contiguous()
return o3.view(o3.shape[0], -1)

def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None):
total_token_num = infer_state.total_token_num
batch_size = infer_state.batch_size

q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2)

att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32)

k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
token_att_fwd(
q_3d,
k_3d,
att_m_tensor,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd

v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_
]

o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out

token_softmax_reducev_fwd(
att_m_tensor,
v_3d,
o_3d,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
)
return o_3d.view(batch_size, -1)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import numpy as np
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight

# add key: language_model.xxx -> xxx
# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
def rename_weight_keys(weights):
prefix = "language_model."
keys = list(weights.keys())
for k in keys:
if prefix in k:
weights[k.replace(prefix, "")] = weights.pop(k)


class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
return

def load_hf_weights(self, weights):
rename_weight_keys(weights)
super().load_hf_weights(weights)
return
Loading