Instructions to use NoteDance/CLIP-Keras with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use NoteDance/CLIP-Keras with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://NoteDance/CLIP-Keras") - Notebooks
- Google Colab
- Kaggle
| import tensorflow as tf | |
| from tensorflow.keras.layers import Dense,Conv2d,BatchNormalization,LayerNormalization,MultiHeadAttention | |
| from tensorflow.keras.layers import ZeroPadding2D,AveragePooling2D,Identity | |
| from tensorflow.keras import Model | |
| import numpy as np | |
| from typing import Tuple, Union | |
| class Bottleneck(tf.keras.layers.Layer): | |
| expansion = 4 | |
| def __init__(self, inplanes, planes, stride=1): | |
| # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 | |
| super(Bottleneck, self).__init__() | |
| self.conv1 = Conv2d(planes, 1, use_bias=False) | |
| self.bn1 = BatchNormalization() | |
| self.relu1 = tf.nn.relu | |
| self.zeropadding2d = ZeroPadding2D(padding=1) | |
| self.conv2 = Conv2d(planes, 3, use_bias=False) | |
| self.bn2 = BatchNormalization() | |
| self.relu2 = tf.nn.relu | |
| self.avgpool = AveragePooling2D(stride, stride, 'VALID') if stride > 1 else Identity() | |
| self.conv3 = Conv2d(planes * self.expansion, 1, use_bias=False) | |
| self.bn3 = BatchNormalization() | |
| self.relu3 = tf.nn.relu | |
| self.downsample = None | |
| self.stride = stride | |
| if stride > 1 or inplanes != planes * Bottleneck.expansion: | |
| # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 | |
| self.downsample = tf.keras.Sequential() | |
| self.downsample.add(AveragePooling2D(stride, stride, 'VALID')) | |
| self.downsample.add(Conv2d(planes * self.expansion, 1, strides=1, use_bias=False)) | |
| self.downsample.add(BatchNormalization()) | |
| def __call__(self, x): | |
| identity = x | |
| out = self.relu1(self.bn1(self.conv1(x))) | |
| out = self.zeropadding2d(out) | |
| out = self.relu2(self.bn2(self.conv2(out))) | |
| out = self.avgpool(out) | |
| out = self.bn3(self.conv3(out)) | |
| if self.downsample is not None: | |
| identity = self.downsample(x) | |
| out += identity | |
| out = self.relu3(out) | |
| return out | |
| class AttentionPool2d(tf.keras.layers.Layer): | |
| def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
| self.positional_embedding = self.add_weight( | |
| name='positional_embedding', | |
| shape=[self.spacial_dim ** 2 + 1, self.embed_dim], | |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1./self.embed_dim**0.5), | |
| trainable=True | |
| ) | |
| self.k_proj = Dense(embed_dim) | |
| self.q_proj = Dense(embed_dim) | |
| self.v_proj = Dense(embed_dim) | |
| self.c_proj = Dense(output_dim or embed_dim) | |
| self.num_heads = num_heads | |
| def __call__(self, x): | |
| shape = x.shape | |
| batch_size = shape[0] | |
| height = shape[1] | |
| width = shape[2] | |
| channels = shape[3] | |
| new_shape = (batch_size, height * width, channels) | |
| x = tf.transpose(tf.reshape(x, new_shape), (1, 0, 2)) | |
| x = tf.concat([tf.reduce_mean(x, axis=0, keepdims=True), x], axis=0) # (HW+1)NC | |
| x = x + tf.cast(self.positional_embedding[:, None, :], x.dtype) # (HW+1)NC | |
| tgt_len, bsz, embed_dim = x.shape | |
| query=self.q_proj(x[:1]) | |
| key=self.k_proj(x) | |
| value=self.v_proj(x) | |
| query = tf.reshape(query, [bsz, 1, self.num_heads, -1]) | |
| query = tf.transpose(query, [0, 2, 1, 3]) | |
| query = tf.multiply(query, 1.0 / tf.math.sqrt(float(embed_dim))) | |
| key = tf.reshape(key, [bsz, tgt_len, self.num_heads, -1]) | |
| key = tf.transpose(key, [0, 2, 3, 1]) | |
| value = tf.reshape(value, [bsz, tgt_len, self.num_heads, -1]) | |
| value = tf.transpose(value, [0, 2, 1, 3]) | |
| qk = tf.matmul(query, key) | |
| w = tf.nn.softmax(qk) | |
| wv = tf.reshape(tf.transpose(tf.matmul(w, value), [0, 2, 1, 3]), [1, bsz, -1]) | |
| x = self.c_proj(wv) | |
| return tf.squeeze(x, 0) | |
| class ModifiedResNet: | |
| """ | |
| A ResNet class that is similar to torchvision's but contains the following changes: | |
| - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |
| - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |
| - The final pooling layer is a QKV attention instead of an average pool | |
| """ | |
| def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): | |
| self.output_dim = output_dim | |
| self.input_resolution = input_resolution | |
| # the 3-layer stem | |
| self.zeropadding2d = ZeroPadding2D(padding=1) | |
| self.conv1 = Conv2d(width // 2, kernel_size=3, strides=2, use_bias=False) | |
| self.bn1 = BatchNormalization() | |
| self.relu1 = tf.nn.relu | |
| self.conv2 = Conv2d(width // 2, kernel_size=3, use_bias=False) | |
| self.bn2 = BatchNormalization() | |
| self.relu2 = tf.nn.relu | |
| self.conv3 = Conv2d(width, kernel_size=3, use_bias=False) | |
| self.bn3 = BatchNormalization() | |
| self.relu3 = tf.nn.relu | |
| self.avgpool = AveragePooling2D(2, 2, 'VALID') | |
| # residual layers | |
| self._inplanes = width # this is a *mutable* variable used during construction | |
| self.layer1 = self._make_layer(width, layers[0]) | |
| self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |
| self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |
| self.layer4 = self._make_layer(width * 8, layers[3], stride=2) | |
| embed_dim = width * 32 # the ResNet feature dimension | |
| self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) | |
| def _make_layer(self, planes, blocks, stride=1): | |
| layers = tf.keras.Sequential() | |
| layers.add(Bottleneck(self._inplanes, planes, stride)) | |
| self._inplanes = planes * Bottleneck.expansion | |
| for _ in range(1, blocks): | |
| layers.add(Bottleneck(self._inplanes, planes)) | |
| return layers | |
| def __call__(self, x): | |
| def stem(x): | |
| x = self.zeropadding2d(x) | |
| x = self.conv1(x) | |
| x = self.relu1(self.bn1(x)) | |
| x = self.zeropadding2d(x) | |
| x = self.conv2(x) | |
| x = self.relu2(self.bn2(x)) | |
| x = self.zeropadding2d(x) | |
| x = self.conv3(x) | |
| x = self.relu3(self.bn3(x)) | |
| x = self.avgpool(x) | |
| return x | |
| x = stem(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.attnpool(x) | |
| return x | |
| class LayerNorm: | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def __init__(self, input_size): | |
| self.layer_norm = LayerNormalization() | |
| def __call__(self, x): | |
| orig_type = x.dtype | |
| ret = self.layer_norm(tf.cast(x, tf.float32)) | |
| return tf.cast(ret, orig_type) | |
| class QuickGELU(tf.keras.layers.Layer): | |
| def __init__(self): | |
| super(QuickGELU, self).__init__() | |
| def __call__(self, x): | |
| return x * tf.nn.sigmoid(1.702 * x) | |
| class ResidualAttentionBlock(tf.keras.layers.Layer): | |
| def __init__(self, d_model: int, n_head: int, attn_mask = None): | |
| super(ResidualAttentionBlock, self).__init__() | |
| self.attn = MultiHeadAttention(n_head, d_model) | |
| self.ln_1 = LayerNorm(d_model) | |
| self.mlp = tf.keras.Sequential() | |
| self.mlp.add(Dense(d_model * 4)) | |
| self.mlp.add(QuickGELU()) | |
| self.mlp.add(Dense(d_model)) | |
| self.ln_2 = LayerNorm(d_model) | |
| self.attn_mask = attn_mask | |
| def attention(self, x): | |
| self.attn_mask = tf.cast(self.attn_mask, x.dtype) if self.attn_mask is not None else None | |
| return self.attn(x, x, attention_mask=self.attn_mask)[0] | |
| def __call__(self, x): | |
| x = x + self.attention(self.ln_1(x)) | |
| x = x + self.mlp(self.ln_2(x)) | |
| return x | |
| class Transformer: | |
| def __init__(self, width: int, layers: int, heads: int, attn_mask = None): | |
| self.width = width | |
| self.layers = layers | |
| self.resblocks = tf.keras.Sequential() | |
| for _ in range(layers): | |
| self.resblocks.add(ResidualAttentionBlock(width, heads, attn_mask)) | |
| def __call__(self, x): | |
| return self.resblocks(x) | |
| class VisionTransformer(tf.keras.layers.Layer): | |
| def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): | |
| self.input_resolution = input_resolution | |
| self.output_dim = output_dim | |
| self.conv1 = Conv2d(width, kernel_size=patch_size, strides=patch_size, use_bias=False) | |
| scale = width ** -0.5 | |
| self.class_embedding = self.add_weight( | |
| name='class_embedding', | |
| shape=[self.width], | |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, | |
| trainable=True | |
| ) | |
| self.positional_embedding = self.add_weight( | |
| name='positional_embedding', | |
| shape=[(self.input_resolution // self.patch_size) ** 2 + 1, self.width], | |
| initializer=tf.keras.initializers.RandomNormal(mean=0., stddev=1.0) * self.scale, | |
| trainable=True | |
| ) | |
| self.ln_pre = LayerNorm(width) | |
| self.transformer = Transformer(width, layers, heads) | |
| self.ln_post = LayerNorm(width) | |
| self.proj = tf.Variable(scale * tf.random.normal(width, output_dim)) | |
| def __call__(self, x, train_flag=True): | |
| x = self.conv1(x) # shape = [*, width, grid, grid] | |
| x = tf.reshape(x, [x.shape[0], x.shape[1], -1]) # shape = [*, width, grid ** 2] | |
| x = tf.transpose(x, (0, 2, 1)) # shape = [*, grid ** 2, width] | |
| x = tf.concat([tf.cast(self.class_embedding, x.dtype) + tf.zeros([x.shape[0], 1, x.shape[-1]], dtype=x.dtype), x], axis=1) # shape = [*, grid ** 2 + 1, width] | |
| x = x + tf.cast(self.positional_embedding, x.dtype) | |
| x = self.ln_pre(x) | |
| x = tf.transpose(x, (1, 0, 2)) # NLD -> LND | |
| x = self.transformer(x) | |
| x = tf.transpose(x, (1, 0, 2)) # LND -> NLD | |
| x = self.ln_post(x[:, 0, :]) | |
| if self.proj is not None: | |
| x = tf.matmul(x, self.proj) | |
| return x | |
| class CLIP(Model): | |
| def __init__(self, | |
| embed_dim: int, | |
| # vision | |
| image_resolution: int, | |
| vision_layers: Union[Tuple[int, int, int, int], int], | |
| vision_width: int, | |
| vision_patch_size: int, | |
| # text | |
| context_length: int, | |
| vocab_size: int, | |
| transformer_width: int, | |
| transformer_heads: int, | |
| transformer_layers: int | |
| ): | |
| super(CLIP, self).__init__() | |
| self.context_length = context_length | |
| if isinstance(vision_layers, (tuple, list)): | |
| vision_heads = vision_width * 32 // 64 | |
| self.visual = ModifiedResNet( | |
| layers=vision_layers, | |
| output_dim=embed_dim, | |
| heads=vision_heads, | |
| input_resolution=image_resolution, | |
| width=vision_width | |
| ) | |
| else: | |
| vision_heads = vision_width // 64 | |
| self.visual = VisionTransformer( | |
| input_resolution=image_resolution, | |
| patch_size=vision_patch_size, | |
| width=vision_width, | |
| layers=vision_layers, | |
| heads=vision_heads, | |
| output_dim=embed_dim | |
| ) | |
| self.transformer = Transformer( | |
| width=transformer_width, | |
| layers=transformer_layers, | |
| heads=transformer_heads, | |
| attn_mask=self.build_attention_mask() | |
| ) | |
| self.vocab_size = vocab_size | |
| self.token_embedding = self.add_weight( | |
| name='token_embedding', | |
| shape=(vocab_size, transformer_width), | |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.02), | |
| trainable=True | |
| ) | |
| self.positional_embedding = self.add_weight( | |
| name='positional_embedding', | |
| shape=(self.context_length, transformer_width), | |
| initializer=tf.keras.initializers.RandomNormal(stddev=0.01), | |
| trainable=True | |
| ) | |
| self.ln_final = LayerNorm(transformer_width) | |
| self.text_projection = self.add_weight( | |
| name='text_projection', | |
| shape=(transformer_width, embed_dim), | |
| initializer=tf.keras.initializers.RandomNormal(stddev=transformer_width ** -0.5), | |
| trainable=True | |
| ) | |
| self.logit_scale = self.add_weight( | |
| name='logit_scale', | |
| shape=[], | |
| initializer=tf.keras.initializers.Constant(np.log(1 / 0.07)), | |
| trainable=True | |
| ) | |
| def build_attention_mask(self): | |
| mask = tf.ones((self.context_length, self.context_length)) | |
| mask = tf.linalg.band_part(mask, 0, -1) # zero out the upper diagonal | |
| mask = mask * -1e9 # fill with -1e9 | |
| return mask | |
| def encode_image(self, image): | |
| return self.visual(image) | |
| def encode_text(self, text): | |
| x = tf.gather(self.token_embedding, text) # [batch_size, n_ctx, d_model] | |
| x = x + self.positional_embedding | |
| x = tf.transpose(x, (1, 0, 2)) # NLD -> LND | |
| x = self.transformer(x) | |
| x = tf.transpose(x, (1, 0, 2)) # LND -> NLD | |
| x = self.ln_final(x) | |
| # x.shape = [batch_size, n_ctx, transformer.width] | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| x = tf.matmul(tf.gather_nd(x, tf.stack([tf.range(x.shape[0], dtype='int32'), | |
| tf.argmax(text, axis=-1, output_type='int32')], axis=1)), self.text_projection) | |
| return x | |
| def __call__(self, image, text): | |
| image_features = self.encode_image(image) | |
| text_features = self.encode_text(text) | |
| # normalized features | |
| image_features = image_features / tf.norm(image_features, axis=1, keepdims=True) | |
| text_features = text_features / tf.norm(text_features, axis=1, keepdims=True) | |
| # cosine similarity as logits | |
| logit_scale = tf.math.exp(self.logit_scale) | |
| logits_per_image = tf.matmul(logit_scale * image_features, tf.transpose(text_features)) | |
| logits_per_text = tf.transpose(logits_per_image) | |
| # shape = [global_batch_size, global_batch_size] | |
| return logits_per_image, logits_per_text |