@@ -31,9 +31,10 @@ ContextBase::ContextBase(CipherSuite suite_in)
31
31
ContextBase::~ContextBase () = default ;
32
32
33
33
void
34
- ContextBase::add_key (KeyID key_id, input_bytes base_key)
34
+ ContextBase::add_key (KeyID key_id, KeyUsage usage, input_bytes base_key)
35
35
{
36
- keys.emplace (key_id, KeyAndSalt::from_base_key (suite, key_id, base_key));
36
+ keys.emplace (key_id,
37
+ KeyAndSalt::from_base_key (suite, key_id, usage, base_key));
37
38
}
38
39
39
40
static owned_bytes<KeyAndSalt::max_salt_size>
@@ -67,6 +68,9 @@ ContextBase::protect(const Header& header,
67
68
}
68
69
69
70
const auto & key_and_salt = keys.at (header.key_id );
71
+ if (key_and_salt.usage != KeyUsage::protect) {
72
+ throw invalid_key_usage_error (" Decrypt-only key used for encryption" );
73
+ }
70
74
71
75
const auto aad = form_aad (header, metadata);
72
76
const auto nonce = form_nonce (header.counter , key_and_salt.salt );
@@ -88,6 +92,9 @@ ContextBase::unprotect(const Header& header,
88
92
}
89
93
90
94
const auto & key_and_salt = keys.at (header.key_id );
95
+ if (key_and_salt.usage != KeyUsage::unprotect) {
96
+ throw invalid_key_usage_error (" Encrypt-only key used for decryption" );
97
+ }
91
98
92
99
const auto aad = form_aad (header, metadata);
93
100
const auto nonce = form_nonce (header.counter , key_and_salt.salt );
@@ -134,7 +141,10 @@ sframe_salt_label(CipherSuite suite, KeyID key_id)
134
141
}
135
142
136
143
KeyAndSalt
137
- KeyAndSalt::from_base_key (CipherSuite suite, KeyID key_id, input_bytes base_key)
144
+ KeyAndSalt::from_base_key (CipherSuite suite,
145
+ KeyID key_id,
146
+ KeyUsage usage,
147
+ input_bytes base_key)
138
148
{
139
149
auto key_size = cipher_key_size (suite);
140
150
auto nonce_size = cipher_nonce_size (suite);
@@ -147,7 +157,7 @@ KeyAndSalt::from_base_key(CipherSuite suite, KeyID key_id, input_bytes base_key)
147
157
auto key = hkdf_expand (suite, secret, key_label, key_size);
148
158
auto salt = hkdf_expand (suite, secret, salt_label, nonce_size);
149
159
150
- return KeyAndSalt{ key, salt, 0 };
160
+ return KeyAndSalt{ key, salt, usage, 0 };
151
161
}
152
162
153
163
// /
@@ -162,9 +172,9 @@ Context::Context(CipherSuite suite_in)
162
172
Context::~Context () = default ;
163
173
164
174
void
165
- Context::add_key (KeyID key_id, input_bytes base_key)
175
+ Context::add_key (KeyID key_id, KeyUsage usage, input_bytes base_key)
166
176
{
167
- ContextBase::add_key (key_id, base_key);
177
+ ContextBase::add_key (key_id, usage, base_key);
168
178
counters.emplace (key_id, 0 );
169
179
}
170
180
@@ -263,7 +273,7 @@ MLSContext::protect(EpochID epoch_id,
263
273
input_bytes metadata)
264
274
{
265
275
auto key_id = form_key_id (epoch_id, sender_id, context_id);
266
- ensure_key (key_id);
276
+ ensure_key (key_id, KeyUsage::protect );
267
277
return Context::protect (key_id, ciphertext, plaintext, metadata);
268
278
}
269
279
@@ -275,7 +285,7 @@ MLSContext::unprotect(output_bytes plaintext,
275
285
const auto header = Header::parse (ciphertext);
276
286
const auto inner_ciphertext = ciphertext.subspan (header.size ());
277
287
278
- ensure_key (header.key_id );
288
+ ensure_key (header.key_id , KeyUsage::unprotect );
279
289
return ContextBase::unprotect (header, plaintext, inner_ciphertext, metadata);
280
290
}
281
291
@@ -358,7 +368,7 @@ MLSContext::form_key_id(EpochID epoch_id,
358
368
}
359
369
360
370
void
361
- MLSContext::ensure_key (KeyID key_id)
371
+ MLSContext::ensure_key (KeyID key_id, KeyUsage usage )
362
372
{
363
373
// If the required key already exists, we are done
364
374
const auto epoch_index = key_id & epoch_mask;
@@ -374,7 +384,7 @@ MLSContext::ensure_key(KeyID key_id)
374
384
375
385
// Otherwise, derive a key and implant it
376
386
const auto sender_id = key_id >> epoch_bits;
377
- Context::add_key (key_id, epoch->base_key (suite, sender_id));
387
+ Context::add_key (key_id, usage, epoch->base_key (suite, sender_id));
378
388
return ;
379
389
}
380
390
0 commit comments