taki0112 / Tensorflow-Cookbook
- понедельник, 18 февраля 2019 г. в 00:16:54
Python
Simple Tensorflow Cookbook for easy-to-use
In now, this repo contains general architectures and functions that are useful for the GAN.
I will continue to add useful things to other areas.
Also, your pull requests and issues are always welcome.
And write what you want to implement on the issue. I'll implement it.
ops.py
utils.py
def network(x, is_training=True, reuse=False, scope="network"):
with tf.variable_scope(scope, reuse=reuse):
x = conv(...)
...
return logit
Image_Data_Class = ImageData(img_size, img_ch, augment_flag)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=16)
trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
trainA_iterator = trainA.make_one_shot_iterator()
data_A = trainA_iterator.get_next()
logit = network(data_A)
padding='SAME'
pad_type
sn
Ra
loss_func
weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = tf.contrib.layers.l2_regularizer(0.0001)
weight_regularizer_fully = tf.contrib.layers.l2_regularizer(0.0001)
Xavier
: tf.contrib.layers.xavier_initializer()He
: tf.contrib.layers.variance_scaling_initializer()Normal
: tf.random_normal_initializer(mean=0.0, stddev=0.02)Truncated_normal
: tf.truncated_normal_initializer(mean=0.0, stddev=0.02)Orthogonal
: tf.orthogonal_initializer(1.0) / # if relu = sqrt(2), the others = 1.0l2_decay
: tf.contrib.layers.l2_regularizer(0.0001)orthogonal_regularizer
: orthogonal_regularizer(0.0001) & orthogonal_regularizer_fully(0.0001)x = conv(x, channels=64, kernel=3, stride=2, pad=1, pad_type='reflect', use_bias=True, sn=True, scope='conv')
x = partial_conv(x, channels=64, kernel=3, stride=2, use_bias=True, padding='SAME', sn=True, scope='partial_conv')
x = dilate_conv(x, channels=64, kernel=3, rate=2, use_bias=True, padding='SAME', sn=True, scope='dilate_conv')
x = deconv(x, channels=64, kernel=3, stride=2, padding='SAME', use_bias=True, sn=True, scope='deconv')
x = fully_conneted(x, units=64, use_bias=True, sn=True, scope='fully_connected')
x = conv_pixel_shuffle_down(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_down')
x = conv_pixel_shuffle_up(x, scale_factor=2, use_bias=True, sn=True, scope='pixel_shuffle_up')
down
===> [height, width] -> [height // scale_factor, width // scale_factor]up
===> [height, width] -> [height * scale_factor, width * scale_factor]x = resblock(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block')
x = resblock_down(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_down')
x = resblock_up(x, channels=64, is_training=is_training, use_bias=True, sn=True, scope='residual_block_up')
down
===> [height, width] -> [height // 2, width // 2]up
===> [height, width] -> [height * 2, width * 2]x = self_attention(x, channels=64, use_bias=True, sn=True, scope='self_attention')
x = self_attention_with_pooling(x, channels=64, use_bias=True, sn=True, scope='self_attention_version_2')
x = squeeze_excitation(x, channels=64, ratio=16, use_bias=True, sn=True, scope='squeeze_excitation')
x = convolution_block_attention(x, channels=64, ratio=16, use_bias=True, sn=True, scope='convolution_block_attention')
x = batch_norm(x, is_training=is_training, scope='batch_norm')
x = instance_norm(x, scope='instance_norm')
x = layer_norm(x, scope='layer_norm')
x = group_norm(x, groups=32, scope='group_norm')
x = pixel_norm(x)
x = batch_instance_norm(x, scope='batch_instance_norm')
x = condition_batch_norm(x, z, is_training=is_training, scope='condition_batch_norm'):
x = adaptive_instance_norm(x, gamma, beta):
x = relu(x)
x = lrelu(x, alpha=0.2)
x = tanh(x)
x = sigmoid(x)
x = swish(x)
x = up_sample(x, scale_factor=2)
x = max_pooling(x, pool_size=2)
x = avg_pooling(x, pool_size=2)
x = global_max_pooling(x)
x = global_avg_pooling(x)
x = flatten(x)
x = hw_flatten(x)
loss, accuracy = classification_loss(logit, label)
loss = L1_loss(x, y)
loss = L2_loss(x, y)
loss = huber_loss(x, y)
loss = histogram_loss(x, y)
histogram_loss
means the difference in the color distribution of the image pixel values.d_loss = discriminator_loss(Ra=True, loss_func='wgan-gp', real=real_logit, fake=fake_logit)
g_loss = generator_loss(Ra=True, loss_func='wgan_gp', real=real_logit, fake=fake_logit)
gradient_penalty
loss = kl_loss(mean, logvar)
Junho Kim