Source code for cosense3d.modules.plugin.transformer
import warnings, copy
from typing import List, Optional
import torch
from torch import nn
import torch.utils.checkpoint as cp
from cosense3d.modules.utils import build_torch_module
from cosense3d.modules.utils.norm import build_norm_layer
from cosense3d.modules.utils.init import xavier_init
try:
from cosense3d.modules.plugin.flash_attn import FlashMHA
except:
from cosense3d.modules.plugin.flash_attn_new import FlashMHA
from cosense3d.modules.utils.amp import auto_fp16
[docs]def build_module(cfg):
cfg_ = copy.deepcopy(cfg)
attn_typ = cfg_.pop('type')
return globals()[attn_typ](**cfg_)
[docs]class FFN(nn.Module):
"""Implements feed-forward networks (FFNs) with residual connection.
"""
def __init__(self,
embed_dims: int,
feedforward_channels: int,
num_fcs: int=2,
act_cfg: dict=dict(type='ReLU', inplace=True),
dropout: float=0.0,
add_residual: bool=True):
"""
:param embed_dims: The feature dimension. Same as
`MultiheadAttention`.
:param feedforward_channels: The hidden dimension of FFNs.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Defaluts to 2.
:param num_fcs: number of fully connected layers.
:param act_cfg: activation config.
:param dropout: Probability of an element to be
zeroed. Default 0.0.
:param add_residual: Add resudual connection.
Defaults to True.
"""
super(FFN, self).__init__()
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.dropout = dropout
self.activate = build_torch_module(act_cfg)
layers = nn.ModuleList()
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(
nn.Sequential(
nn.Linear(in_channels, feedforward_channels),
self.activate,
nn.Dropout(dropout)))
in_channels = feedforward_channels
layers.append(nn.Linear(feedforward_channels, embed_dims))
self.layers = nn.Sequential(*layers)
self.dropout = nn.Dropout(dropout)
self.add_residual = add_residual
[docs] def forward(self, x, residual=None):
"""Forward function for `FFN`."""
out = self.layers(x)
if not self.add_residual:
return out
if residual is None:
residual = x
return residual + self.dropout(out)
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f'(embed_dims={self.embed_dims}, '
repr_str += f'feedforward_channels={self.feedforward_channels}, '
repr_str += f'num_fcs={self.num_fcs}, '
repr_str += f'act_cfg={self.act_cfg}, '
repr_str += f'dropout={self.dropout}, '
repr_str += f'add_residual={self.add_residual})'
return repr_str
[docs]class MultiheadFlashAttention(nn.Module):
r"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
attn_drop: float=0.,
proj_drop: float=0.,
dropout: float=None,
batch_first: bool=True,
cache_attn_weights: bool=False,
**kwargs):
"""
:param embed_dims: The embedding dimension.
:param num_heads: Parallel attention heads.
:param attn_drop: A Dropout layer on attn_output_weights. Default: 0.0.
:param proj_drop: A Dropout layer after `nn.MultiheadAttention`. Default: 0.0.
:param dropout: united dropout for both attention and projection layer.
:param batch_first: When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
:param cache_attn_weights: whether to cache the intermediate attention weights.
:param kwargs:
"""
super(MultiheadFlashAttention, self).__init__()
if dropout is not None:
attn_drop = dropout
proj_drop = dropout
self.embed_dims = embed_dims
self.num_heads = num_heads
self.batch_first = True
self.cache_attn_weights = cache_attn_weights
self.attn_weights = None
self.attn = FlashMHA(embed_dims, num_heads, attn_drop, dtype=torch.float16, device='cuda',
**kwargs)
self.proj_drop = nn.Dropout(proj_drop)
self.dropout_layer = nn.Dropout(attn_drop)
[docs] def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs):
"""
Forward function for `MultiheadAttention`.
:param query: The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False,
else [bs, num_queries embed_dims].
:param key: The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else
[bs, num_keys, embed_dims]. If None, the ``query`` will be used. Defaults to None.
:param value: The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`.
Defaults to None. If None, the `key` will be used.
:param identity: This tensor, with the same shape as x, will be used for the identity link.
If None, `x` will be used. Defaults to None.
:param query_pos: The positional encoding for query, with the same shape as `x`. If not None, it will
be added to `x` before forward function. Defaults to None.
:param key_pos: The positional encoding for `key`, with the same shape as `key`. Defaults to None.
If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same
shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None.
:param attn_mask: ByteTensor mask with shape [num_queries, num_keys].
Same in `nn.MultiheadAttention.forward`. Defaults to None.
:param key_padding_mask: ByteTensor with shape [bs, num_keys]. Defaults to None.
:param kwargs: allow passing a more general data flow when combining with
other operations in `transformerlayer`.
:return: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else
[bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
with torch.autocast(device_type='cuda', dtype=torch.float16):
# flash attention only support f16
out, attn_weights = self.attn(
q=query,
k=key,
v=value,
key_padding_mask=None)
if self.cache_attn_weights:
self.attn_weights = attn_weights
if self.batch_first:
out = out.transpose(0, 1)
return identity + self.dropout_layer(self.proj_drop(out))
[docs]class MultiHeadAttentionWrapper(nn.MultiheadAttention):
def __init__(self, *args, **kwargs):
super(MultiHeadAttentionWrapper, self).__init__(*args, **kwargs)
self.fp16_enabled = True
[docs] @auto_fp16(out_fp32=True)
def forward_fp16(self, *args, **kwargs):
return super(MultiHeadAttentionWrapper, self).forward(*args, **kwargs)
[docs] def forward_fp32(self, *args, **kwargs):
return super(MultiHeadAttentionWrapper, self).forward(*args, **kwargs)
[docs] def forward(self, *args, **kwargs):
if self.fp16_enabled and self.training:
return self.forward_fp16(*args, **kwargs)
else:
return self.forward_fp32(*args, **kwargs)
[docs]class MultiheadAttention(nn.Module):
r"""A wrapper for ``torch.nn.MultiheadAttention``.
This module implements MultiheadAttention with identity connection,
and positional encoding is also passed as input.
"""
def __init__(self,
embed_dims: int,
num_heads: int,
dropout: float=0.1,
batch_first: bool=False,
cache_attn_weights: bool=False,
fp16: bool=False,
**kwargs):
"""
:param embed_dims: The embedding dimension.
:param num_heads: Parallel attention heads.
:param dropout: probability of Dropout layer, Default: 0.0.
:param batch_first: When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
:param cache_attn_weights: whether to cache attention weights.
:param fp16: whether set precision to float16
:param kwargs:
"""
super(MultiheadAttention, self).__init__()
self.embed_dims = embed_dims
self.num_heads = num_heads
self.batch_first = batch_first
self.cache_attn_weights = cache_attn_weights
self.attn_weights = None
self.fp16_enabled = fp16
if fp16:
self.attn = MultiHeadAttentionWrapper(embed_dims, num_heads, dropout, **kwargs)
else:
self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout, **kwargs)
self.proj_drop = nn.Dropout(dropout)
self.dropout_layer = nn.Dropout(dropout)
[docs] def forward(self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs):
"""
Forward function for `MultiheadAttention`.
:param query: The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False,
else [bs, num_queries embed_dims].
:param key: The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False,
else [bs, num_keys, embed_dims]. If None, the ``query`` will be used. Defaults to None.
:param value: The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`.
Defaults to None. If None, the `key` will be used.
:param identity: This tensor, with the same shape as x, will be used for the identity link.
If None, `x` will be used. Defaults to None.
:param query_pos: The positional encoding for query, with the same shape as `x`.
If not None, it will be added to `x` before forward function. Defaults to None.
:param key_pos: The positional encoding for `key`, with the same shape as `key`.
Defaults to None. If not None, it will be added to `key` before `query_pos` has the same shape as `key`,
then `query_pos` will be used for `key_pos`. Defaults to None.
:param attn_mask: ByteTensor mask with shape [num_queries, num_keys].
Same in `nn.MultiheadAttention.forward`. Defaults to None.
:param key_padding_mask: ByteTensor with shape [bs, num_keys]. Defaults to None.
:param kwargs: allow passing a more general data flow when combining with other operations in `transformerlayer`.
:return: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False,
else[bs, num_queries embed_dims].
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(f'position encoding of key is'
f'missing in {self.__class__.__name__}.')
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
# Because the dataflow('key', 'query', 'value') of
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
# embed_dims), We should adjust the shape of dataflow from
# batch_first (batch, num_query, embed_dims) to num_query_first
# (num_query ,batch, embed_dims), and recover ``attn_output``
# from num_query_first to batch_first.
if self.batch_first:
query = query.transpose(0, 1).contiguous()
key = key.transpose(0, 1).contiguous()
value = value.transpose(0, 1).contiguous()
out, attn_weights = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask)
if self.batch_first:
out = out.transpose(0, 1).contiguous()
if self.cache_attn_weights:
self.attn_weights = attn_weights
return identity + self.dropout_layer(self.proj_drop(out))
[docs]class TransformerDecoderLayer(nn.Module):
def __init__(self,
attn_cfgs=None,
ffn_cfgs=None,
operation_order=None,
norm_cfg=dict(type='LN'),
batch_first=False,
with_cp=True,
**kwargs):
super().__init__()
assert set(operation_order) & {
'self_attn', 'norm', 'ffn', 'cross_attn'} == \
set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
num_attn = operation_order.count('self_attn') + operation_order.count('cross_attn')
if isinstance(attn_cfgs, dict):
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
else:
assert num_attn == len(attn_cfgs), f'The length ' \
f'of attn_cfg {num_attn} is ' \
f'not consistent with the number of attention' \
f'in operation_order {operation_order}.'
self.batch_first = batch_first
self.num_attn = num_attn
self.operation_order = operation_order
self.norm_cfg = norm_cfg
self.pre_norm = operation_order[0] == 'norm'
self.use_checkpoint = with_cp
self._init_layers(operation_order, attn_cfgs, ffn_cfgs, norm_cfg)
def _init_layers(self, operation_order, attn_cfgs, ffn_cfgs, norm_cfg):
self.attentions = nn.ModuleList()
index = 0
for operation_name in operation_order:
if operation_name in ['self_attn', 'cross_attn']:
if 'batch_first' in attn_cfgs[index]:
assert self.batch_first == attn_cfgs[index]['batch_first']
else:
attn_cfgs[index]['batch_first'] = self.batch_first
attention = build_module(attn_cfgs[index])
# Some custom attentions used as `self_attn`
# or `cross_attn` can have different behavior.
attention.operation_name = operation_name
self.attentions.append(attention)
index += 1
self.embed_dims = self.attentions[0].embed_dims
self.ffns = nn.ModuleList()
num_ffns = operation_order.count('ffn')
if isinstance(ffn_cfgs, dict):
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append(build_module(ffn_cfgs[ffn_index]))
self.norms = nn.ModuleList()
num_norms = operation_order.count('norm')
for _ in range(num_norms):
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
def _forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
temp_memory=None,
temp_pos=None,
attn_masks: List[torch.Tensor]=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""
Forward function for `TransformerDecoderLayer`.
:param query: The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False,
else [bs, num_queries embed_dims].
:param key: The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False,
else [bs, num_keys, embed_dims].
:param value: The value tensor with same shape as `key`.
:param query_pos: The positional encoding for `query`. Default: None.
:param key_pos: The positional encoding for `key`. Default: None.
:param temp_memory: 2D Tensor used in calculation of corresponding attention. The length of it should equal
to the number of `attention` in `operation_order`. Default: None.
:param temp_pos:
:param attn_masks: 2D Tensor used in calculation of corresponding attention. The length of it should equal
to the number of `attention` in `operation_order`. Default: None.
:param query_key_padding_mask: ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn`
layer. Defaults to None.
:param key_padding_mask: ByteTensor for `query`, with shape [bs, num_keys]. Default: None.
:param kwargs: contains some specific arguments of attentions.
:return: forwarded results with shape [num_queries, bs, embed_dims].
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
f'to the number of attention in ' \
f'operation_order {self.num_attn}'
for layer in self.operation_order:
if layer == 'self_attn':
if temp_memory is not None:
temp_key = temp_value = torch.cat([query, temp_memory], dim=0)
temp_pos = torch.cat([query_pos, temp_pos], dim=0)
else:
temp_key = temp_value = query
temp_pos = query_pos
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=temp_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'norm':
query = self.norms[norm_index](query)
norm_index += 1
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs)
attn_index += 1
identity = query
elif layer == 'ffn':
query = self.ffns[ffn_index](
query, identity if self.pre_norm else None)
ffn_index += 1
return query
[docs] def forward(self,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
temp_memory=None,
temp_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs
):
"""Forward function for `TransformerCoder`.
:returns: Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
if self.use_checkpoint and self.training:
x = cp.checkpoint(
self._forward,
query,
key,
value,
query_pos,
key_pos,
temp_memory,
temp_pos,
attn_masks,
query_key_padding_mask,
key_padding_mask,
)
else:
x = self._forward(
query,
key,
value,
query_pos,
key_pos,
temp_memory,
temp_pos,
attn_masks,
query_key_padding_mask,
key_padding_mask,
)
return x
[docs]class TransformerLayerSequence(nn.Module):
"""
Base class for TransformerEncoder and TransformerDecoder in vision
transformer.
As base-class of Encoder and Decoder in vision transformer.
Support customization such as specifying different kind
of `transformer_layer` in `transformer_coder`.
"""
def __init__(self, transformerlayers=None, num_layers=None):
"""
:param transformerlayers: (list[obj:`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`)
Config of transformerlayer in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
it would be repeated `num_layer` times to a list[`mmcv.ConfigDict`]. Default: None.
:param num_layers: The number of `TransformerLayer`. Default: None.
"""
super().__init__()
if isinstance(transformerlayers, dict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
]
else:
assert isinstance(transformerlayers, list) and \
len(transformerlayers) == num_layers
self.num_layers = num_layers
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(build_module(transformerlayers[i]))
self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].pre_norm
[docs] def forward(self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs):
"""Forward function for `TransformerCoder`.
:param query: (Tensor) Input query with shape `(num_queries, bs, embed_dims)`.
:param key: (Tensor) The key tensor with shape `(num_keys, bs, embed_dims)`.
:param value: (Tensor) The value tensor with shape `(num_keys, bs, embed_dims)`.
:param query_pos: (Tensor) The positional encoding for `query`. Default: None.
:param key_pos: (Tensor) The positional encoding for `key`. Default: None.
:param attn_masks: (List[Tensor], optional) Each element is 2D Tensor which is
used in calculation of corresponding attention in operation_order. Default: None.
:param query_key_padding_mask: (Tensor) ByteTensor for `query`, with shape [bs, num_queries].
Only used in self-attention Default: None.
:param key_padding_mask: (Tensor) ByteTensor for `query`, with shape [bs, num_keys]. Default: None.
:returns: results with shape [num_queries, bs, embed_dims].
"""
for layer in self.layers:
query = layer(
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs)
return query
[docs]class TransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR transformer."""
def __init__(self,
*args,
post_norm_cfg=dict(type='LN'),
return_intermediate=False,
**kwargs):
"""
:param args:
:param post_norm_cfg: Config of last normalization layer. Default: `LN`.
:param return_intermediate: Whether to return intermediate outputs.
:param kwargs:
"""
super(TransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
if post_norm_cfg is not None:
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1]
else:
self.post_norm = None
[docs] def forward(self, query, *args, **kwargs):
"""Forward function for `TransformerDecoder`.
:param query: (Tensor) Input query with shape `(num_query, bs, embed_dims)`.
:return:Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape [num_layers, num_query, bs, embed_dims].
"""
if not self.return_intermediate:
x = super().forward(query, *args, **kwargs)
if self.post_norm:
x = self.post_norm(x)[None]
return x
intermediate = []
for layer in self.layers:
query = layer(query, *args, **kwargs)
if self.return_intermediate:
if self.post_norm is not None:
intermediate.append(self.post_norm(query))
else:
intermediate.append(query)
# if torch.isnan(query).any():
# print('TransfromerDecoder: Found nan in query.')
# if torch.isnan(intermediate[-1]).any():
# print('TransfromerDecoder: Found nan in intermediate result.')
return torch.stack(intermediate)
class PETRTemporalTransformer(nn.Module):
"""Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
"""
def __init__(self, encoder=None, decoder=None, cross=False):
"""
:param encoder: (`mmcv.ConfigDict` | Dict) Config of
TransformerEncoder. Defaults to None.
:param decoder: ((`mmcv.ConfigDict` | Dict) Config of
TransformerDecoder. Defaults to None.
:param cross: whether to use cross-attention.
"""
super(PETRTemporalTransformer, self).__init__()
if encoder is not None:
self.encoder = build_module(encoder)
else:
self.encoder = None
self.decoder = build_module(decoder)
self.embed_dims = self.decoder.embed_dims
self.cross = cross
def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, 'weight') and m.weight.dim() > 1:
xavier_init(m, distribution='uniform')
self._is_init = True
def forward(self, memory, tgt, query_pos, pos_embed, attn_masks, temp_memory=None, temp_pos=None,
mask=None, query_mask=None, reg_branch=None):
"""Forward function for `Transformer`.
"""
memory = memory.transpose(0, 1).contiguous()
query_pos = query_pos.transpose(0, 1).contiguous()
pos_embed = pos_embed.transpose(0, 1).contiguous()
n, bs, c = memory.shape
if tgt is None:
tgt = torch.zeros_like(query_pos)
else:
tgt = tgt.transpose(0, 1).contiguous()
if temp_memory is not None:
temp_memory = temp_memory.transpose(0, 1).contiguous()
temp_pos = temp_pos.transpose(0, 1).contiguous()
# out_dec: [num_layers, num_query, bs, dim]
if not isinstance(attn_masks, list):
attn_masks = [attn_masks, None]
out_dec = self.decoder(
query=tgt,
key=memory,
value=memory,
key_pos=pos_embed,
query_pos=query_pos,
temp_memory=temp_memory,
temp_pos=temp_pos,
query_key_padding_mask=query_mask,
key_padding_mask=mask,
attn_masks=attn_masks,
reg_branch=reg_branch,
)
out_dec = out_dec.transpose(1, 2).contiguous()
memory = memory.reshape(-1, bs, c).transpose(0, 1).contiguous()
return out_dec, memory
[docs]class PETRTransformer(nn.Module):
"""
Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
"""
def __init__(self, encoder=None, decoder=None, cross=False):
super(PETRTransformer, self).__init__()
if encoder is not None:
self.encoder = build_module(encoder)
else:
self.encoder = None
self.decoder = build_module(decoder)
self.embed_dims = self.decoder.embed_dims
self.cross = cross
[docs] def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, 'weight') and m.weight.dim() > 1:
xavier_init(m, distribution='uniform')
self._is_init = True
[docs] def forward(self, memory, tgt, query_pos, pos_embed, attn_masks=None,
mask=None, query_mask=None):
"""Forward function for `Transformer`.
"""
memory = memory.transpose(0, 1).contiguous()
query_pos = query_pos.transpose(0, 1).contiguous()
pos_embed = pos_embed.transpose(0, 1).contiguous()
n, bs, c = memory.shape
if tgt is None:
tgt = torch.zeros_like(query_pos)
else:
tgt = tgt.transpose(0, 1).contiguous()
# out_dec: [num_layers, num_query, bs, dim]
if not isinstance(attn_masks, list):
attn_masks = [attn_masks]
assert len(attn_masks) == self.decoder.layers[0].num_attn
out_dec = self.decoder(
query=tgt,
key=memory,
value=memory,
key_pos=pos_embed,
query_pos=query_pos,
query_key_padding_mask=query_mask,
key_padding_mask=mask,
attn_masks=attn_masks,
)
out_dec = out_dec.transpose(1, 2).contiguous()
memory = memory.reshape(-1, bs, c).transpose(0, 1).contiguous()
return out_dec, memory
[docs]class PETRTemporalTransformer(nn.Module):
r"""
Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
"""
def __init__(self, encoder=None, decoder=None, cross=False):
super(PETRTemporalTransformer, self).__init__()
if encoder is not None:
self.encoder = build_module(encoder)
else:
self.encoder = None
self.decoder = build_module(decoder)
self.embed_dims = self.decoder.embed_dims
self.cross = cross
[docs] def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, 'weight') and m.weight.dim() > 1:
xavier_init(m, distribution='uniform')
self._is_init = True
[docs] def forward(self, memory, tgt, query_pos, pos_embed, attn_masks, temp_memory=None, temp_pos=None,
mask=None, query_mask=None, reg_branch=None):
"""Forward function for `Transformer`.
"""
query_pos = query_pos.transpose(0, 1).contiguous()
if memory is not None:
memory = memory.transpose(0, 1).contiguous()
n, bs, c = memory.shape
if pos_embed is not None:
pos_embed = pos_embed.transpose(0, 1).contiguous()
if tgt is None:
tgt = torch.zeros_like(query_pos)
else:
tgt = tgt.transpose(0, 1).contiguous()
if temp_memory is not None:
temp_memory = temp_memory.transpose(0, 1).contiguous()
temp_pos = temp_pos.transpose(0, 1).contiguous()
# out_dec: [num_layers, num_query, bs, dim]
if not isinstance(attn_masks, list):
attn_masks = [attn_masks]
assert len(attn_masks) == self.decoder.layers[0].num_attn
out_dec = self.decoder(
query=tgt,
key=memory,
value=memory,
key_pos=pos_embed,
query_pos=query_pos,
temp_memory=temp_memory,
temp_pos=temp_pos,
query_key_padding_mask=query_mask,
key_padding_mask=mask,
attn_masks=attn_masks,
)
out_dec = out_dec.transpose(1, 2).contiguous()
if memory is not None:
memory = memory.reshape(-1, bs, c).transpose(0, 1).contiguous()
return out_dec, memory