Kronos解析
2026/5/5 12:15:45 网站建设 项目流程

模型结构

<bound method Module.parameters of Kronos(
(token_drop): Dropout(p=0.0, inplace=False)
(embedding): HierarchicalEmbedding(
(emb_s1): Embedding(1024, 832)
(emb_s2): Embedding(1024, 832)
(fusion_proj): Linear(in_features=1664, out_features=832, bias=True)
)
(time_emb): TemporalEmbedding(
(minute_embed): Embedding(60, 832)
(hour_embed): Embedding(24, 832)
(weekday_embed): Embedding(7, 832)
(day_embed): Embedding(32, 832)
(month_embed): Embedding(13, 832)
)
(transformer): ModuleList(
(0-11): 12 x TransformerBlock(
(norm1): RMSNorm()
(self_attn): MultiHeadAttentionWithRoPE(
(q_proj): Linear(in_features=832, out_features=832, bias=True)
(k_proj): Linear(in_features=832, out_features=832, bias=True)
(v_proj): Linear(in_features=832, out_features=832, bias=True)
(out_proj): Linear(in_features=832, out_features=832, bias=True)
(rotary): RotaryPositionalEmbedding()
(resid_dropout): Dropout(p=0.2, inplace=False)
)
(norm2): RMSNorm()
(ffn): FeedForward(
(w1): Linear(in_features=832, out_features=2048, bias=False)
(w3): Linear(in_features=832, out_features=2048, bias=False)
(w2): Linear(in_features=2048, out_features=832, bias=False)
(ffn_dropout): Dropout(p=0.2, inplace=False)
)
)
)
(norm): RMSNorm()
(dep_layer): DependencyAwareLayer(
(cross_attn): MultiHeadCrossAttentionWithRoPE(
(q_proj): Linear(in_features=832, out_features=832, bias=True)
(k_proj): Linear(in_features=832, out_features=832, bias=True)
(v_proj): Linear(in_features=832, out_features=832, bias=True)
(out_proj): Linear(in_features=832, out_features=832, bias=True)
(rotary): RotaryPositionalEmbedding()
(resid_dropout): Dropout(p=0.0, inplace=False)
)
(norm): RMSNorm()
)
(head): DualHead(
(proj_s1): Linear(in_features=832, out_features=1024, bias=True)
(proj_s2): Linear(in_features=832, out_features=1024, bias=True)
)
)>

def forward(self, s1_ids, s2_ids, stamp=None, padding_mask=None, use_teacher_forcing=False, s1_targets=None):

输入 token后的 s1_ids, s2_ids shape为 [1,400] [1,400]

x = self.embedding([s1_ids, s2_ids])

HierarchicalEmbedding

token_ids (torch.Tensor): Composite token IDs of shape [batch_size, seq_len] or [N], each in range [0, 2^(s1_bits + s2_bits) - 1]. 2^(s1_bits + s2_bits) - 1 这个哪里来的? token我找找
BSQuantizer
def bits_to_indices(self, bits): bits = (bits >= 0).to(torch.long) indices = 2 ** torch.arange( 0, bits.shape[-1], 1, dtype=torch.long, device=bits.device, ) return (bits * indices).sum(-1)

bits_to_indices(bits) ∈ [0, 2^N − 1]

s1_emb = self.emb_s1(s1_ids) * math.sqrt(self.d_model) s2_emb = self.emb_s2(s2_ids) * math.sqrt(self.d_model) return self.fusion_proj(torch.cat([s1_emb, s2_emb], dim=-1))
if stamp is not None: time_embedding = self.time_emb(stamp) x = x + time_embedding
TemporalEmbedding
x = self.token_drop(x)
for layer in self.transformer: x = layer(x, key_padding_mask=padding_mask) x = self.norm(x)
s1_logits = self.head(x)
DualHead
if use_teacher_forcing: sibling_embed = self.embedding.emb_s1(s1_targets) else: s1_probs = F.softmax(s1_logits.detach(), dim=-1) sample_s1_ids = torch.multinomial(s1_probs.view(-1, self.s1_vocab_size), 1).view(s1_ids.shape) sibling_embed = self.embedding.emb_s1(sample_s1_ids) x2 = self.dep_layer(x, sibling_embed, key_padding_mask=padding_mask) # Dependency Aware Layer: Condition on s1 embeddings 这个DependencyAwareLayer跨注意力(cross-attention)让一个子表示(sibling/subtoken)去感知并注入主序列 hidden states 的依赖信息,从而显式建模不同子表示之间的结构依赖关系。 s2_logits = self.head.cond_forward(x2) return s1_logits, s2_logits
计算损失 def compute_loss(self, s1_logits, s2_logits, s1_targets, s2_targets, padding_mask=None): if padding_mask is not None: valid_mask = (padding_mask == 0) s1_logits = s1_logits[valid_mask] s2_logits = s2_logits[valid_mask] s1_targets = s1_targets[valid_mask] s2_targets = s2_targets[valid_mask] ce_s1 = F.cross_entropy(s1_logits, s1_targets) ce_s2 = F.cross_entropy(s2_logits, s2_targets) else: ce_s1 = F.cross_entropy(s1_logits.reshape(-1, self.vocab_s1), s1_targets.reshape(-1)) ce_s2 = F.cross_entropy(s2_logits.reshape(-1, self.vocab_s2), s2_targets.reshape(-1)) ce_loss = (ce_s1 + ce_s2) / 2 return ce_loss, ce_s1, ce_s2
decode_s1
def decode_s1(self, s1_ids, s2_ids, stamp=None, padding_mask=None): """ Decodes only the s1 tokens. This method performs a forward pass to predict only s1 tokens. It returns the s1 logits and the context representation from the Transformer, which can be used for subsequent s2 decoding. Args: s1_ids (torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] s2_ids (torch.Tensor): Input tensor of s2 token IDs. Shape: [batch_size, seq_len] stamp (torch.Tensor, optional): Temporal stamp tensor. Shape: [batch_size, seq_len]. Defaults to None. padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: - s1 logits: Logits for s1 token predictions. Shape: [batch_size, seq_len, s1_vocab_size] - context: Context representation from the Transformer. Shape: [batch_size, seq_len, d_model] """ x = self.embedding([s1_ids, s2_ids]) if stamp is not None: time_embedding = self.time_emb(stamp) x = x + time_embedding x = self.token_drop(x) for layer in self.transformer: x = layer(x, key_padding_mask=padding_mask) x = self.norm(x) s1_logits = self.head(x) return s1_logits, x
decode_s2
def decode_s2(self, context, s1_ids, padding_mask=None): """ Decodes the s2 tokens, conditioned on the context and s1 tokens. This method decodes s2 tokens based on a pre-computed context representation (typically from `decode_s1`) and the s1 token IDs. It uses the dependency-aware layer and the conditional s2 head to predict s2 tokens. Args: context (torch.Tensor): Context representation from the transformer (output of decode_s1). Shape: [batch_size, seq_len, d_model] s1_ids (torch.torch.Tensor): Input tensor of s1 token IDs. Shape: [batch_size, seq_len] padding_mask (torch.Tensor, optional): Mask for padding tokens. Shape: [batch_size, seq_len]. Defaults to None. Returns: torch.Tensor: s2 logits. Shape: [batch_size, seq_len, s2_vocab_size] """ sibling_embed = self.embedding.emb_s1(s1_ids) x2 = self.dep_layer(context, sibling_embed, key_padding_mask=padding_mask) return self.head.cond_forward(x2)

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询