-
Notifications
You must be signed in to change notification settings - Fork 291
Add neo chat #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add neo chat #1161
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -110,6 +110,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| rope_scaling = self.config.get("rope_scaling", None) | ||||||||||||||||||||||||||||||||||||||
| if rope_scaling is None: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if "rope_type" in rope_scaling: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -132,6 +134,8 @@ def _init_custom(self): | |||||||||||||||||||||||||||||||||||||
| self._init_to_get_rotary() | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Unknown RoPE scaling type {scaling_type}") | ||||||||||||||||||||||||||||||||||||||
| if "rope_theta_hw" in self.config: | ||||||||||||||||||||||||||||||||||||||
| self._init_to_get_hw_rotary() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_weights(self): | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -178,7 +182,7 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| print(f"base is {base}") | ||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -211,6 +215,47 @@ def _init_to_get_rotary(self, default_base=10000): | |||||||||||||||||||||||||||||||||||||
| self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_hw_rotary(self, default_base=10000): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) | ||||||||||||||||||||||||||||||||||||||
| if self.config.get("rope_scaling", {}) is None: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = 1.0 | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| base = self.config.get("rope_theta_hw", float(default_base)) | ||||||||||||||||||||||||||||||||||||||
| print(f"hw_base is {base}") | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||
| if "max_sequence_length" in self.config: | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = self.config["max_sequence_length"] | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get( | ||||||||||||||||||||||||||||||||||||||
| "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| max_seq_len = max_position_embeddings * rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # NTK | ||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||
| ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) | ||||||||||||||||||||||||||||||||||||||
| assert ntk_alpha >= 1 | ||||||||||||||||||||||||||||||||||||||
| if ntk_alpha > 1: | ||||||||||||||||||||||||||||||||||||||
| logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") | ||||||||||||||||||||||||||||||||||||||
| max_seq_len *= ntk_alpha | ||||||||||||||||||||||||||||||||||||||
| base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula | ||||||||||||||||||||||||||||||||||||||
| except: | ||||||||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+236
to
+244
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a bare
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| inv_freq = 1.0 / ( | ||||||||||||||||||||||||||||||||||||||
| base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| t = ( | ||||||||||||||||||||||||||||||||||||||
| torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) | ||||||||||||||||||||||||||||||||||||||
| / rope_scaling_factor | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| freqs = torch.outer(t, inv_freq) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() | ||||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _init_to_get_dynamic_ntk_rotary(self): | ||||||||||||||||||||||||||||||||||||||
| partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) | ||||||||||||||||||||||||||||||||||||||
| max_position_embeddings = self.config.get("max_position_embeddings", 2048) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| from typing import Optional, List | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
| from lightllm.common.req_manager import ReqManager | ||
| from lightllm.models.neo_chat.triton_kernel.get_neo_position import get_neo_position_triton | ||
| from lightllm.models.llama.model import LlamaTpPartModel | ||
|
|
||
|
|
||
| class NeoChatInferStateInfo(LlamaInferStateInfo): | ||
| def __init__(self): | ||
| super().__init__() | ||
| self.position_cos = None | ||
| self.position_sin = None | ||
| self.position_cos_h = None | ||
| self.position_sin_h = None | ||
| self.position_cos_w = None | ||
| self.position_sin_w = None | ||
|
|
||
| def init_some_extra_state(self, model: LlamaTpPartModel, input_ids: torch.Tensor): | ||
| LlamaInferStateInfo.init_some_extra_state(self, model, input_ids) | ||
| if self.is_prefill: | ||
| self.position_ids = self.get_neo_position(self.multimodal_params) | ||
| else: | ||
| b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] | ||
| for batch_idx, p in enumerate(self.multimodal_params): | ||
| position_delta = 0 | ||
| for image in p["images"]: | ||
| position_delta += image["grid_thwd"][3] | ||
| b_position_delta[batch_idx] = position_delta | ||
| position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) | ||
| self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() | ||
| self.position_ids[1:].zero_() | ||
|
|
||
| self.position_ids = self.position_ids.contiguous() | ||
| self.position_cos = model._cos_cached[self.position_ids[0]] | ||
| self.position_sin = model._sin_cached[self.position_ids[0]] | ||
| self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] | ||
| self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] | ||
| self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] | ||
| self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] | ||
| return | ||
|
|
||
| def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: | ||
| if len(multimodal_params) == 0: | ||
| position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) | ||
| position_ids[0].copy_(self.position_ids) | ||
| return position_ids | ||
| b_image_start_idx = [] | ||
| b_image_nums = [] | ||
| b_image_start_num = [] | ||
| b_image_len = [] | ||
| image_start_num = 0 | ||
| b_image_thwd = [] | ||
|
|
||
| # pad multimodal_params to batch size. | ||
| batch_size = self.b_q_seq_len.shape[0] | ||
| multimodal_params = multimodal_params + [ | ||
| {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) | ||
| ] | ||
|
|
||
| for _, p in enumerate(multimodal_params): | ||
| images = p.get("images", []) | ||
| for img in images: | ||
| b_image_start_idx.append(img["start_idx"]) | ||
| a = img["start_idx"] | ||
| print(f"img start_idx: {a}") | ||
|
Comment on lines
+66
to
+67
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| b_image_len.append(img["token_num"]) | ||
| b_image_thwd.append(img["grid_thwd"]) | ||
| b_image_nums.append(len(images)) | ||
| b_image_start_num.append(image_start_num) | ||
| image_start_num += len(images) | ||
|
|
||
| # 没有任何图片 | ||
| if image_start_num == 0: | ||
| position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) | ||
| position_ids[0].copy_(self.position_ids) | ||
| return position_ids.contiguous() | ||
| b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) | ||
| b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 | ||
| b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) | ||
| b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) | ||
| b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) | ||
|
|
||
| position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) | ||
| position_ids[0].copy_(self.position_ids) | ||
|
|
||
| get_neo_position_triton( | ||
| b_image_start_idx=b_image_start_idx, | ||
| b_image_thwd=b_image_thwd, | ||
| b_image_nums=b_image_nums, | ||
| b_image_start_num=b_image_start_num, | ||
| b_image_len=b_image_len, | ||
| position_ids=position_ids, | ||
| b_ready_cache_len=self.b_ready_cache_len, | ||
| b_q_seq_len=self.b_q_seq_len, | ||
| b_start_loc=self.b_start_loc, | ||
| ) | ||
| return position_ids | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import torch | ||
| from functools import partial | ||
| from typing import Tuple | ||
| from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward | ||
| from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd | ||
| from lightllm.models.neo_chat.infer_struct import NeoChatInferStateInfo | ||
| from lightllm.models.neo_chat.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo | ||
| from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd | ||
| from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd | ||
| from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer | ||
| from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight | ||
| from lightllm.distributed import all_reduce | ||
| import torch.distributed as dist | ||
| from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
| from lightllm.models.qwen3.triton_kernel.qk_norm import qk_rmsnorm_forward | ||
|
|
||
|
|
||
| class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def _bind_attention(self): | ||
| self._context_attention_kernel = self._context_attention_kernel | ||
| self._token_attention_kernel = self._token_decode_attention_normal | ||
| self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal | ||
| return | ||
|
|
||
| def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): | ||
| input = input.view(-1, self.embed_dim_) | ||
| q = layer_weight.q_proj.mm(input) # [T, Hq*D] | ||
|
|
||
| q_hw = layer_weight.q_hw_proj.mm(input) | ||
| q_hw = q_hw.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q_h, q_w = q_hw.chunk(2, dim=-1) | ||
|
|
||
| k_hw = layer_weight.k_hw_proj.mm(input) | ||
| k_hw = k_hw.view(-1, self.tp_k_head_num_, self.head_dim_) | ||
| k_h, k_w = k_hw.chunk(2, dim=-1) | ||
|
|
||
| cache_kv = layer_weight.kv_proj.mm(input) # [T, (Hk+Hv)*D] | ||
|
|
||
| qk_rmsnorm_forward(q, weight=layer_weight.q_norm_weight_.weight, eps=self.eps_) | ||
|
|
||
| q_h_2d = q_h.reshape(q.shape[0], -1) | ||
| q_w_2d = q_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(q_h_2d, weight=layer_weight.q_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(q_w_2d, weight=layer_weight.q_norm_w_weight_.weight, eps=self.eps_) | ||
| q_h = q_h_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
| q_w = q_w_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) | ||
|
|
||
| qk_rmsnorm_forward( | ||
| cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], | ||
| weight=layer_weight.k_norm_weight_.weight, | ||
| eps=self.eps_, | ||
| ) | ||
|
|
||
| k_h_2d = k_h.reshape(q.shape[0], -1) # [T, Hk*(D/2)] | ||
| k_w_2d = k_w.reshape(q.shape[0], -1) | ||
| qk_rmsnorm_forward(k_h_2d, weight=layer_weight.k_norm_h_weight_.weight, eps=self.eps_) | ||
| qk_rmsnorm_forward(k_w_2d, weight=layer_weight.k_norm_w_weight_.weight, eps=self.eps_) | ||
| k_h = k_h_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
| k_w = k_w_2d.view(q.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) | ||
|
|
||
| cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) | ||
|
|
||
| rotary_emb_fwd( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_), | ||
| cache_kv[:, : self.tp_k_head_num_, :], | ||
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_h, | ||
| k_h, | ||
| infer_state.position_cos_h, | ||
| infer_state.position_sin_h, | ||
| ) | ||
| rotary_emb_fwd( | ||
| q_w, | ||
| k_w, | ||
| infer_state.position_cos_w, | ||
| infer_state.position_sin_w, | ||
| ) | ||
|
|
||
| q3 = q.view(-1, self.tp_q_head_num_, self.head_dim_) | ||
| q3 = torch.cat([q3, q_h, q_w], dim=-1) | ||
| q = q3.reshape(q3.shape[0], -1) | ||
|
|
||
| k = cache_kv[:, : self.tp_k_head_num_, :] | ||
| k = torch.cat([k, k_h, k_w], dim=-1) | ||
|
|
||
| v = cache_kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] | ||
| v_pad = torch.zeros((v.shape[0], v.shape[1], self.head_dim_), device=v.device, dtype=v.dtype) | ||
| v = torch.cat([v, v_pad], dim=-1) | ||
|
|
||
| cache_kv = torch.cat([k, v], dim=1) | ||
| return q, cache_kv | ||
|
|
||
| def _context_attention_kernel( | ||
| self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None | ||
| ) -> torch.Tensor: | ||
| o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out | ||
| kv = infer_state.mem_manager.kv_buffer[self.layer_num_] | ||
| context_attention_fwd_neo( | ||
| q.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| kv[:, 0 : self.tp_k_head_num_, :], | ||
| kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], | ||
| o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2), | ||
| infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.b_ready_cache_len, | ||
| infer_state.max_len_in_batch, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| ) | ||
| o3 = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_ * 2) | ||
| o3 = o3[:, :, : self.head_dim_].contiguous() | ||
| return o3.view(o3.shape[0], -1) | ||
|
|
||
| def _token_decode_attention_normal(self, q, infer_state: NeoChatInferStateInfo, layer_weight, out=None): | ||
| total_token_num = infer_state.total_token_num | ||
| batch_size = infer_state.batch_size | ||
|
|
||
| q_3d = q.view(batch_size, self.tp_q_head_num_, self.head_dim_ * 2) | ||
|
|
||
| att_m_tensor = self.alloc_tensor((self.tp_q_head_num_, total_token_num), torch.float32) | ||
|
|
||
| k_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] | ||
| token_att_fwd( | ||
| q_3d, | ||
| k_3d, | ||
| att_m_tensor, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| infer_state.max_len_in_batch, | ||
| ) | ||
|
|
||
| from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import token_softmax_reducev_fwd | ||
|
|
||
| v_3d = infer_state.mem_manager.kv_buffer[self.layer_num_][ | ||
| :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : self.head_dim_ | ||
| ] | ||
|
|
||
| o_3d = self.alloc_tensor((batch_size, self.tp_q_head_num_, self.head_dim_), q.dtype) if out is None else out | ||
|
|
||
| token_softmax_reducev_fwd( | ||
| att_m_tensor, | ||
| v_3d, | ||
| o_3d, | ||
| infer_state.req_manager.req_to_token_indexs, | ||
| infer_state.b_req_idx, | ||
| infer_state.b_start_loc, | ||
| infer_state.b_seq_len, | ||
| ) | ||
| return o_3d.view(batch_size, -1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| import torch | ||
| import numpy as np | ||
| from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight | ||
|
|
||
| # add key: language_model.xxx -> xxx | ||
| # only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now | ||
| def rename_weight_keys(weights): | ||
| prefix = "language_model." | ||
| keys = list(weights.keys()) | ||
| for k in keys: | ||
| if prefix in k: | ||
| weights[k.replace(prefix, "")] = weights.pop(k) | ||
|
|
||
|
|
||
| class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): | ||
| def __init__(self, data_type, network_config, mode): | ||
| super().__init__(data_type, network_config, mode) | ||
| return | ||
|
|
||
| def load_hf_weights(self, weights): | ||
| rename_weight_keys(weights) | ||
| super().load_hf_weights(weights) | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
printstatement appears to be for debugging purposes and should be removed before merging.