3535
3636logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3737
38+
3839class Nerf (nn .Module ):
3940 def __init__ (
4041 self ,
@@ -44,7 +45,7 @@ def __init__(
4445 transformer_hidden_size : int ,
4546 max_freqs : int ,
4647 mlp_ratio : int ,
47- eps = 1e-6 ,
48+ eps = 1e-6 ,
4849 ):
4950 super ().__init__ ()
5051 self .nerf_embedder = NerfEmbedder (
@@ -69,6 +70,7 @@ def __init__(
6970 eps = eps ,
7071 )
7172 self .transformer_hidden_size = transformer_hidden_size
73+
7274 def __call__ (
7375 self ,
7476 pixels ,
@@ -77,20 +79,20 @@ def __call__(
7779 num_patches ,
7880 ):
7981 batch_size , channels , height , width = pixels .shape
80-
82+
8183 pixels = nn .functional .unfold (pixels , kernel_size = patch_size , stride = patch_size )
8284 pixels = pixels .transpose (1 , 2 )
83-
85+
8486 hidden = latents .reshape (batch_size * num_patches , self .transformer_hidden_size )
8587 pixels = pixels .reshape (batch_size * num_patches , channels , patch_size ** 2 ).transpose (1 , 2 )
86-
88+
8789 # Get pixel embeddings
8890 latents_dct = self .nerf_embedder (pixels )
89-
91+
9092 # Pass through blocks
9193 for block in self .blocks :
9294 latents_dct = block (latents_dct , hidden )
93-
95+
9496 latents_dct = latents_dct .transpose (1 , 2 ).reshape (batch_size , num_patches , - 1 ).transpose (1 , 2 )
9597 latents_dct = nn .functional .fold (
9698 latents_dct ,
@@ -100,6 +102,7 @@ def __call__(
100102 )
101103 return self .final_layer (latents_dct )
102104
105+
103106class NerfEmbedder (nn .Module ):
104107 def __init__ (
105108 self ,
@@ -111,6 +114,7 @@ def __init__(
111114 self .max_freqs = max_freqs
112115 self .hidden_size = hidden_size
113116 self .embedder = nn .Sequential (nn .Linear (in_channels + max_freqs ** 2 , hidden_size ))
117+
114118 def fetch_pos (self , patch_size ) -> torch .Tensor :
115119 pos_x = torch .linspace (0 , 1 , patch_size )
116120 pos_y = torch .linspace (0 , 1 , patch_size )
@@ -123,8 +127,9 @@ def fetch_pos(self, patch_size) -> torch.Tensor:
123127 coeffs = (1 + freqs_x * freqs_y ) ** - 1
124128 dct_x = torch .cos (pos_x * freqs_x * torch .pi )
125129 dct_y = torch .cos (pos_y * freqs_y * torch .pi )
126- dct = (dct_x * dct_y * coeffs ).view (1 , - 1 , self .max_freqs ** 2 )
130+ dct = (dct_x * dct_y * coeffs ).view (1 , - 1 , self .max_freqs ** 2 )
127131 return dct
132+
128133 def __call__ (self , inputs : torch .Tensor ) -> torch .Tensor :
129134 batch , pixels , channels = inputs .shape
130135 patch_size = int (pixels ** 0.5 )
@@ -134,13 +139,15 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor:
134139 inputs = torch .cat ((inputs , dct ), dim = - 1 )
135140 return self .embedder (inputs )
136141
142+
137143class NerfGLUBlock (nn .Module ):
138144 def __init__ (self , transformer_hidden_size : int , nerf_hidden_size : int , mlp_ratio , eps ):
139145 super ().__init__ ()
140146 total_params = 3 * nerf_hidden_size ** 2 * mlp_ratio
141147 self .param_generator = nn .Linear (transformer_hidden_size , total_params )
142148 self .norm = RMSNorm (nerf_hidden_size , eps = eps )
143149 self .mlp_ratio = mlp_ratio
150+
144151 def forward (self , x : torch .Tensor , s : torch .Tensor ) -> torch .Tensor :
145152 batch_size , num_x , hidden_size_x = x .shape
146153 mlp_params = self .param_generator (s )
@@ -156,6 +163,7 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
156163 x = torch .bmm (torch .nn .functional .silu (torch .bmm (x , fc1_gate )) * torch .bmm (x , fc1_value ), fc2 )
157164 return x + res_x
158165
166+
159167class NerfFinalLayer (nn .Module ):
160168 def __init__ (self , hidden_size : int , out_channels : int , eps ):
161169 super ().__init__ ()
@@ -166,9 +174,11 @@ def __init__(self, hidden_size: int, out_channels: int, eps):
166174 kernel_size = 3 ,
167175 padding = 1 ,
168176 )
177+
169178 def forward (self , x : torch .Tensor ) -> torch .Tensor :
170179 return self .conv (self .norm (x .movedim (1 , - 1 )).movedim (- 1 , 1 ))
171180
181+
172182class ChromaAdaLayerNormZeroPruned (nn .Module ):
173183 r"""
174184 Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -658,7 +668,7 @@ def forward(
658668 logger .warning (
659669 "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
660670 )
661-
671+
662672 hidden_states = self .x_embedder (hidden_states )
663673
664674 timestep = timestep .to (hidden_states .dtype ) * 1000
@@ -773,6 +783,7 @@ def forward(
773783
774784 return Transformer2DModelOutput (sample = output )
775785
786+
776787class ChromaRadianceTransformer2DModel (
777788 ModelMixin ,
778789 ConfigMixin ,
@@ -850,7 +861,7 @@ def __init__(
850861 hidden_dim = approximator_hidden_dim ,
851862 n_layers = approximator_layers ,
852863 )
853-
864+
854865 self .nerf = Nerf (
855866 in_channels ,
856867 nerf_layers ,
@@ -859,7 +870,7 @@ def __init__(
859870 nerf_max_freqs ,
860871 nerf_mlp_ratio ,
861872 )
862-
873+
863874 self .x_embedder_patch = nn .Conv2d (
864875 in_channels ,
865876 self .inner_dim ,
@@ -932,7 +943,6 @@ def forward(
932943 If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
933944 `tuple` where the first element is the sample tensor.
934945 """
935- print (self .device )
936946 pixels = hidden_states .to (self .device )
937947 if joint_attention_kwargs is not None :
938948 joint_attention_kwargs = joint_attention_kwargs .copy ()
0 commit comments