Skip to content

Commit cd5fc9e

Browse files
authored
Merge pull request #59 from cisco/unidirectional
Make keys unidirectional
2 parents 917511e + 9892e00 commit cd5fc9e

File tree

4 files changed

+41
-17
lines changed

4 files changed

+41
-17
lines changed

include/sframe/sframe.h

+17-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ struct invalid_parameter_error : std::runtime_error
3232
using parent::parent;
3333
};
3434

35+
struct invalid_key_usage_error : std::runtime_error
36+
{
37+
using parent = std::runtime_error;
38+
using parent::parent;
39+
};
40+
3541
enum class CipherSuite : uint16_t
3642
{
3743
AES_128_CTR_HMAC_SHA256_80 = 1,
@@ -219,17 +225,25 @@ class Header
219225
Header(KeyID key_id_in, Counter counter_in, input_bytes encoded_in);
220226
};
221227

228+
enum struct KeyUsage
229+
{
230+
protect,
231+
unprotect,
232+
};
233+
222234
struct KeyAndSalt
223235
{
224236
static KeyAndSalt from_base_key(CipherSuite suite,
225237
KeyID key_id,
238+
KeyUsage usage,
226239
input_bytes base_key);
227240

228241
static constexpr size_t max_key_size = 48;
229242
static constexpr size_t max_salt_size = 12;
230243

231244
owned_bytes<max_key_size> key;
232245
owned_bytes<max_salt_size> salt;
246+
KeyUsage usage;
233247
Counter counter;
234248
};
235249

@@ -246,7 +260,7 @@ class ContextBase
246260
ContextBase(CipherSuite suite_in);
247261
virtual ~ContextBase();
248262

249-
void add_key(KeyID kid, input_bytes key);
263+
void add_key(KeyID kid, KeyUsage usage, input_bytes key);
250264

251265
output_bytes protect(const Header& header,
252266
output_bytes ciphertext,
@@ -275,7 +289,7 @@ class Context : protected ContextBase
275289
Context(CipherSuite suite);
276290
virtual ~Context();
277291

278-
void add_key(KeyID kid, input_bytes key);
292+
void add_key(KeyID kid, KeyUsage usage, input_bytes key);
279293

280294
output_bytes protect(KeyID key_id,
281295
output_bytes ciphertext,
@@ -349,7 +363,7 @@ class MLSContext : protected Context
349363
KeyID form_key_id(EpochID epoch_id,
350364
SenderID sender_id,
351365
ContextID context_id) const;
352-
void ensure_key(KeyID key_id);
366+
void ensure_key(KeyID key_id, KeyUsage usage);
353367

354368
const size_t epoch_bits;
355369
const size_t epoch_mask;

src/sframe.cpp

+20-10
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ ContextBase::ContextBase(CipherSuite suite_in)
3131
ContextBase::~ContextBase() = default;
3232

3333
void
34-
ContextBase::add_key(KeyID key_id, input_bytes base_key)
34+
ContextBase::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
3535
{
36-
keys.emplace(key_id, KeyAndSalt::from_base_key(suite, key_id, base_key));
36+
keys.emplace(key_id,
37+
KeyAndSalt::from_base_key(suite, key_id, usage, base_key));
3738
}
3839

3940
static owned_bytes<KeyAndSalt::max_salt_size>
@@ -67,6 +68,9 @@ ContextBase::protect(const Header& header,
6768
}
6869

6970
const auto& key_and_salt = keys.at(header.key_id);
71+
if (key_and_salt.usage != KeyUsage::protect) {
72+
throw invalid_key_usage_error("Decrypt-only key used for encryption");
73+
}
7074

7175
const auto aad = form_aad(header, metadata);
7276
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
@@ -88,6 +92,9 @@ ContextBase::unprotect(const Header& header,
8892
}
8993

9094
const auto& key_and_salt = keys.at(header.key_id);
95+
if (key_and_salt.usage != KeyUsage::unprotect) {
96+
throw invalid_key_usage_error("Encrypt-only key used for decryption");
97+
}
9198

9299
const auto aad = form_aad(header, metadata);
93100
const auto nonce = form_nonce(header.counter, key_and_salt.salt);
@@ -134,7 +141,10 @@ sframe_salt_label(CipherSuite suite, KeyID key_id)
134141
}
135142

136143
KeyAndSalt
137-
KeyAndSalt::from_base_key(CipherSuite suite, KeyID key_id, input_bytes base_key)
144+
KeyAndSalt::from_base_key(CipherSuite suite,
145+
KeyID key_id,
146+
KeyUsage usage,
147+
input_bytes base_key)
138148
{
139149
auto key_size = cipher_key_size(suite);
140150
auto nonce_size = cipher_nonce_size(suite);
@@ -147,7 +157,7 @@ KeyAndSalt::from_base_key(CipherSuite suite, KeyID key_id, input_bytes base_key)
147157
auto key = hkdf_expand(suite, secret, key_label, key_size);
148158
auto salt = hkdf_expand(suite, secret, salt_label, nonce_size);
149159

150-
return KeyAndSalt{ key, salt, 0 };
160+
return KeyAndSalt{ key, salt, usage, 0 };
151161
}
152162

153163
///
@@ -162,9 +172,9 @@ Context::Context(CipherSuite suite_in)
162172
Context::~Context() = default;
163173

164174
void
165-
Context::add_key(KeyID key_id, input_bytes base_key)
175+
Context::add_key(KeyID key_id, KeyUsage usage, input_bytes base_key)
166176
{
167-
ContextBase::add_key(key_id, base_key);
177+
ContextBase::add_key(key_id, usage, base_key);
168178
counters.emplace(key_id, 0);
169179
}
170180

@@ -263,7 +273,7 @@ MLSContext::protect(EpochID epoch_id,
263273
input_bytes metadata)
264274
{
265275
auto key_id = form_key_id(epoch_id, sender_id, context_id);
266-
ensure_key(key_id);
276+
ensure_key(key_id, KeyUsage::protect);
267277
return Context::protect(key_id, ciphertext, plaintext, metadata);
268278
}
269279

@@ -275,7 +285,7 @@ MLSContext::unprotect(output_bytes plaintext,
275285
const auto header = Header::parse(ciphertext);
276286
const auto inner_ciphertext = ciphertext.subspan(header.size());
277287

278-
ensure_key(header.key_id);
288+
ensure_key(header.key_id, KeyUsage::unprotect);
279289
return ContextBase::unprotect(header, plaintext, inner_ciphertext, metadata);
280290
}
281291

@@ -358,7 +368,7 @@ MLSContext::form_key_id(EpochID epoch_id,
358368
}
359369

360370
void
361-
MLSContext::ensure_key(KeyID key_id)
371+
MLSContext::ensure_key(KeyID key_id, KeyUsage usage)
362372
{
363373
// If the required key already exists, we are done
364374
const auto epoch_index = key_id & epoch_mask;
@@ -374,7 +384,7 @@ MLSContext::ensure_key(KeyID key_id)
374384

375385
// Otherwise, derive a key and implant it
376386
const auto sender_id = key_id >> epoch_bits;
377-
Context::add_key(key_id, epoch->base_key(suite, sender_id));
387+
Context::add_key(key_id, usage, epoch->base_key(suite, sender_id));
378388
return;
379389
}
380390

test/sframe.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ TEST_CASE("SFrame Round-Trip")
5252
auto& key = pair.second;
5353

5454
auto send = Context(suite);
55-
send.add_key(kid, key);
55+
send.add_key(kid, KeyUsage::protect, key);
5656

5757
auto recv = Context(suite);
58-
recv.add_key(kid, key);
58+
recv.add_key(kid, KeyUsage::unprotect, key);
5959

6060
for (int i = 0; i < rounds; i++) {
6161
auto encrypted = to_bytes(send.protect(kid, ct_out, plaintext, {}));

test/vectors.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ struct SFrameTestVector
144144
{
145145
// Protect
146146
auto send_ctx = Context(cipher_suite);
147-
send_ctx.add_key(kid, base_key);
147+
send_ctx.add_key(kid, KeyUsage::protect, base_key);
148148

149149
auto ct_data = owned_bytes<128>();
150150
auto next_ctr = uint64_t(0);
@@ -163,7 +163,7 @@ struct SFrameTestVector
163163

164164
// Unprotect
165165
auto recv_ctx = Context(cipher_suite);
166-
recv_ctx.add_key(kid, base_key);
166+
recv_ctx.add_key(kid, KeyUsage::unprotect, base_key);
167167

168168
auto pt_data = owned_bytes<128>();
169169
auto pt_out = recv_ctx.unprotect(pt_data, ct, metadata);

0 commit comments

Comments
 (0)