← Back to home

iCTF 2023 — Stop the model thief!

To steal an ML model, an attacker often sends 'very similar versions' of the same image, which tells the attacker how the model reacts to very small changes in the input. You realized that an attacker might be trying to steal your image classification model. You're given two files - [1::model_queries.npy] a list of images that your model received as inputs and [2::user_query_indices.txt] a list of image indices (starts from zero) in [1] sent to your model by each user-id. In [2], each line contains the indices from a different user-id (e.g., the very first line is user-id 0, the second line is user-id 1). Can you help us find the attacker's user-ids (there are 20 of them)? Note:: if there were 4 attacker user-ids (e.g., 82,54,13,36), the flag will be 'ictf{13,36,54,82}' (sorted, no quotes).

We know that each attacker user-id has sent at least 5 near-duplicate attack images.

We're given a bunch of 32x32 "query" images, as well as a list of images sent by each user. Our goal is to find the 20 users who have been sending malicious queries.

We can unzip and poke around with the images with a simple script like so:

Code (py):

1import os
2
3import numpy as np
4import cv2
5
6queries: np.ndarray = np.load('model_queries.npy')
7
8os.makedirs('queries', exist_ok=True)
9for i, q in enumerate(queries):
10    cv2.imwrite(f'./queries/{i}.jpg', q)
11
12with open('./user_query_indices.txt') as f:
13    for i, l in enumerate(f.readlines()):
14        os.makedirs(f'./users/{i}', exist_ok=True)
15
16        for id in l.split(","):
17            cv2.imwrite(f'./users/{i}/{id}.jpg', queries[int(id)])

After some clarification, the suspicious images we are meant to detect are slightly-tampered-with duplicates of other images:

<p align="center"> <img src="https://gist.github.com/assets/60120929/d72f7b1a-b34f-40f3-a4eb-436d87250bf9"> <img src="https://gist.github.com/assets/60120929/6633ae65-c771-4ae6-849d-5a4f0d2d7307"> <img src="https://gist.github.com/assets/60120929/cdd01c46-4591-4a2f-a256-f0dae058b6cc"> <img src="https://gist.github.com/assets/60120929/51ee800b-06ba-4dff-ade6-5a54229096c9"> <img src="https://gist.github.com/assets/60120929/140d156b-d1e4-4daf-b14d-455cb56d357c"> <img src="https://gist.github.com/assets/60120929/93f8efee-5139-4332-967d-b226368dbbfb"> <img src="https://gist.github.com/assets/60120929/00cf9cdf-1315-4c30-81ba-b2d923fcf0a3"> <img src="https://gist.github.com/assets/60120929/8b80f3e0-ab62-446b-9dd5-41198bada2e4"> </p>

Because these differences are per-pixel, the brute force solution is to subtract each image from each other image and check if that sum is below some threshold; if so, that pair of images is suspicious and can be labelled as such. Then, do one pass through the users and check their queries against the precomputed sus ids to determine whether that user is a malicious model-stealing agent.

Here's a rough implementation of the image similarity check:

Code (py):

1THRESH = 20000
2
3# ...
4
5def process_query(i: int, q: np.ndarray, queries, unique_images, sus_ids):
6    for u in unique_images:
7        diff = np.sum(cv2.absdiff(q, queries[u]))
8
9        if diff < THRESH:
10            print(i, u, diff)
11            sus_ids[i] = 1
12            sus_ids[u] = 1
13            break
14    else:
15        unique_images.append(i)

The threshold here was obtained by running a small pass of the algorithm and printing at each iteration the minimum diff of any pair of images. The threshold value was set to an arbitrary 2000 above the highest diff "fake" and below the lowest diff false positive; most fakes have diffs between 16000 - 17000, with some as low as 12000, while the most similar pair of real images had a diff of around 30000.

Code:

11666 2239 16581
22307 2239 17299
31999 1590 17126
4184 2239 16989
51423 102 17159
(snapshot of program output, each row showing the two ids detected as fakes and their image diff, respectively)

The problem with this sledgehammer solution is that comparing images an O(n²) algorithm, where n is in the order of some 10,000 images. While each image is only 32x32 pixels and operating on any individual pair of images is cheap, doing that same diffing on on 10,000² = 100,000,000 images might be a bit problematic.

Code (py):

1from multiprocessing import Pool, Manager
2
3import numpy as np
4import cv2
5
6PROCESSES = 8
7THRESH = 20000
8
9
10def process_query(i: int, q: np.ndarray, queries, unique_images, sus_ids):
11    for u in unique_images:
12        diff = np.sum(cv2.absdiff(q, queries[u]))
13
14        if diff < THRESH:
15            print(i, u, diff)
16            sus_ids[i] = 1
17            sus_ids[u] = 1
18            break
19    else:
20        unique_images.append(i)
21
22
23if __name__ == "__main__":
24    queries: np.ndarray = np.load('model_queries.npy')
25
26    manager = Manager()
27    unique_images = manager.list()
28    sus_ids = manager.dict()
29    sus_users = []
30
31    with Pool(PROCESSES) as pool:
32        pool.starmap(process_query, [(i, q, queries, unique_images, sus_ids) for (i, q) in enumerate(queries)])
33
34    print(sus_ids)
35
36    with open('./user_query_indices.txt') as f:
37        for i, l in enumerate(f.readlines()):
38            suspicious = [sus_ids.get(int(s), 0) for s in l.split(",")]
39            if np.sum(suspicious) >= 5:
40                sus_users.append(i)
41
42    print(sorted(sus_users))

Unfortunately, there doesn't seem to be an obvious way around it. Loading up the task in multiprocessing and churning away at it on 8 processes for about two hours, we get the flag:

Code:

1ictf{18,27,29,32,68,106,126,158,182,189,192,232,282,330,338,370,419,438,447,465}