@@ -38,7 +38,7 @@ public override Tensor forward(Tensor x, Tensor y)
3838 long batch_size = input_shape [ 0 ] ;
3939 long sequence_length = input_shape [ 1 ] ;
4040
41- long [ ] interim_shape = [ batch_size , - 1 , n_heads_ , d_head ] ;
41+ long [ ] interim_shape = new long [ ] { batch_size , - 1 , n_heads_ , d_head } ;
4242 Tensor q = to_q . forward ( x ) ;
4343 Tensor k = to_k . forward ( y ) ;
4444 Tensor v = to_v . forward ( y ) ;
@@ -190,7 +190,7 @@ public override Tensor forward(Tensor x, Tensor context)
190190 x = proj_in . forward ( x ) ;
191191 }
192192
193- x = x . view ( [ n , c , h * w ] ) ;
193+ x = x . view ( new long [ ] { n , c , h * w } ) ;
194194 x = x . transpose ( - 1 , - 2 ) ;
195195
196196 if ( use_linear )
@@ -208,7 +208,7 @@ public override Tensor forward(Tensor x, Tensor context)
208208 x = proj_out . forward ( x ) ;
209209 }
210210 x = x . transpose ( - 1 , - 2 ) ;
211- x = x . view ( [ n , c , h , w ] ) ;
211+ x = x . view ( new long [ ] { n , c , h , w } ) ;
212212 if ( ! use_linear )
213213 {
214214 x = proj_out . forward ( x ) ;
@@ -241,7 +241,7 @@ public Upsample(int in_channels, bool with_conv = true, Device? device = null, S
241241 }
242242 public override Tensor forward ( Tensor x )
243243 {
244- var output = functional . interpolate ( x , scale_factor : [ 2.0 , 2.0 ] , mode : InterpolationMode . Nearest ) ;
244+ var output = functional . interpolate ( x , scale_factor : new double [ ] { 2.0 , 2.0 } , mode : InterpolationMode . Nearest ) ;
245245 if ( with_conv && conv is not null )
246246 {
247247 output = conv . forward ( output ) ;
@@ -359,19 +359,19 @@ private class UNet : Module<Tensor, Tensor, Tensor, Tensor>
359359 public UNet ( int model_channels , int in_channels , int [ ] ? channel_mult = null , int num_res_blocks = 2 , int num_atten_blocks = 1 , int context_dim = 768 , int num_heads = 8 , float dropout = 0.0f , bool use_timestep = true , Device ? device = null , ScalarType ? dtype = null ) : base ( nameof ( UNet ) )
360360 {
361361 bool mask = false ;
362- channel_mult = channel_mult ?? [ 1 , 2 , 4 , 4 ] ;
362+ channel_mult = channel_mult ?? new int [ ] { 1 , 2 , 4 , 4 } ;
363363
364364 ch = model_channels ;
365365 time_embed_dim = model_channels * 4 ;
366366 this . in_channels = in_channels ;
367367 this . use_timestep = use_timestep ;
368368
369- List < int > input_block_channels = [ model_channels ] ;
369+ List < int > input_block_channels = new List < int > { model_channels } ;
370370
371371 if ( use_timestep )
372372 {
373373 // timestep embedding
374- time_embed = Sequential ( [ Linear ( model_channels , time_embed_dim , device : device , dtype : dtype ) , SiLU ( ) , Linear ( time_embed_dim , time_embed_dim , device : device , dtype : dtype ) ] ) ;
374+ time_embed = Sequential ( new Module < Tensor , Tensor > [ ] { Linear ( model_channels , time_embed_dim , device : device , dtype : dtype ) , SiLU ( ) , Linear ( time_embed_dim , time_embed_dim , device : device , dtype : dtype ) } ) ;
375375 }
376376
377377 // downsampling
@@ -462,7 +462,7 @@ public override Tensor forward(Tensor x, Tensor context, Tensor time)
462462 foreach ( TimestepEmbedSequential layers in output_blocks )
463463 {
464464 Tensor index = skip_connections . Last ( ) ;
465- x = cat ( [ x , index ] , 1 ) ;
465+ x = cat ( new Tensor [ ] { x , index } , 1 ) ;
466466 skip_connections . RemoveAt ( skip_connections . Count - 1 ) ;
467467 x = layers . forward ( x , context , time ) ;
468468 }
@@ -528,7 +528,7 @@ private class UNet : Module<Tensor, Tensor, Tensor, Tensor, Tensor>
528528
529529 public UNet ( int model_channels , int in_channels , int [ ] ? channel_mult = null , int num_res_blocks = 2 , int context_dim = 768 , int adm_in_channels = 2816 , int num_heads = 20 , float dropout = 0.0f , bool use_timestep = true , Device ? device = null , ScalarType ? dtype = null ) : base ( nameof ( SDUnet ) )
530530 {
531- channel_mult = channel_mult ?? [ 1 , 2 , 4 ] ;
531+ channel_mult = channel_mult ?? new int [ ] { 1 , 2 , 4 } ;
532532
533533 ch = model_channels ;
534534 time_embed_dim = model_channels * 4 ;
@@ -538,7 +538,7 @@ public UNet(int model_channels, int in_channels, int[]? channel_mult = null, int
538538 bool useLinear = true ;
539539 bool mask = false ;
540540
541- List < int > input_block_channels = [ model_channels ] ;
541+ List < int > input_block_channels = new List < int > { model_channels } ;
542542
543543 if ( use_timestep )
544544 {
@@ -590,10 +590,10 @@ public override Tensor forward(Tensor x, Tensor context, Tensor time, Tensor y)
590590 {
591591 int dim = 512 ;
592592 Tensor embed = time_embed . forward ( time ) ;
593- Tensor time_ids = tensor ( new float [ ] { dim , dim , 0 , 0 , dim , dim } , embed . dtype , embed . device ) . repeat ( [ 2 , 1 ] ) ;
593+ Tensor time_ids = tensor ( new float [ ] { dim , dim , 0 , 0 , dim , dim } , embed . dtype , embed . device ) . repeat ( new long [ ] { 2 , 1 } ) ;
594594 Tensor time_embeds = get_timestep_embedding ( time_ids . flatten ( ) , dim / 2 , true , 0 , 1 ) ;
595- time_embeds = time_embeds . reshape ( [ 2 , - 1 ] ) ;
596- y = cat ( [ y , time_embeds ] , dim : - 1 ) ;
595+ time_embeds = time_embeds . reshape ( new long [ ] { 2 , - 1 } ) ;
596+ y = cat ( new Tensor [ ] { y , time_embeds } , dim : - 1 ) ;
597597 Tensor label_embed = label_emb . forward ( y . to ( embed . dtype , embed . device ) ) ;
598598 embed = embed + label_embed ;
599599
@@ -607,7 +607,7 @@ public override Tensor forward(Tensor x, Tensor context, Tensor time, Tensor y)
607607 foreach ( TimestepEmbedSequential layers in output_blocks )
608608 {
609609 Tensor index = skip_connections . Last ( ) ;
610- x = cat ( [ x , index ] , 1 ) ;
610+ x = cat ( new Tensor [ ] { x , index } , 1 ) ;
611611 skip_connections . RemoveAt ( skip_connections . Count - 1 ) ;
612612 x = layers . forward ( x , context , embed ) ;
613613 }
@@ -685,12 +685,12 @@ private static Tensor get_timestep_embedding(Tensor timesteps, int embedding_dim
685685 emb = scale * emb ;
686686
687687 // concat sine and cosine embeddings
688- emb = torch . cat ( [ torch . sin ( emb ) , torch . cos ( emb ) ] , dim : - 1 ) ;
688+ emb = torch . cat ( new Tensor [ ] { torch . sin ( emb ) , torch . cos ( emb ) } , dim : - 1 ) ;
689689
690690 // flip sine and cosine embeddings
691691 if ( flip_sin_to_cos )
692692 {
693- emb = torch . cat ( [ emb [ .., half_dim ..] , emb [ .., ..half_dim ] ] , dim : - 1 ) ;
693+ emb = torch . cat ( new Tensor [ ] { emb [ .., half_dim ..] , emb [ .., ..half_dim ] } , dim : - 1 ) ;
694694 }
695695
696696 // zero pad
0 commit comments