圖像相似度估計(jì) | 結(jié)合三元組損失的暹羅網(wǎng)絡(luò)
在機(jī)器學(xué)習(xí)領(lǐng)域,確定圖像之間的相似度在各種應(yīng)用中至關(guān)重要,從檢測(cè)重復(fù)項(xiàng)到面部識(shí)別。解決這個(gè)問(wèn)題的一個(gè)強(qiáng)大方法是使用暹羅網(wǎng)絡(luò)結(jié)合三元組損失函數(shù)。在本文中,我們將探索如何構(gòu)建和訓(xùn)練暹羅網(wǎng)絡(luò)以估計(jì)圖像相似度,并通過(guò)一個(gè)來(lái)自GitHub倉(cāng)庫(kù)的實(shí)際示例進(jìn)行說(shuō)明。
什么是暹羅網(wǎng)絡(luò)?
暹羅網(wǎng)絡(luò)是一種包含兩個(gè)或更多相同子網(wǎng)絡(luò)的神經(jīng)網(wǎng)絡(luò)架構(gòu)。這些子網(wǎng)絡(luò)旨在為每個(gè)輸入生成特征向量,然后可以比較這些向量以估計(jì)相似度。關(guān)鍵思想是使用相同的網(wǎng)絡(luò)處理每個(gè)輸入,確保輸出一致且可比較。
這種架構(gòu)特別適合于檢測(cè)重復(fù)項(xiàng)、尋找異常和面部識(shí)別等任務(wù)。在我們將要探索的實(shí)現(xiàn)中,網(wǎng)絡(luò)設(shè)置有三個(gè)相同的子網(wǎng)絡(luò)。每個(gè)網(wǎng)絡(luò)處理三張圖像中的一張:錨點(diǎn)圖像、正樣本(與錨點(diǎn)相似)和負(fù)樣本(與錨點(diǎn)無(wú)關(guān))。
什么是三元組損失?
為了有效地訓(xùn)練暹羅網(wǎng)絡(luò),我們使用三元組損失函數(shù)。這種損失函數(shù)鼓勵(lì)網(wǎng)絡(luò)在特征空間中拉近錨點(diǎn)和正樣本的距離,同時(shí)將錨點(diǎn)和負(fù)樣本推得更遠(yuǎn)。損失函數(shù)定義如下:
L(A, P, N) = max(‖f(A) — f(P)‖2 — ‖f(A) — f(N)‖2 + margin, 0)
這里,A是錨點(diǎn)圖像,P是正圖像,N是負(fù)圖像。函數(shù)f(x)代表網(wǎng)絡(luò)生成的embedding,而margin是一個(gè)小的正值,有助于確保網(wǎng)絡(luò)不會(huì)將所有嵌入壓縮到同一點(diǎn)。
設(shè)置暹羅網(wǎng)絡(luò)
在這次實(shí)現(xiàn)中,我們首先加載Totally Looks Like數(shù)據(jù)集,其中包含我們用來(lái)創(chuàng)建訓(xùn)練網(wǎng)絡(luò)的三元組圖像。
1. 數(shù)據(jù)準(zhǔn)備
使用TensorFlow的tf.data API處理數(shù)據(jù)集以創(chuàng)建圖像三元組。這涉及到設(shè)置一個(gè)數(shù)據(jù)管道,其中每個(gè)三元組由錨點(diǎn)、正樣本和負(fù)樣本圖像組成。通過(guò)調(diào)整圖像大小到目標(biāo)形狀并歸一化像素值來(lái)預(yù)處理圖像。
def preprocess_image(filename):
image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image
def preprocess_triplets(anchor, positive, negative):
return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)
以下是從數(shù)據(jù)集中生成的三元組示例,每行的前兩張圖像相似(錨點(diǎn)和正樣本),第三張不同(負(fù)樣本):
圖1:在數(shù)據(jù)準(zhǔn)備期間生成的三元組。每行的前兩張圖像相似(錨點(diǎn)和正樣本),第三張不同(負(fù)樣本)
2.構(gòu)建 embedding 生成器
我們暹羅網(wǎng)絡(luò)的核心是嵌入生成器,它使用在ImageNet上預(yù)訓(xùn)練的ResNet50模型構(gòu)建。通過(guò)凍結(jié)ResNet50中的大部分層的權(quán)重,并且僅微調(diào)最后幾層,我們可以利用遷移學(xué)習(xí)來(lái)減少訓(xùn)練時(shí)間并提高性能。
base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)
flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)
embedding = Model(base_cnn.input, output, name="Embedding")
# Freeze all layers until the layer conv5_block1_out
trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable
3.構(gòu)建暹羅網(wǎng)絡(luò)
暹羅網(wǎng)絡(luò)設(shè)置為一次輸入三張圖像(錨點(diǎn)、正樣本和負(fù)樣本)。自定義的DistanceLayer計(jì)算錨點(diǎn)-正樣本對(duì)和錨點(diǎn)-負(fù)樣本對(duì)之間的距離。然后訓(xùn)練模型以最小化相似圖像之間的距離,并最大化不相似圖像之間的距離。
class DistanceLayer(layers.Layer):
def call(self, anchor, positive, negative):
ap_distance = tf.reduce_sum(tf.square(anchor - positive), -1)
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)
anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))
distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)
siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)
4.訓(xùn)練和評(píng)估
模型使用自定義訓(xùn)練循環(huán)進(jìn)行訓(xùn)練,其中計(jì)算三元組損失并用于更新網(wǎng)絡(luò)的權(quán)重。仔細(xì)監(jiān)控訓(xùn)練過(guò)程,并通過(guò)對(duì)學(xué)習(xí)到的嵌入進(jìn)行檢查來(lái)評(píng)估模型的性能。
class SiameseModel(Model):
def __init__(self, siamese_network, margin=0.5):
super(SiameseModel, self).__init__()
self.siamese_network = siamese_network
self.margin = margin
self.loss_tracker = metrics.Mean(name="loss")
def train_step(self, data):
with tf.GradientTape() as tape:
loss = self._compute_loss(data)
gradients = tape.gradient(loss, self.siamese_network.trainable_weights)
self.optimizer.apply_gradients(
zip(gradients, self.siamese_network.trainable_weights)
)
self.loss_tracker.update_state(loss)
return {"loss": self.loss_tracker.result()}
def _compute_loss(self, data):
ap_distance, an_distance = self.siamese_network(data)
loss = ap_distance - an_distance
loss = tf.maximum(loss + self.margin, 0.0)
return loss
5.檢查結(jié)果
訓(xùn)練完成后,我們可以通過(guò)比較錨點(diǎn)-正樣本對(duì)和錨點(diǎn)-負(fù)樣本對(duì)的嵌入之間的余弦相似度來(lái)評(píng)估網(wǎng)絡(luò)學(xué)習(xí)分離相似和不相似圖像的能力。
cosine_similarity = metrics.CosineSimilarity()
positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Positive similarity:", positive_similarity.numpy())
negative_similarity = cosine_similarity(anchor_embedding, negative_embedding)
print("Negative similarity:", negative_similarity.numpy())
以下是經(jīng)過(guò)訓(xùn)練的模型評(píng)估的三元組示例。網(wǎng)絡(luò)成功識(shí)別出圖像之間的相似性和差異:
圖2:經(jīng)過(guò)訓(xùn)練的暹羅網(wǎng)絡(luò)的輸出,其中每行的前兩張圖像被模型識(shí)別為相似,第三張為不同
結(jié)論
本文展示了使用三元組損失的暹羅網(wǎng)絡(luò)如何有效地估計(jì)圖像相似度。通過(guò)使用預(yù)訓(xùn)練的ResNet50模型并微調(diào)其層,我們可以創(chuàng)建一個(gè)可以應(yīng)用于需要相似度估計(jì)的各種任務(wù)。
完整代碼和解釋?zhuān)瑓⒖迹篽ttps://github.com/elcaiseri/Siamese-Network