danieldk HF Staff commited on
Commit
e32aaaa
·
verified ·
1 Parent(s): e9deed8

Benchmarks uploaded using `kernels`.

Browse files
Files changed (1) hide show
  1. benchmarks/benchmark.py +239 -0
benchmarks/benchmark.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from kernels.benchmark import Benchmark
4
+
5
+
6
+ def lsh_weighted_cumulation_reference(
7
+ query_mask: torch.Tensor,
8
+ query_hash_code: torch.Tensor,
9
+ query_weight: torch.Tensor,
10
+ key_mask: torch.Tensor,
11
+ key_hash_code: torch.Tensor,
12
+ key_weight: torch.Tensor,
13
+ value: torch.Tensor,
14
+ hashtable_capacity: int,
15
+ ) -> torch.Tensor:
16
+ batch_size, num_query, num_hash_f = query_hash_code.shape
17
+ _, num_key, value_dim = value.shape
18
+ weight_dim = query_weight.shape[2]
19
+ device = value.device
20
+ dtype = value.dtype
21
+
22
+ output = torch.zeros(batch_size, num_query, value_dim, device=device, dtype=dtype)
23
+
24
+ for b in range(batch_size):
25
+ for weight_idx in range(weight_dim):
26
+ # Build hashtables for all hash functions
27
+ hashtables = torch.zeros(
28
+ num_hash_f, hashtable_capacity, value_dim, device=device, dtype=dtype
29
+ )
30
+
31
+ k_mask = key_mask[b, :].float() # [num_key]
32
+ k_weight_val = key_weight[b, :, weight_idx] # [num_key]
33
+
34
+ for h in range(num_hash_f):
35
+ k_hash = key_hash_code[b, :, h].long() # [num_key]
36
+ # Weighted values: [num_key, value_dim]
37
+ weighted_values = (
38
+ k_mask.unsqueeze(-1) * k_weight_val.unsqueeze(-1) * value[b]
39
+ )
40
+ k_hash_expanded = k_hash.unsqueeze(-1).expand(-1, value_dim)
41
+ hashtables[h].scatter_add_(0, k_hash_expanded, weighted_values)
42
+
43
+ # Query: sum over all hash functions
44
+ q_mask = query_mask[b, :].float() # [num_query]
45
+ q_weight_val = query_weight[b, :, weight_idx] # [num_query]
46
+
47
+ sum_val = torch.zeros(num_query, value_dim, device=device, dtype=dtype)
48
+ for h in range(num_hash_f):
49
+ q_hash = query_hash_code[b, :, h].long() # [num_query]
50
+ gathered = hashtables[h][q_hash] # [num_query, value_dim]
51
+ sum_val += gathered
52
+
53
+ # Apply query weight and divide by num_hash_f
54
+ output[b] += (
55
+ q_mask.unsqueeze(-1) * q_weight_val.unsqueeze(-1) * sum_val / num_hash_f
56
+ )
57
+
58
+ return output
59
+
60
+
61
+ class YosoBenchmark(Benchmark):
62
+ seed: int = 42
63
+
64
+ def setup(self):
65
+ batch_size = 2
66
+ num_query = 128
67
+ num_key = 128
68
+ dim = 64
69
+ self.num_hash_f = 32
70
+ self.hash_code_len = 9
71
+ self.weight_dim = self.num_hash_f
72
+ self.value_dim = dim
73
+ self.hashtable_capacity = 1 << self.hash_code_len
74
+
75
+ self.query_mask = torch.ones(
76
+ batch_size, num_query, device=self.device, dtype=torch.int32
77
+ )
78
+ self.query_vector = torch.randn(
79
+ batch_size, num_query, dim, device=self.device, dtype=torch.float32
80
+ )
81
+ self.key_mask = torch.ones(
82
+ batch_size, num_key, device=self.device, dtype=torch.int32
83
+ )
84
+ self.key_vector = torch.randn(
85
+ batch_size, num_key, dim, device=self.device, dtype=torch.float32
86
+ )
87
+ self.value = torch.randn(
88
+ batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32
89
+ )
90
+ self.query_weight = torch.randn(
91
+ batch_size,
92
+ num_query,
93
+ self.weight_dim,
94
+ device=self.device,
95
+ dtype=torch.float32,
96
+ )
97
+ self.key_weight = torch.randn(
98
+ batch_size,
99
+ num_key,
100
+ self.weight_dim,
101
+ device=self.device,
102
+ dtype=torch.float32,
103
+ )
104
+
105
+ # Pre-compute hash codes for cumulation benchmarks
106
+ hash_result = self.kernel.fast_hash(
107
+ self.query_mask,
108
+ self.query_vector,
109
+ self.key_mask,
110
+ self.key_vector,
111
+ self.num_hash_f,
112
+ self.hash_code_len,
113
+ True,
114
+ 1,
115
+ )
116
+ self.query_hash_code = hash_result[0]
117
+ self.key_hash_code = hash_result[1]
118
+
119
+ self.out = torch.empty(
120
+ batch_size,
121
+ num_query,
122
+ self.value_dim,
123
+ device=self.device,
124
+ dtype=torch.float32,
125
+ )
126
+
127
+ def benchmark_base(self):
128
+ self.out = self.kernel.lsh_weighted_cumulation(
129
+ self.query_mask,
130
+ self.query_hash_code,
131
+ self.query_weight,
132
+ self.key_mask,
133
+ self.key_hash_code,
134
+ self.key_weight,
135
+ self.value,
136
+ self.hashtable_capacity,
137
+ True,
138
+ 1,
139
+ )
140
+
141
+ def verify_base(self) -> torch.Tensor:
142
+ return lsh_weighted_cumulation_reference(
143
+ self.query_mask,
144
+ self.query_hash_code,
145
+ self.query_weight,
146
+ self.key_mask,
147
+ self.key_hash_code,
148
+ self.key_weight,
149
+ self.value,
150
+ self.hashtable_capacity,
151
+ )
152
+
153
+ def setup_large(self):
154
+ batch_size = 4
155
+ num_query = 512
156
+ num_key = 512
157
+ dim = 128
158
+ self.num_hash_f = 32
159
+ self.hash_code_len = 9
160
+ self.weight_dim = self.num_hash_f
161
+ self.value_dim = dim
162
+ self.hashtable_capacity = 1 << self.hash_code_len
163
+
164
+ self.query_mask = torch.ones(
165
+ batch_size, num_query, device=self.device, dtype=torch.int32
166
+ )
167
+ self.query_vector = torch.randn(
168
+ batch_size, num_query, dim, device=self.device, dtype=torch.float32
169
+ )
170
+ self.key_mask = torch.ones(
171
+ batch_size, num_key, device=self.device, dtype=torch.int32
172
+ )
173
+ self.key_vector = torch.randn(
174
+ batch_size, num_key, dim, device=self.device, dtype=torch.float32
175
+ )
176
+ self.value = torch.randn(
177
+ batch_size, num_key, self.value_dim, device=self.device, dtype=torch.float32
178
+ )
179
+ self.query_weight = torch.randn(
180
+ batch_size,
181
+ num_query,
182
+ self.weight_dim,
183
+ device=self.device,
184
+ dtype=torch.float32,
185
+ )
186
+ self.key_weight = torch.randn(
187
+ batch_size,
188
+ num_key,
189
+ self.weight_dim,
190
+ device=self.device,
191
+ dtype=torch.float32,
192
+ )
193
+
194
+ hash_result = self.kernel.fast_hash(
195
+ self.query_mask,
196
+ self.query_vector,
197
+ self.key_mask,
198
+ self.key_vector,
199
+ self.num_hash_f,
200
+ self.hash_code_len,
201
+ True,
202
+ 1,
203
+ )
204
+ self.query_hash_code = hash_result[0]
205
+ self.key_hash_code = hash_result[1]
206
+
207
+ self.out = torch.empty(
208
+ batch_size,
209
+ num_query,
210
+ self.value_dim,
211
+ device=self.device,
212
+ dtype=torch.float32,
213
+ )
214
+
215
+ def benchmark_large(self):
216
+ self.out = self.kernel.lsh_weighted_cumulation(
217
+ self.query_mask,
218
+ self.query_hash_code,
219
+ self.query_weight,
220
+ self.key_mask,
221
+ self.key_hash_code,
222
+ self.key_weight,
223
+ self.value,
224
+ self.hashtable_capacity,
225
+ True,
226
+ 1,
227
+ )
228
+
229
+ def verify_large(self) -> torch.Tensor:
230
+ return lsh_weighted_cumulation_reference(
231
+ self.query_mask,
232
+ self.query_hash_code,
233
+ self.query_weight,
234
+ self.key_mask,
235
+ self.key_hash_code,
236
+ self.key_weight,
237
+ self.value,
238
+ self.hashtable_capacity,
239
+ )