hc
2024-10-22 8ac6c7a54ed1b98d142dce24b11c6de6a1e239a5
kernel/net/tls/tls_sw.c
....@@ -4,6 +4,7 @@
44 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
55 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
66 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
78 *
89 * This software is available to you under a choice of one of two
910 * licenses. You may choose to be licensed under the terms of the GNU
....@@ -34,244 +35,1327 @@
3435 * SOFTWARE.
3536 */
3637
38
+#include <linux/bug.h>
3739 #include <linux/sched/signal.h>
3840 #include <linux/module.h>
41
+#include <linux/splice.h>
3942 #include <crypto/aead.h>
4043
4144 #include <net/strparser.h>
4245 #include <net/tls.h>
4346
44
-#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
47
+noinline void tls_err_abort(struct sock *sk, int err)
48
+{
49
+ WARN_ON_ONCE(err >= 0);
50
+ /* sk->sk_err should contain a positive error code. */
51
+ sk->sk_err = -err;
52
+ sk->sk_error_report(sk);
53
+}
54
+
55
+static int __skb_nsg(struct sk_buff *skb, int offset, int len,
56
+ unsigned int recursion_level)
57
+{
58
+ int start = skb_headlen(skb);
59
+ int i, chunk = start - offset;
60
+ struct sk_buff *frag_iter;
61
+ int elt = 0;
62
+
63
+ if (unlikely(recursion_level >= 24))
64
+ return -EMSGSIZE;
65
+
66
+ if (chunk > 0) {
67
+ if (chunk > len)
68
+ chunk = len;
69
+ elt++;
70
+ len -= chunk;
71
+ if (len == 0)
72
+ return elt;
73
+ offset += chunk;
74
+ }
75
+
76
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
77
+ int end;
78
+
79
+ WARN_ON(start > offset + len);
80
+
81
+ end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
82
+ chunk = end - offset;
83
+ if (chunk > 0) {
84
+ if (chunk > len)
85
+ chunk = len;
86
+ elt++;
87
+ len -= chunk;
88
+ if (len == 0)
89
+ return elt;
90
+ offset += chunk;
91
+ }
92
+ start = end;
93
+ }
94
+
95
+ if (unlikely(skb_has_frag_list(skb))) {
96
+ skb_walk_frags(skb, frag_iter) {
97
+ int end, ret;
98
+
99
+ WARN_ON(start > offset + len);
100
+
101
+ end = start + frag_iter->len;
102
+ chunk = end - offset;
103
+ if (chunk > 0) {
104
+ if (chunk > len)
105
+ chunk = len;
106
+ ret = __skb_nsg(frag_iter, offset - start, chunk,
107
+ recursion_level + 1);
108
+ if (unlikely(ret < 0))
109
+ return ret;
110
+ elt += ret;
111
+ len -= chunk;
112
+ if (len == 0)
113
+ return elt;
114
+ offset += chunk;
115
+ }
116
+ start = end;
117
+ }
118
+ }
119
+ BUG_ON(len);
120
+ return elt;
121
+}
122
+
123
+/* Return the number of scatterlist elements required to completely map the
124
+ * skb, or -EMSGSIZE if the recursion depth is exceeded.
125
+ */
126
+static int skb_nsg(struct sk_buff *skb, int offset, int len)
127
+{
128
+ return __skb_nsg(skb, offset, len, 0);
129
+}
130
+
131
+static int padding_length(struct tls_sw_context_rx *ctx,
132
+ struct tls_prot_info *prot, struct sk_buff *skb)
133
+{
134
+ struct strp_msg *rxm = strp_msg(skb);
135
+ int sub = 0;
136
+
137
+ /* Determine zero-padding length */
138
+ if (prot->version == TLS_1_3_VERSION) {
139
+ char content_type = 0;
140
+ int err;
141
+ int back = 17;
142
+
143
+ while (content_type == 0) {
144
+ if (back > rxm->full_len - prot->prepend_size)
145
+ return -EBADMSG;
146
+ err = skb_copy_bits(skb,
147
+ rxm->offset + rxm->full_len - back,
148
+ &content_type, 1);
149
+ if (err)
150
+ return err;
151
+ if (content_type)
152
+ break;
153
+ sub++;
154
+ back++;
155
+ }
156
+ ctx->control = content_type;
157
+ }
158
+ return sub;
159
+}
160
+
161
+static void tls_decrypt_done(struct crypto_async_request *req, int err)
162
+{
163
+ struct aead_request *aead_req = (struct aead_request *)req;
164
+ struct scatterlist *sgout = aead_req->dst;
165
+ struct scatterlist *sgin = aead_req->src;
166
+ struct tls_sw_context_rx *ctx;
167
+ struct tls_context *tls_ctx;
168
+ struct tls_prot_info *prot;
169
+ struct scatterlist *sg;
170
+ struct sk_buff *skb;
171
+ unsigned int pages;
172
+ int pending;
173
+
174
+ skb = (struct sk_buff *)req->data;
175
+ tls_ctx = tls_get_ctx(skb->sk);
176
+ ctx = tls_sw_ctx_rx(tls_ctx);
177
+ prot = &tls_ctx->prot_info;
178
+
179
+ /* Propagate if there was an err */
180
+ if (err) {
181
+ if (err == -EBADMSG)
182
+ TLS_INC_STATS(sock_net(skb->sk),
183
+ LINUX_MIB_TLSDECRYPTERROR);
184
+ ctx->async_wait.err = err;
185
+ tls_err_abort(skb->sk, err);
186
+ } else {
187
+ struct strp_msg *rxm = strp_msg(skb);
188
+ int pad;
189
+
190
+ pad = padding_length(ctx, prot, skb);
191
+ if (pad < 0) {
192
+ ctx->async_wait.err = pad;
193
+ tls_err_abort(skb->sk, pad);
194
+ } else {
195
+ rxm->full_len -= pad;
196
+ rxm->offset += prot->prepend_size;
197
+ rxm->full_len -= prot->overhead_size;
198
+ }
199
+ }
200
+
201
+ /* After using skb->sk to propagate sk through crypto async callback
202
+ * we need to NULL it again.
203
+ */
204
+ skb->sk = NULL;
205
+
206
+
207
+ /* Free the destination pages if skb was not decrypted inplace */
208
+ if (sgout != sgin) {
209
+ /* Skip the first S/G entry as it points to AAD */
210
+ for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
211
+ if (!sg)
212
+ break;
213
+ put_page(sg_page(sg));
214
+ }
215
+ }
216
+
217
+ kfree(aead_req);
218
+
219
+ spin_lock_bh(&ctx->decrypt_compl_lock);
220
+ pending = atomic_dec_return(&ctx->decrypt_pending);
221
+
222
+ if (!pending && ctx->async_notify)
223
+ complete(&ctx->async_wait.completion);
224
+ spin_unlock_bh(&ctx->decrypt_compl_lock);
225
+}
45226
46227 static int tls_do_decryption(struct sock *sk,
228
+ struct sk_buff *skb,
47229 struct scatterlist *sgin,
48230 struct scatterlist *sgout,
49231 char *iv_recv,
50232 size_t data_len,
51
- struct aead_request *aead_req)
233
+ struct aead_request *aead_req,
234
+ bool async)
52235 {
53236 struct tls_context *tls_ctx = tls_get_ctx(sk);
237
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
54238 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55239 int ret;
56240
57241 aead_request_set_tfm(aead_req, ctx->aead_recv);
58
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
242
+ aead_request_set_ad(aead_req, prot->aad_size);
59243 aead_request_set_crypt(aead_req, sgin, sgout,
60
- data_len + tls_ctx->rx.tag_size,
244
+ data_len + prot->tag_size,
61245 (u8 *)iv_recv);
62
- aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63
- crypto_req_done, &ctx->async_wait);
64246
65
- ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
247
+ if (async) {
248
+ /* Using skb->sk to push sk through to crypto async callback
249
+ * handler. This allows propagating errors up to the socket
250
+ * if needed. It _must_ be cleared in the async handler
251
+ * before consume_skb is called. We _know_ skb->sk is NULL
252
+ * because it is a clone from strparser.
253
+ */
254
+ skb->sk = sk;
255
+ aead_request_set_callback(aead_req,
256
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
257
+ tls_decrypt_done, skb);
258
+ atomic_inc(&ctx->decrypt_pending);
259
+ } else {
260
+ aead_request_set_callback(aead_req,
261
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
262
+ crypto_req_done, &ctx->async_wait);
263
+ }
264
+
265
+ ret = crypto_aead_decrypt(aead_req);
266
+ if (ret == -EINPROGRESS) {
267
+ if (async)
268
+ return ret;
269
+
270
+ ret = crypto_wait_req(ret, &ctx->async_wait);
271
+ }
272
+
273
+ if (async)
274
+ atomic_dec(&ctx->decrypt_pending);
275
+
66276 return ret;
67277 }
68278
69
-static void trim_sg(struct sock *sk, struct scatterlist *sg,
70
- int *sg_num_elem, unsigned int *sg_size, int target_size)
71
-{
72
- int i = *sg_num_elem - 1;
73
- int trim = *sg_size - target_size;
74
-
75
- if (trim <= 0) {
76
- WARN_ON(trim < 0);
77
- return;
78
- }
79
-
80
- *sg_size = target_size;
81
- while (trim >= sg[i].length) {
82
- trim -= sg[i].length;
83
- sk_mem_uncharge(sk, sg[i].length);
84
- put_page(sg_page(&sg[i]));
85
- i--;
86
-
87
- if (i < 0)
88
- goto out;
89
- }
90
-
91
- sg[i].length -= trim;
92
- sk_mem_uncharge(sk, trim);
93
-
94
-out:
95
- *sg_num_elem = i + 1;
96
-}
97
-
98
-static void trim_both_sgl(struct sock *sk, int target_size)
279
+static void tls_trim_both_msgs(struct sock *sk, int target_size)
99280 {
100281 struct tls_context *tls_ctx = tls_get_ctx(sk);
282
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
101283 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
284
+ struct tls_rec *rec = ctx->open_rec;
102285
103
- trim_sg(sk, ctx->sg_plaintext_data,
104
- &ctx->sg_plaintext_num_elem,
105
- &ctx->sg_plaintext_size,
106
- target_size);
107
-
286
+ sk_msg_trim(sk, &rec->msg_plaintext, target_size);
108287 if (target_size > 0)
109
- target_size += tls_ctx->tx.overhead_size;
110
-
111
- trim_sg(sk, ctx->sg_encrypted_data,
112
- &ctx->sg_encrypted_num_elem,
113
- &ctx->sg_encrypted_size,
114
- target_size);
288
+ target_size += prot->overhead_size;
289
+ sk_msg_trim(sk, &rec->msg_encrypted, target_size);
115290 }
116291
117
-static int alloc_encrypted_sg(struct sock *sk, int len)
292
+static int tls_alloc_encrypted_msg(struct sock *sk, int len)
118293 {
119294 struct tls_context *tls_ctx = tls_get_ctx(sk);
120295 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121
- int rc = 0;
296
+ struct tls_rec *rec = ctx->open_rec;
297
+ struct sk_msg *msg_en = &rec->msg_encrypted;
122298
123
- rc = sk_alloc_sg(sk, len,
124
- ctx->sg_encrypted_data, 0,
125
- &ctx->sg_encrypted_num_elem,
126
- &ctx->sg_encrypted_size, 0);
127
-
128
- if (rc == -ENOSPC)
129
- ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
130
-
131
- return rc;
299
+ return sk_msg_alloc(sk, msg_en, len, 0);
132300 }
133301
134
-static int alloc_plaintext_sg(struct sock *sk, int len)
302
+static int tls_clone_plaintext_msg(struct sock *sk, int required)
303
+{
304
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
305
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
306
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
307
+ struct tls_rec *rec = ctx->open_rec;
308
+ struct sk_msg *msg_pl = &rec->msg_plaintext;
309
+ struct sk_msg *msg_en = &rec->msg_encrypted;
310
+ int skip, len;
311
+
312
+ /* We add page references worth len bytes from encrypted sg
313
+ * at the end of plaintext sg. It is guaranteed that msg_en
314
+ * has enough required room (ensured by caller).
315
+ */
316
+ len = required - msg_pl->sg.size;
317
+
318
+ /* Skip initial bytes in msg_en's data to be able to use
319
+ * same offset of both plain and encrypted data.
320
+ */
321
+ skip = prot->prepend_size + msg_pl->sg.size;
322
+
323
+ return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
324
+}
325
+
326
+static struct tls_rec *tls_get_rec(struct sock *sk)
327
+{
328
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
329
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
330
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
331
+ struct sk_msg *msg_pl, *msg_en;
332
+ struct tls_rec *rec;
333
+ int mem_size;
334
+
335
+ mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
336
+
337
+ rec = kzalloc(mem_size, sk->sk_allocation);
338
+ if (!rec)
339
+ return NULL;
340
+
341
+ msg_pl = &rec->msg_plaintext;
342
+ msg_en = &rec->msg_encrypted;
343
+
344
+ sk_msg_init(msg_pl);
345
+ sk_msg_init(msg_en);
346
+
347
+ sg_init_table(rec->sg_aead_in, 2);
348
+ sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
349
+ sg_unmark_end(&rec->sg_aead_in[1]);
350
+
351
+ sg_init_table(rec->sg_aead_out, 2);
352
+ sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
353
+ sg_unmark_end(&rec->sg_aead_out[1]);
354
+
355
+ return rec;
356
+}
357
+
358
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
359
+{
360
+ sk_msg_free(sk, &rec->msg_encrypted);
361
+ sk_msg_free(sk, &rec->msg_plaintext);
362
+ kfree(rec);
363
+}
364
+
365
+static void tls_free_open_rec(struct sock *sk)
135366 {
136367 struct tls_context *tls_ctx = tls_get_ctx(sk);
137368 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
138
- int rc = 0;
369
+ struct tls_rec *rec = ctx->open_rec;
139370
140
- rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
141
- &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
142
- tls_ctx->pending_open_record_frags);
143
-
144
- if (rc == -ENOSPC)
145
- ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
146
-
147
- return rc;
148
-}
149
-
150
-static void free_sg(struct sock *sk, struct scatterlist *sg,
151
- int *sg_num_elem, unsigned int *sg_size)
152
-{
153
- int i, n = *sg_num_elem;
154
-
155
- for (i = 0; i < n; ++i) {
156
- sk_mem_uncharge(sk, sg[i].length);
157
- put_page(sg_page(&sg[i]));
371
+ if (rec) {
372
+ tls_free_rec(sk, rec);
373
+ ctx->open_rec = NULL;
158374 }
159
- *sg_num_elem = 0;
160
- *sg_size = 0;
161375 }
162376
163
-static void tls_free_both_sg(struct sock *sk)
377
+int tls_tx_records(struct sock *sk, int flags)
164378 {
165379 struct tls_context *tls_ctx = tls_get_ctx(sk);
166380 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
381
+ struct tls_rec *rec, *tmp;
382
+ struct sk_msg *msg_en;
383
+ int tx_flags, rc = 0;
167384
168
- free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
169
- &ctx->sg_encrypted_size);
385
+ if (tls_is_partially_sent_record(tls_ctx)) {
386
+ rec = list_first_entry(&ctx->tx_list,
387
+ struct tls_rec, list);
170388
171
- free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
172
- &ctx->sg_plaintext_size);
389
+ if (flags == -1)
390
+ tx_flags = rec->tx_flags;
391
+ else
392
+ tx_flags = flags;
393
+
394
+ rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
395
+ if (rc)
396
+ goto tx_err;
397
+
398
+ /* Full record has been transmitted.
399
+ * Remove the head of tx_list
400
+ */
401
+ list_del(&rec->list);
402
+ sk_msg_free(sk, &rec->msg_plaintext);
403
+ kfree(rec);
404
+ }
405
+
406
+ /* Tx all ready records */
407
+ list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
408
+ if (READ_ONCE(rec->tx_ready)) {
409
+ if (flags == -1)
410
+ tx_flags = rec->tx_flags;
411
+ else
412
+ tx_flags = flags;
413
+
414
+ msg_en = &rec->msg_encrypted;
415
+ rc = tls_push_sg(sk, tls_ctx,
416
+ &msg_en->sg.data[msg_en->sg.curr],
417
+ 0, tx_flags);
418
+ if (rc)
419
+ goto tx_err;
420
+
421
+ list_del(&rec->list);
422
+ sk_msg_free(sk, &rec->msg_plaintext);
423
+ kfree(rec);
424
+ } else {
425
+ break;
426
+ }
427
+ }
428
+
429
+tx_err:
430
+ if (rc < 0 && rc != -EAGAIN)
431
+ tls_err_abort(sk, -EBADMSG);
432
+
433
+ return rc;
173434 }
174435
175
-static int tls_do_encryption(struct tls_context *tls_ctx,
436
+static void tls_encrypt_done(struct crypto_async_request *req, int err)
437
+{
438
+ struct aead_request *aead_req = (struct aead_request *)req;
439
+ struct sock *sk = req->data;
440
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
441
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
442
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
443
+ struct scatterlist *sge;
444
+ struct sk_msg *msg_en;
445
+ struct tls_rec *rec;
446
+ bool ready = false;
447
+ int pending;
448
+
449
+ rec = container_of(aead_req, struct tls_rec, aead_req);
450
+ msg_en = &rec->msg_encrypted;
451
+
452
+ sge = sk_msg_elem(msg_en, msg_en->sg.curr);
453
+ sge->offset -= prot->prepend_size;
454
+ sge->length += prot->prepend_size;
455
+
456
+ /* Check if error is previously set on socket */
457
+ if (err || sk->sk_err) {
458
+ rec = NULL;
459
+
460
+ /* If err is already set on socket, return the same code */
461
+ if (sk->sk_err) {
462
+ ctx->async_wait.err = -sk->sk_err;
463
+ } else {
464
+ ctx->async_wait.err = err;
465
+ tls_err_abort(sk, err);
466
+ }
467
+ }
468
+
469
+ if (rec) {
470
+ struct tls_rec *first_rec;
471
+
472
+ /* Mark the record as ready for transmission */
473
+ smp_store_mb(rec->tx_ready, true);
474
+
475
+ /* If received record is at head of tx_list, schedule tx */
476
+ first_rec = list_first_entry(&ctx->tx_list,
477
+ struct tls_rec, list);
478
+ if (rec == first_rec)
479
+ ready = true;
480
+ }
481
+
482
+ spin_lock_bh(&ctx->encrypt_compl_lock);
483
+ pending = atomic_dec_return(&ctx->encrypt_pending);
484
+
485
+ if (!pending && ctx->async_notify)
486
+ complete(&ctx->async_wait.completion);
487
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
488
+
489
+ if (!ready)
490
+ return;
491
+
492
+ /* Schedule the transmission */
493
+ if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
494
+ schedule_delayed_work(&ctx->tx_work.work, 1);
495
+}
496
+
497
+static int tls_do_encryption(struct sock *sk,
498
+ struct tls_context *tls_ctx,
176499 struct tls_sw_context_tx *ctx,
177500 struct aead_request *aead_req,
178
- size_t data_len)
501
+ size_t data_len, u32 start)
179502 {
180
- int rc;
503
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
504
+ struct tls_rec *rec = ctx->open_rec;
505
+ struct sk_msg *msg_en = &rec->msg_encrypted;
506
+ struct scatterlist *sge = sk_msg_elem(msg_en, start);
507
+ int rc, iv_offset = 0;
181508
182
- ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
183
- ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
509
+ /* For CCM based ciphers, first byte of IV is a constant */
510
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
511
+ rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
512
+ iv_offset = 1;
513
+ }
514
+
515
+ memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
516
+ prot->iv_size + prot->salt_size);
517
+
518
+ xor_iv_with_seq(prot->version, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
519
+
520
+ sge->offset += prot->prepend_size;
521
+ sge->length -= prot->prepend_size;
522
+
523
+ msg_en->sg.curr = start;
184524
185525 aead_request_set_tfm(aead_req, ctx->aead_send);
186
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
187
- aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
188
- data_len, tls_ctx->tx.iv);
526
+ aead_request_set_ad(aead_req, prot->aad_size);
527
+ aead_request_set_crypt(aead_req, rec->sg_aead_in,
528
+ rec->sg_aead_out,
529
+ data_len, rec->iv_data);
189530
190531 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
191
- crypto_req_done, &ctx->async_wait);
532
+ tls_encrypt_done, sk);
192533
193
- rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
534
+ /* Add the record in tx_list */
535
+ list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
536
+ atomic_inc(&ctx->encrypt_pending);
194537
195
- ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
196
- ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
538
+ rc = crypto_aead_encrypt(aead_req);
539
+ if (!rc || rc != -EINPROGRESS) {
540
+ atomic_dec(&ctx->encrypt_pending);
541
+ sge->offset -= prot->prepend_size;
542
+ sge->length += prot->prepend_size;
543
+ }
197544
545
+ if (!rc) {
546
+ WRITE_ONCE(rec->tx_ready, true);
547
+ } else if (rc != -EINPROGRESS) {
548
+ list_del(&rec->list);
549
+ return rc;
550
+ }
551
+
552
+ /* Unhook the record from context if encryption is not failure */
553
+ ctx->open_rec = NULL;
554
+ tls_advance_record_sn(sk, prot, &tls_ctx->tx);
198555 return rc;
556
+}
557
+
558
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
559
+ struct tls_rec **to, struct sk_msg *msg_opl,
560
+ struct sk_msg *msg_oen, u32 split_point,
561
+ u32 tx_overhead_size, u32 *orig_end)
562
+{
563
+ u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
564
+ struct scatterlist *sge, *osge, *nsge;
565
+ u32 orig_size = msg_opl->sg.size;
566
+ struct scatterlist tmp = { };
567
+ struct sk_msg *msg_npl;
568
+ struct tls_rec *new;
569
+ int ret;
570
+
571
+ new = tls_get_rec(sk);
572
+ if (!new)
573
+ return -ENOMEM;
574
+ ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
575
+ tx_overhead_size, 0);
576
+ if (ret < 0) {
577
+ tls_free_rec(sk, new);
578
+ return ret;
579
+ }
580
+
581
+ *orig_end = msg_opl->sg.end;
582
+ i = msg_opl->sg.start;
583
+ sge = sk_msg_elem(msg_opl, i);
584
+ while (apply && sge->length) {
585
+ if (sge->length > apply) {
586
+ u32 len = sge->length - apply;
587
+
588
+ get_page(sg_page(sge));
589
+ sg_set_page(&tmp, sg_page(sge), len,
590
+ sge->offset + apply);
591
+ sge->length = apply;
592
+ bytes += apply;
593
+ apply = 0;
594
+ } else {
595
+ apply -= sge->length;
596
+ bytes += sge->length;
597
+ }
598
+
599
+ sk_msg_iter_var_next(i);
600
+ if (i == msg_opl->sg.end)
601
+ break;
602
+ sge = sk_msg_elem(msg_opl, i);
603
+ }
604
+
605
+ msg_opl->sg.end = i;
606
+ msg_opl->sg.curr = i;
607
+ msg_opl->sg.copybreak = 0;
608
+ msg_opl->apply_bytes = 0;
609
+ msg_opl->sg.size = bytes;
610
+
611
+ msg_npl = &new->msg_plaintext;
612
+ msg_npl->apply_bytes = apply;
613
+ msg_npl->sg.size = orig_size - bytes;
614
+
615
+ j = msg_npl->sg.start;
616
+ nsge = sk_msg_elem(msg_npl, j);
617
+ if (tmp.length) {
618
+ memcpy(nsge, &tmp, sizeof(*nsge));
619
+ sk_msg_iter_var_next(j);
620
+ nsge = sk_msg_elem(msg_npl, j);
621
+ }
622
+
623
+ osge = sk_msg_elem(msg_opl, i);
624
+ while (osge->length) {
625
+ memcpy(nsge, osge, sizeof(*nsge));
626
+ sg_unmark_end(nsge);
627
+ sk_msg_iter_var_next(i);
628
+ sk_msg_iter_var_next(j);
629
+ if (i == *orig_end)
630
+ break;
631
+ osge = sk_msg_elem(msg_opl, i);
632
+ nsge = sk_msg_elem(msg_npl, j);
633
+ }
634
+
635
+ msg_npl->sg.end = j;
636
+ msg_npl->sg.curr = j;
637
+ msg_npl->sg.copybreak = 0;
638
+
639
+ *to = new;
640
+ return 0;
641
+}
642
+
643
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
644
+ struct tls_rec *from, u32 orig_end)
645
+{
646
+ struct sk_msg *msg_npl = &from->msg_plaintext;
647
+ struct sk_msg *msg_opl = &to->msg_plaintext;
648
+ struct scatterlist *osge, *nsge;
649
+ u32 i, j;
650
+
651
+ i = msg_opl->sg.end;
652
+ sk_msg_iter_var_prev(i);
653
+ j = msg_npl->sg.start;
654
+
655
+ osge = sk_msg_elem(msg_opl, i);
656
+ nsge = sk_msg_elem(msg_npl, j);
657
+
658
+ if (sg_page(osge) == sg_page(nsge) &&
659
+ osge->offset + osge->length == nsge->offset) {
660
+ osge->length += nsge->length;
661
+ put_page(sg_page(nsge));
662
+ }
663
+
664
+ msg_opl->sg.end = orig_end;
665
+ msg_opl->sg.curr = orig_end;
666
+ msg_opl->sg.copybreak = 0;
667
+ msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
668
+ msg_opl->sg.size += msg_npl->sg.size;
669
+
670
+ sk_msg_free(sk, &to->msg_encrypted);
671
+ sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
672
+
673
+ kfree(from);
199674 }
200675
201676 static int tls_push_record(struct sock *sk, int flags,
202677 unsigned char record_type)
203678 {
204679 struct tls_context *tls_ctx = tls_get_ctx(sk);
680
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
205681 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
682
+ struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
683
+ u32 i, split_point, orig_end;
684
+ struct sk_msg *msg_pl, *msg_en;
206685 struct aead_request *req;
686
+ bool split;
207687 int rc;
208688
209
- req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
210
- if (!req)
211
- return -ENOMEM;
689
+ if (!rec)
690
+ return 0;
212691
213
- sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
214
- sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
692
+ msg_pl = &rec->msg_plaintext;
693
+ msg_en = &rec->msg_encrypted;
215694
216
- tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
217
- tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
218
- record_type);
219
-
220
- tls_fill_prepend(tls_ctx,
221
- page_address(sg_page(&ctx->sg_encrypted_data[0])) +
222
- ctx->sg_encrypted_data[0].offset,
223
- ctx->sg_plaintext_size, record_type);
224
-
225
- tls_ctx->pending_open_record_frags = 0;
226
- set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
227
-
228
- rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
229
- if (rc < 0) {
230
- /* If we are called from write_space and
231
- * we fail, we need to set this SOCK_NOSPACE
232
- * to trigger another write_space in the future.
695
+ split_point = msg_pl->apply_bytes;
696
+ split = split_point && split_point < msg_pl->sg.size;
697
+ if (unlikely((!split &&
698
+ msg_pl->sg.size +
699
+ prot->overhead_size > msg_en->sg.size) ||
700
+ (split &&
701
+ split_point +
702
+ prot->overhead_size > msg_en->sg.size))) {
703
+ split = true;
704
+ split_point = msg_en->sg.size;
705
+ }
706
+ if (split) {
707
+ rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
708
+ split_point, prot->overhead_size,
709
+ &orig_end);
710
+ if (rc < 0)
711
+ return rc;
712
+ /* This can happen if above tls_split_open_record allocates
713
+ * a single large encryption buffer instead of two smaller
714
+ * ones. In this case adjust pointers and continue without
715
+ * split.
233716 */
234
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
235
- goto out_req;
717
+ if (!msg_pl->sg.size) {
718
+ tls_merge_open_record(sk, rec, tmp, orig_end);
719
+ msg_pl = &rec->msg_plaintext;
720
+ msg_en = &rec->msg_encrypted;
721
+ split = false;
722
+ }
723
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
724
+ prot->overhead_size);
236725 }
237726
238
- free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
239
- &ctx->sg_plaintext_size);
727
+ rec->tx_flags = flags;
728
+ req = &rec->aead_req;
240729
241
- ctx->sg_encrypted_num_elem = 0;
242
- ctx->sg_encrypted_size = 0;
730
+ i = msg_pl->sg.end;
731
+ sk_msg_iter_var_prev(i);
243732
244
- /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
245
- rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
246
- if (rc < 0 && rc != -EAGAIN)
247
- tls_err_abort(sk, EBADMSG);
733
+ rec->content_type = record_type;
734
+ if (prot->version == TLS_1_3_VERSION) {
735
+ /* Add content type to end of message. No padding added */
736
+ sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
737
+ sg_mark_end(&rec->sg_content_type);
738
+ sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
739
+ &rec->sg_content_type);
740
+ } else {
741
+ sg_mark_end(sk_msg_elem(msg_pl, i));
742
+ }
248743
249
- tls_advance_record_sn(sk, &tls_ctx->tx);
250
-out_req:
251
- aead_request_free(req);
252
- return rc;
744
+ if (msg_pl->sg.end < msg_pl->sg.start) {
745
+ sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
746
+ MAX_SKB_FRAGS - msg_pl->sg.start + 1,
747
+ msg_pl->sg.data);
748
+ }
749
+
750
+ i = msg_pl->sg.start;
751
+ sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
752
+
753
+ i = msg_en->sg.end;
754
+ sk_msg_iter_var_prev(i);
755
+ sg_mark_end(sk_msg_elem(msg_en, i));
756
+
757
+ i = msg_en->sg.start;
758
+ sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
759
+
760
+ tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
761
+ tls_ctx->tx.rec_seq, prot->rec_seq_size,
762
+ record_type, prot->version);
763
+
764
+ tls_fill_prepend(tls_ctx,
765
+ page_address(sg_page(&msg_en->sg.data[i])) +
766
+ msg_en->sg.data[i].offset,
767
+ msg_pl->sg.size + prot->tail_size,
768
+ record_type, prot->version);
769
+
770
+ tls_ctx->pending_open_record_frags = false;
771
+
772
+ rc = tls_do_encryption(sk, tls_ctx, ctx, req,
773
+ msg_pl->sg.size + prot->tail_size, i);
774
+ if (rc < 0) {
775
+ if (rc != -EINPROGRESS) {
776
+ tls_err_abort(sk, -EBADMSG);
777
+ if (split) {
778
+ tls_ctx->pending_open_record_frags = true;
779
+ tls_merge_open_record(sk, rec, tmp, orig_end);
780
+ }
781
+ }
782
+ ctx->async_capable = 1;
783
+ return rc;
784
+ } else if (split) {
785
+ msg_pl = &tmp->msg_plaintext;
786
+ msg_en = &tmp->msg_encrypted;
787
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
788
+ tls_ctx->pending_open_record_frags = true;
789
+ ctx->open_rec = tmp;
790
+ }
791
+
792
+ return tls_tx_records(sk, flags);
793
+}
794
+
795
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
796
+ bool full_record, u8 record_type,
797
+ ssize_t *copied, int flags)
798
+{
799
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
800
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
801
+ struct sk_msg msg_redir = { };
802
+ struct sk_psock *psock;
803
+ struct sock *sk_redir;
804
+ struct tls_rec *rec;
805
+ bool enospc, policy;
806
+ int err = 0, send;
807
+ u32 delta = 0;
808
+
809
+ policy = !(flags & MSG_SENDPAGE_NOPOLICY);
810
+ psock = sk_psock_get(sk);
811
+ if (!psock || !policy) {
812
+ err = tls_push_record(sk, flags, record_type);
813
+ if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
814
+ *copied -= sk_msg_free(sk, msg);
815
+ tls_free_open_rec(sk);
816
+ err = -sk->sk_err;
817
+ }
818
+ if (psock)
819
+ sk_psock_put(sk, psock);
820
+ return err;
821
+ }
822
+more_data:
823
+ enospc = sk_msg_full(msg);
824
+ if (psock->eval == __SK_NONE) {
825
+ delta = msg->sg.size;
826
+ psock->eval = sk_psock_msg_verdict(sk, psock, msg);
827
+ delta -= msg->sg.size;
828
+ }
829
+ if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
830
+ !enospc && !full_record) {
831
+ err = -ENOSPC;
832
+ goto out_err;
833
+ }
834
+ msg->cork_bytes = 0;
835
+ send = msg->sg.size;
836
+ if (msg->apply_bytes && msg->apply_bytes < send)
837
+ send = msg->apply_bytes;
838
+
839
+ switch (psock->eval) {
840
+ case __SK_PASS:
841
+ err = tls_push_record(sk, flags, record_type);
842
+ if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) {
843
+ *copied -= sk_msg_free(sk, msg);
844
+ tls_free_open_rec(sk);
845
+ err = -sk->sk_err;
846
+ goto out_err;
847
+ }
848
+ break;
849
+ case __SK_REDIRECT:
850
+ sk_redir = psock->sk_redir;
851
+ memcpy(&msg_redir, msg, sizeof(*msg));
852
+ if (msg->apply_bytes < send)
853
+ msg->apply_bytes = 0;
854
+ else
855
+ msg->apply_bytes -= send;
856
+ sk_msg_return_zero(sk, msg, send);
857
+ msg->sg.size -= send;
858
+ release_sock(sk);
859
+ err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
860
+ lock_sock(sk);
861
+ if (err < 0) {
862
+ *copied -= sk_msg_free_nocharge(sk, &msg_redir);
863
+ msg->sg.size = 0;
864
+ }
865
+ if (msg->sg.size == 0)
866
+ tls_free_open_rec(sk);
867
+ break;
868
+ case __SK_DROP:
869
+ default:
870
+ sk_msg_free_partial(sk, msg, send);
871
+ if (msg->apply_bytes < send)
872
+ msg->apply_bytes = 0;
873
+ else
874
+ msg->apply_bytes -= send;
875
+ if (msg->sg.size == 0)
876
+ tls_free_open_rec(sk);
877
+ *copied -= (send + delta);
878
+ err = -EACCES;
879
+ }
880
+
881
+ if (likely(!err)) {
882
+ bool reset_eval = !ctx->open_rec;
883
+
884
+ rec = ctx->open_rec;
885
+ if (rec) {
886
+ msg = &rec->msg_plaintext;
887
+ if (!msg->apply_bytes)
888
+ reset_eval = true;
889
+ }
890
+ if (reset_eval) {
891
+ psock->eval = __SK_NONE;
892
+ if (psock->sk_redir) {
893
+ sock_put(psock->sk_redir);
894
+ psock->sk_redir = NULL;
895
+ }
896
+ }
897
+ if (rec)
898
+ goto more_data;
899
+ }
900
+ out_err:
901
+ sk_psock_put(sk, psock);
902
+ return err;
253903 }
254904
255905 static int tls_sw_push_pending_record(struct sock *sk, int flags)
256906 {
257
- return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
907
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
908
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
909
+ struct tls_rec *rec = ctx->open_rec;
910
+ struct sk_msg *msg_pl;
911
+ size_t copied;
912
+
913
+ if (!rec)
914
+ return 0;
915
+
916
+ msg_pl = &rec->msg_plaintext;
917
+ copied = msg_pl->sg.size;
918
+ if (!copied)
919
+ return 0;
920
+
921
+ return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
922
+ &copied, flags);
258923 }
259924
260
-static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
261
- int length, int *pages_used,
262
- unsigned int *size_used,
263
- struct scatterlist *to, int to_max_pages,
264
- bool charge)
925
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
265926 {
266
- struct page *pages[MAX_SKB_FRAGS];
927
+ long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
928
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
929
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
930
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
931
+ bool async_capable = ctx->async_capable;
932
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
933
+ bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
934
+ bool eor = !(msg->msg_flags & MSG_MORE);
935
+ size_t try_to_copy;
936
+ ssize_t copied = 0;
937
+ struct sk_msg *msg_pl, *msg_en;
938
+ struct tls_rec *rec;
939
+ int required_size;
940
+ int num_async = 0;
941
+ bool full_record;
942
+ int record_room;
943
+ int num_zc = 0;
944
+ int orig_size;
945
+ int ret = 0;
946
+ int pending;
267947
268
- size_t offset;
269
- ssize_t copied, use;
270
- int i = 0;
948
+ if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
949
+ MSG_CMSG_COMPAT))
950
+ return -EOPNOTSUPP;
951
+
952
+ ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
953
+ if (ret)
954
+ return ret;
955
+ lock_sock(sk);
956
+
957
+ if (unlikely(msg->msg_controllen)) {
958
+ ret = tls_proccess_cmsg(sk, msg, &record_type);
959
+ if (ret) {
960
+ if (ret == -EINPROGRESS)
961
+ num_async++;
962
+ else if (ret != -EAGAIN)
963
+ goto send_end;
964
+ }
965
+ }
966
+
967
+ while (msg_data_left(msg)) {
968
+ if (sk->sk_err) {
969
+ ret = -sk->sk_err;
970
+ goto send_end;
971
+ }
972
+
973
+ if (ctx->open_rec)
974
+ rec = ctx->open_rec;
975
+ else
976
+ rec = ctx->open_rec = tls_get_rec(sk);
977
+ if (!rec) {
978
+ ret = -ENOMEM;
979
+ goto send_end;
980
+ }
981
+
982
+ msg_pl = &rec->msg_plaintext;
983
+ msg_en = &rec->msg_encrypted;
984
+
985
+ orig_size = msg_pl->sg.size;
986
+ full_record = false;
987
+ try_to_copy = msg_data_left(msg);
988
+ record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
989
+ if (try_to_copy >= record_room) {
990
+ try_to_copy = record_room;
991
+ full_record = true;
992
+ }
993
+
994
+ required_size = msg_pl->sg.size + try_to_copy +
995
+ prot->overhead_size;
996
+
997
+ if (!sk_stream_memory_free(sk))
998
+ goto wait_for_sndbuf;
999
+
1000
+alloc_encrypted:
1001
+ ret = tls_alloc_encrypted_msg(sk, required_size);
1002
+ if (ret) {
1003
+ if (ret != -ENOSPC)
1004
+ goto wait_for_memory;
1005
+
1006
+ /* Adjust try_to_copy according to the amount that was
1007
+ * actually allocated. The difference is due
1008
+ * to max sg elements limit
1009
+ */
1010
+ try_to_copy -= required_size - msg_en->sg.size;
1011
+ full_record = true;
1012
+ }
1013
+
1014
+ if (!is_kvec && (full_record || eor) && !async_capable) {
1015
+ u32 first = msg_pl->sg.end;
1016
+
1017
+ ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1018
+ msg_pl, try_to_copy);
1019
+ if (ret)
1020
+ goto fallback_to_reg_send;
1021
+
1022
+ num_zc++;
1023
+ copied += try_to_copy;
1024
+
1025
+ sk_msg_sg_copy_set(msg_pl, first);
1026
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1027
+ record_type, &copied,
1028
+ msg->msg_flags);
1029
+ if (ret) {
1030
+ if (ret == -EINPROGRESS)
1031
+ num_async++;
1032
+ else if (ret == -ENOMEM)
1033
+ goto wait_for_memory;
1034
+ else if (ctx->open_rec && ret == -ENOSPC)
1035
+ goto rollback_iter;
1036
+ else if (ret != -EAGAIN)
1037
+ goto send_end;
1038
+ }
1039
+ continue;
1040
+rollback_iter:
1041
+ copied -= try_to_copy;
1042
+ sk_msg_sg_copy_clear(msg_pl, first);
1043
+ iov_iter_revert(&msg->msg_iter,
1044
+ msg_pl->sg.size - orig_size);
1045
+fallback_to_reg_send:
1046
+ sk_msg_trim(sk, msg_pl, orig_size);
1047
+ }
1048
+
1049
+ required_size = msg_pl->sg.size + try_to_copy;
1050
+
1051
+ ret = tls_clone_plaintext_msg(sk, required_size);
1052
+ if (ret) {
1053
+ if (ret != -ENOSPC)
1054
+ goto send_end;
1055
+
1056
+ /* Adjust try_to_copy according to the amount that was
1057
+ * actually allocated. The difference is due
1058
+ * to max sg elements limit
1059
+ */
1060
+ try_to_copy -= required_size - msg_pl->sg.size;
1061
+ full_record = true;
1062
+ sk_msg_trim(sk, msg_en,
1063
+ msg_pl->sg.size + prot->overhead_size);
1064
+ }
1065
+
1066
+ if (try_to_copy) {
1067
+ ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1068
+ msg_pl, try_to_copy);
1069
+ if (ret < 0)
1070
+ goto trim_sgl;
1071
+ }
1072
+
1073
+ /* Open records defined only if successfully copied, otherwise
1074
+ * we would trim the sg but not reset the open record frags.
1075
+ */
1076
+ tls_ctx->pending_open_record_frags = true;
1077
+ copied += try_to_copy;
1078
+ if (full_record || eor) {
1079
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1080
+ record_type, &copied,
1081
+ msg->msg_flags);
1082
+ if (ret) {
1083
+ if (ret == -EINPROGRESS)
1084
+ num_async++;
1085
+ else if (ret == -ENOMEM)
1086
+ goto wait_for_memory;
1087
+ else if (ret != -EAGAIN) {
1088
+ if (ret == -ENOSPC)
1089
+ ret = 0;
1090
+ goto send_end;
1091
+ }
1092
+ }
1093
+ }
1094
+
1095
+ continue;
1096
+
1097
+wait_for_sndbuf:
1098
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1099
+wait_for_memory:
1100
+ ret = sk_stream_wait_memory(sk, &timeo);
1101
+ if (ret) {
1102
+trim_sgl:
1103
+ if (ctx->open_rec)
1104
+ tls_trim_both_msgs(sk, orig_size);
1105
+ goto send_end;
1106
+ }
1107
+
1108
+ if (ctx->open_rec && msg_en->sg.size < required_size)
1109
+ goto alloc_encrypted;
1110
+ }
1111
+
1112
+ if (!num_async) {
1113
+ goto send_end;
1114
+ } else if (num_zc) {
1115
+ /* Wait for pending encryptions to get completed */
1116
+ spin_lock_bh(&ctx->encrypt_compl_lock);
1117
+ ctx->async_notify = true;
1118
+
1119
+ pending = atomic_read(&ctx->encrypt_pending);
1120
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
1121
+ if (pending)
1122
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1123
+ else
1124
+ reinit_completion(&ctx->async_wait.completion);
1125
+
1126
+ /* There can be no concurrent accesses, since we have no
1127
+ * pending encrypt operations
1128
+ */
1129
+ WRITE_ONCE(ctx->async_notify, false);
1130
+
1131
+ if (ctx->async_wait.err) {
1132
+ ret = ctx->async_wait.err;
1133
+ copied = 0;
1134
+ }
1135
+ }
1136
+
1137
+ /* Transmit if any encryptions have completed */
1138
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1139
+ cancel_delayed_work(&ctx->tx_work.work);
1140
+ tls_tx_records(sk, msg->msg_flags);
1141
+ }
1142
+
1143
+send_end:
1144
+ ret = sk_stream_error(sk, msg->msg_flags, ret);
1145
+
1146
+ release_sock(sk);
1147
+ mutex_unlock(&tls_ctx->tx_lock);
1148
+ return copied > 0 ? copied : ret;
1149
+}
1150
+
1151
+static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1152
+ int offset, size_t size, int flags)
1153
+{
1154
+ long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1155
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1156
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1157
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
1158
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
1159
+ struct sk_msg *msg_pl;
1160
+ struct tls_rec *rec;
1161
+ int num_async = 0;
1162
+ ssize_t copied = 0;
1163
+ bool full_record;
1164
+ int record_room;
1165
+ int ret = 0;
1166
+ bool eor;
1167
+
1168
+ eor = !(flags & MSG_SENDPAGE_NOTLAST);
1169
+ sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1170
+
1171
+ /* Call the sk_stream functions to manage the sndbuf mem. */
1172
+ while (size > 0) {
1173
+ size_t copy, required_size;
1174
+
1175
+ if (sk->sk_err) {
1176
+ ret = -sk->sk_err;
1177
+ goto sendpage_end;
1178
+ }
1179
+
1180
+ if (ctx->open_rec)
1181
+ rec = ctx->open_rec;
1182
+ else
1183
+ rec = ctx->open_rec = tls_get_rec(sk);
1184
+ if (!rec) {
1185
+ ret = -ENOMEM;
1186
+ goto sendpage_end;
1187
+ }
1188
+
1189
+ msg_pl = &rec->msg_plaintext;
1190
+
1191
+ full_record = false;
1192
+ record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1193
+ copy = size;
1194
+ if (copy >= record_room) {
1195
+ copy = record_room;
1196
+ full_record = true;
1197
+ }
1198
+
1199
+ required_size = msg_pl->sg.size + copy + prot->overhead_size;
1200
+
1201
+ if (!sk_stream_memory_free(sk))
1202
+ goto wait_for_sndbuf;
1203
+alloc_payload:
1204
+ ret = tls_alloc_encrypted_msg(sk, required_size);
1205
+ if (ret) {
1206
+ if (ret != -ENOSPC)
1207
+ goto wait_for_memory;
1208
+
1209
+ /* Adjust copy according to the amount that was
1210
+ * actually allocated. The difference is due
1211
+ * to max sg elements limit
1212
+ */
1213
+ copy -= required_size - msg_pl->sg.size;
1214
+ full_record = true;
1215
+ }
1216
+
1217
+ sk_msg_page_add(msg_pl, page, copy, offset);
1218
+ sk_mem_charge(sk, copy);
1219
+
1220
+ offset += copy;
1221
+ size -= copy;
1222
+ copied += copy;
1223
+
1224
+ tls_ctx->pending_open_record_frags = true;
1225
+ if (full_record || eor || sk_msg_full(msg_pl)) {
1226
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1227
+ record_type, &copied, flags);
1228
+ if (ret) {
1229
+ if (ret == -EINPROGRESS)
1230
+ num_async++;
1231
+ else if (ret == -ENOMEM)
1232
+ goto wait_for_memory;
1233
+ else if (ret != -EAGAIN) {
1234
+ if (ret == -ENOSPC)
1235
+ ret = 0;
1236
+ goto sendpage_end;
1237
+ }
1238
+ }
1239
+ }
1240
+ continue;
1241
+wait_for_sndbuf:
1242
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1243
+wait_for_memory:
1244
+ ret = sk_stream_wait_memory(sk, &timeo);
1245
+ if (ret) {
1246
+ if (ctx->open_rec)
1247
+ tls_trim_both_msgs(sk, msg_pl->sg.size);
1248
+ goto sendpage_end;
1249
+ }
1250
+
1251
+ if (ctx->open_rec)
1252
+ goto alloc_payload;
1253
+ }
1254
+
1255
+ if (num_async) {
1256
+ /* Transmit if any encryptions have completed */
1257
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1258
+ cancel_delayed_work(&ctx->tx_work.work);
1259
+ tls_tx_records(sk, flags);
1260
+ }
1261
+ }
1262
+sendpage_end:
1263
+ ret = sk_stream_error(sk, flags, ret);
1264
+ return copied > 0 ? copied : ret;
1265
+}
1266
+
1267
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1268
+ int offset, size_t size, int flags)
1269
+{
1270
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1271
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1272
+ MSG_NO_SHARED_FRAGS))
1273
+ return -EOPNOTSUPP;
1274
+
1275
+ return tls_sw_do_sendpage(sk, page, offset, size, flags);
1276
+}
1277
+
1278
+int tls_sw_sendpage(struct sock *sk, struct page *page,
1279
+ int offset, size_t size, int flags)
1280
+{
1281
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1282
+ int ret;
1283
+
1284
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1285
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1286
+ return -EOPNOTSUPP;
1287
+
1288
+ ret = mutex_lock_interruptible(&tls_ctx->tx_lock);
1289
+ if (ret)
1290
+ return ret;
1291
+ lock_sock(sk);
1292
+ ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1293
+ release_sock(sk);
1294
+ mutex_unlock(&tls_ctx->tx_lock);
1295
+ return ret;
1296
+}
1297
+
1298
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
1299
+ bool nonblock, long timeo, int *err)
1300
+{
1301
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1302
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1303
+ struct sk_buff *skb;
1304
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
1305
+
1306
+ while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
1307
+ if (sk->sk_err) {
1308
+ *err = sock_error(sk);
1309
+ return NULL;
1310
+ }
1311
+
1312
+ if (!skb_queue_empty(&sk->sk_receive_queue)) {
1313
+ __strp_unpause(&ctx->strp);
1314
+ if (ctx->recv_pkt)
1315
+ return ctx->recv_pkt;
1316
+ }
1317
+
1318
+ if (sk->sk_shutdown & RCV_SHUTDOWN)
1319
+ return NULL;
1320
+
1321
+ if (sock_flag(sk, SOCK_DONE))
1322
+ return NULL;
1323
+
1324
+ if (nonblock || !timeo) {
1325
+ *err = -EAGAIN;
1326
+ return NULL;
1327
+ }
1328
+
1329
+ add_wait_queue(sk_sleep(sk), &wait);
1330
+ sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1331
+ sk_wait_event(sk, &timeo,
1332
+ ctx->recv_pkt != skb ||
1333
+ !sk_psock_queue_empty(psock),
1334
+ &wait);
1335
+ sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1336
+ remove_wait_queue(sk_sleep(sk), &wait);
1337
+
1338
+ /* Handle signals */
1339
+ if (signal_pending(current)) {
1340
+ *err = sock_intr_errno(timeo);
1341
+ return NULL;
1342
+ }
1343
+ }
1344
+
1345
+ return skb;
1346
+}
1347
+
1348
+static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
1349
+ int length, int *pages_used,
1350
+ unsigned int *size_used,
1351
+ struct scatterlist *to,
1352
+ int to_max_pages)
1353
+{
1354
+ int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1355
+ struct page *pages[MAX_SKB_FRAGS];
2711356 unsigned int size = *size_used;
272
- int num_elem = *pages_used;
273
- int rc = 0;
274
- int maxpages;
1357
+ ssize_t copied, use;
1358
+ size_t offset;
2751359
2761360 while (length > 0) {
2771361 i = 0;
....@@ -298,17 +1382,15 @@
2981382 sg_set_page(&to[num_elem],
2991383 pages[i], use, offset);
3001384 sg_unmark_end(&to[num_elem]);
301
- if (charge)
302
- sk_mem_charge(sk, use);
1385
+ /* We do not uncharge memory from this API */
3031386
3041387 offset = 0;
3051388 copied -= use;
3061389
307
- ++i;
308
- ++num_elem;
1390
+ i++;
1391
+ num_elem++;
3091392 }
3101393 }
311
-
3121394 /* Mark the end in the last sg entry if newly added */
3131395 if (num_elem > *pages_used)
3141396 sg_mark_end(&to[num_elem - 1]);
....@@ -319,348 +1401,6 @@
3191401 *pages_used = num_elem;
3201402
3211403 return rc;
322
-}
323
-
324
-static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
325
- int bytes)
326
-{
327
- struct tls_context *tls_ctx = tls_get_ctx(sk);
328
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
329
- struct scatterlist *sg = ctx->sg_plaintext_data;
330
- int copy, i, rc = 0;
331
-
332
- for (i = tls_ctx->pending_open_record_frags;
333
- i < ctx->sg_plaintext_num_elem; ++i) {
334
- copy = sg[i].length;
335
- if (copy_from_iter(
336
- page_address(sg_page(&sg[i])) + sg[i].offset,
337
- copy, from) != copy) {
338
- rc = -EFAULT;
339
- goto out;
340
- }
341
- bytes -= copy;
342
-
343
- ++tls_ctx->pending_open_record_frags;
344
-
345
- if (!bytes)
346
- break;
347
- }
348
-
349
-out:
350
- return rc;
351
-}
352
-
353
-int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
354
-{
355
- struct tls_context *tls_ctx = tls_get_ctx(sk);
356
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
357
- int ret;
358
- int required_size;
359
- long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
360
- bool eor = !(msg->msg_flags & MSG_MORE);
361
- size_t try_to_copy, copied = 0;
362
- unsigned char record_type = TLS_RECORD_TYPE_DATA;
363
- int record_room;
364
- bool full_record;
365
- int orig_size;
366
- bool is_kvec = msg->msg_iter.type & ITER_KVEC;
367
-
368
- if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
369
- return -ENOTSUPP;
370
-
371
- lock_sock(sk);
372
-
373
- ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
374
- if (ret)
375
- goto send_end;
376
-
377
- if (unlikely(msg->msg_controllen)) {
378
- ret = tls_proccess_cmsg(sk, msg, &record_type);
379
- if (ret)
380
- goto send_end;
381
- }
382
-
383
- while (msg_data_left(msg)) {
384
- if (sk->sk_err) {
385
- ret = -sk->sk_err;
386
- goto send_end;
387
- }
388
-
389
- orig_size = ctx->sg_plaintext_size;
390
- full_record = false;
391
- try_to_copy = msg_data_left(msg);
392
- record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
393
- if (try_to_copy >= record_room) {
394
- try_to_copy = record_room;
395
- full_record = true;
396
- }
397
-
398
- required_size = ctx->sg_plaintext_size + try_to_copy +
399
- tls_ctx->tx.overhead_size;
400
-
401
- if (!sk_stream_memory_free(sk))
402
- goto wait_for_sndbuf;
403
-alloc_encrypted:
404
- ret = alloc_encrypted_sg(sk, required_size);
405
- if (ret) {
406
- if (ret != -ENOSPC)
407
- goto wait_for_memory;
408
-
409
- /* Adjust try_to_copy according to the amount that was
410
- * actually allocated. The difference is due
411
- * to max sg elements limit
412
- */
413
- try_to_copy -= required_size - ctx->sg_encrypted_size;
414
- full_record = true;
415
- }
416
- if (!is_kvec && (full_record || eor)) {
417
- ret = zerocopy_from_iter(sk, &msg->msg_iter,
418
- try_to_copy, &ctx->sg_plaintext_num_elem,
419
- &ctx->sg_plaintext_size,
420
- ctx->sg_plaintext_data,
421
- ARRAY_SIZE(ctx->sg_plaintext_data),
422
- true);
423
- if (ret)
424
- goto fallback_to_reg_send;
425
-
426
- copied += try_to_copy;
427
- ret = tls_push_record(sk, msg->msg_flags, record_type);
428
- if (ret)
429
- goto send_end;
430
- continue;
431
-
432
-fallback_to_reg_send:
433
- trim_sg(sk, ctx->sg_plaintext_data,
434
- &ctx->sg_plaintext_num_elem,
435
- &ctx->sg_plaintext_size,
436
- orig_size);
437
- }
438
-
439
- required_size = ctx->sg_plaintext_size + try_to_copy;
440
-alloc_plaintext:
441
- ret = alloc_plaintext_sg(sk, required_size);
442
- if (ret) {
443
- if (ret != -ENOSPC)
444
- goto wait_for_memory;
445
-
446
- /* Adjust try_to_copy according to the amount that was
447
- * actually allocated. The difference is due
448
- * to max sg elements limit
449
- */
450
- try_to_copy -= required_size - ctx->sg_plaintext_size;
451
- full_record = true;
452
-
453
- trim_sg(sk, ctx->sg_encrypted_data,
454
- &ctx->sg_encrypted_num_elem,
455
- &ctx->sg_encrypted_size,
456
- ctx->sg_plaintext_size +
457
- tls_ctx->tx.overhead_size);
458
- }
459
-
460
- ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
461
- if (ret)
462
- goto trim_sgl;
463
-
464
- copied += try_to_copy;
465
- if (full_record || eor) {
466
-push_record:
467
- ret = tls_push_record(sk, msg->msg_flags, record_type);
468
- if (ret) {
469
- if (ret == -ENOMEM)
470
- goto wait_for_memory;
471
-
472
- goto send_end;
473
- }
474
- }
475
-
476
- continue;
477
-
478
-wait_for_sndbuf:
479
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
480
-wait_for_memory:
481
- ret = sk_stream_wait_memory(sk, &timeo);
482
- if (ret) {
483
-trim_sgl:
484
- trim_both_sgl(sk, orig_size);
485
- goto send_end;
486
- }
487
-
488
- if (tls_is_pending_closed_record(tls_ctx))
489
- goto push_record;
490
-
491
- if (ctx->sg_encrypted_size < required_size)
492
- goto alloc_encrypted;
493
-
494
- goto alloc_plaintext;
495
- }
496
-
497
-send_end:
498
- ret = sk_stream_error(sk, msg->msg_flags, ret);
499
-
500
- release_sock(sk);
501
- return copied ? copied : ret;
502
-}
503
-
504
-int tls_sw_sendpage(struct sock *sk, struct page *page,
505
- int offset, size_t size, int flags)
506
-{
507
- struct tls_context *tls_ctx = tls_get_ctx(sk);
508
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
509
- int ret;
510
- long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
511
- bool eor;
512
- size_t orig_size = size;
513
- unsigned char record_type = TLS_RECORD_TYPE_DATA;
514
- struct scatterlist *sg;
515
- bool full_record;
516
- int record_room;
517
-
518
- if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
519
- MSG_SENDPAGE_NOTLAST))
520
- return -ENOTSUPP;
521
-
522
- /* No MSG_EOR from splice, only look at MSG_MORE */
523
- eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
524
-
525
- lock_sock(sk);
526
-
527
- sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
528
-
529
- ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
530
- if (ret)
531
- goto sendpage_end;
532
-
533
- /* Call the sk_stream functions to manage the sndbuf mem. */
534
- while (size > 0) {
535
- size_t copy, required_size;
536
-
537
- if (sk->sk_err) {
538
- ret = -sk->sk_err;
539
- goto sendpage_end;
540
- }
541
-
542
- full_record = false;
543
- record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
544
- copy = size;
545
- if (copy >= record_room) {
546
- copy = record_room;
547
- full_record = true;
548
- }
549
- required_size = ctx->sg_plaintext_size + copy +
550
- tls_ctx->tx.overhead_size;
551
-
552
- if (!sk_stream_memory_free(sk))
553
- goto wait_for_sndbuf;
554
-alloc_payload:
555
- ret = alloc_encrypted_sg(sk, required_size);
556
- if (ret) {
557
- if (ret != -ENOSPC)
558
- goto wait_for_memory;
559
-
560
- /* Adjust copy according to the amount that was
561
- * actually allocated. The difference is due
562
- * to max sg elements limit
563
- */
564
- copy -= required_size - ctx->sg_plaintext_size;
565
- full_record = true;
566
- }
567
-
568
- get_page(page);
569
- sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
570
- sg_set_page(sg, page, copy, offset);
571
- sg_unmark_end(sg);
572
-
573
- ctx->sg_plaintext_num_elem++;
574
-
575
- sk_mem_charge(sk, copy);
576
- offset += copy;
577
- size -= copy;
578
- ctx->sg_plaintext_size += copy;
579
- tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
580
-
581
- if (full_record || eor ||
582
- ctx->sg_plaintext_num_elem ==
583
- ARRAY_SIZE(ctx->sg_plaintext_data)) {
584
-push_record:
585
- ret = tls_push_record(sk, flags, record_type);
586
- if (ret) {
587
- if (ret == -ENOMEM)
588
- goto wait_for_memory;
589
-
590
- goto sendpage_end;
591
- }
592
- }
593
- continue;
594
-wait_for_sndbuf:
595
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
596
-wait_for_memory:
597
- ret = sk_stream_wait_memory(sk, &timeo);
598
- if (ret) {
599
- trim_both_sgl(sk, ctx->sg_plaintext_size);
600
- goto sendpage_end;
601
- }
602
-
603
- if (tls_is_pending_closed_record(tls_ctx))
604
- goto push_record;
605
-
606
- goto alloc_payload;
607
- }
608
-
609
-sendpage_end:
610
- if (orig_size > size)
611
- ret = orig_size - size;
612
- else
613
- ret = sk_stream_error(sk, flags, ret);
614
-
615
- release_sock(sk);
616
- return ret;
617
-}
618
-
619
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
620
- long timeo, int *err)
621
-{
622
- struct tls_context *tls_ctx = tls_get_ctx(sk);
623
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
624
- struct sk_buff *skb;
625
- DEFINE_WAIT_FUNC(wait, woken_wake_function);
626
-
627
- while (!(skb = ctx->recv_pkt)) {
628
- if (sk->sk_err) {
629
- *err = sock_error(sk);
630
- return NULL;
631
- }
632
-
633
- if (!skb_queue_empty(&sk->sk_receive_queue)) {
634
- __strp_unpause(&ctx->strp);
635
- if (ctx->recv_pkt)
636
- return ctx->recv_pkt;
637
- }
638
-
639
- if (sk->sk_shutdown & RCV_SHUTDOWN)
640
- return NULL;
641
-
642
- if (sock_flag(sk, SOCK_DONE))
643
- return NULL;
644
-
645
- if ((flags & MSG_DONTWAIT) || !timeo) {
646
- *err = -EAGAIN;
647
- return NULL;
648
- }
649
-
650
- add_wait_queue(sk_sleep(sk), &wait);
651
- sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
652
- sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
653
- sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
654
- remove_wait_queue(sk_sleep(sk), &wait);
655
-
656
- /* Handle signals */
657
- if (signal_pending(current)) {
658
- *err = sock_intr_errno(timeo);
659
- return NULL;
660
- }
661
- }
662
-
663
- return skb;
6641404 }
6651405
6661406 /* This function decrypts the input skb into either out_iov or in out_sg
....@@ -674,10 +1414,11 @@
6741414 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
6751415 struct iov_iter *out_iov,
6761416 struct scatterlist *out_sg,
677
- int *chunk, bool *zc)
1417
+ int *chunk, bool *zc, bool async)
6781418 {
6791419 struct tls_context *tls_ctx = tls_get_ctx(sk);
6801420 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1421
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
6811422 struct strp_msg *rxm = strp_msg(skb);
6821423 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
6831424 struct aead_request *aead_req;
....@@ -685,19 +1426,23 @@
6851426 u8 *aad, *iv, *mem = NULL;
6861427 struct scatterlist *sgin = NULL;
6871428 struct scatterlist *sgout = NULL;
688
- const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
1429
+ const int data_len = rxm->full_len - prot->overhead_size +
1430
+ prot->tail_size;
1431
+ int iv_offset = 0;
6891432
6901433 if (*zc && (out_iov || out_sg)) {
6911434 if (out_iov)
6921435 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
6931436 else
6941437 n_sgout = sg_nents(out_sg);
1438
+ n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1439
+ rxm->full_len - prot->prepend_size);
6951440 } else {
6961441 n_sgout = 0;
6971442 *zc = false;
1443
+ n_sgin = skb_cow_data(skb, 0, &unused);
6981444 }
6991445
700
- n_sgin = skb_cow_data(skb, 0, &unused);
7011446 if (n_sgin < 1)
7021447 return -EBADMSG;
7031448
....@@ -708,7 +1453,7 @@
7081453
7091454 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
7101455 mem_size = aead_size + (nsg * sizeof(struct scatterlist));
711
- mem_size = mem_size + TLS_AAD_SPACE_SIZE;
1456
+ mem_size = mem_size + prot->aad_size;
7121457 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
7131458
7141459 /* Allocate a single block of memory which contains
....@@ -724,29 +1469,42 @@
7241469 sgin = (struct scatterlist *)(mem + aead_size);
7251470 sgout = sgin + n_sgin;
7261471 aad = (u8 *)(sgout + n_sgout);
727
- iv = aad + TLS_AAD_SPACE_SIZE;
1472
+ iv = aad + prot->aad_size;
1473
+
1474
+ /* For CCM based ciphers, first byte of nonce+iv is always '2' */
1475
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
1476
+ iv[0] = 2;
1477
+ iv_offset = 1;
1478
+ }
7281479
7291480 /* Prepare IV */
7301481 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
731
- iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
732
- tls_ctx->rx.iv_size);
1482
+ iv + iv_offset + prot->salt_size,
1483
+ prot->iv_size);
7331484 if (err < 0) {
7341485 kfree(mem);
7351486 return err;
7361487 }
737
- memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1488
+ if (prot->version == TLS_1_3_VERSION)
1489
+ memcpy(iv + iv_offset, tls_ctx->rx.iv,
1490
+ prot->iv_size + prot->salt_size);
1491
+ else
1492
+ memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
1493
+
1494
+ xor_iv_with_seq(prot->version, iv + iv_offset, tls_ctx->rx.rec_seq);
7381495
7391496 /* Prepare AAD */
740
- tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
741
- tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
742
- ctx->control);
1497
+ tls_make_aad(aad, rxm->full_len - prot->overhead_size +
1498
+ prot->tail_size,
1499
+ tls_ctx->rx.rec_seq, prot->rec_seq_size,
1500
+ ctx->control, prot->version);
7431501
7441502 /* Prepare sgin */
7451503 sg_init_table(sgin, n_sgin);
746
- sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
1504
+ sg_set_buf(&sgin[0], aad, prot->aad_size);
7471505 err = skb_to_sgvec(skb, &sgin[1],
748
- rxm->offset + tls_ctx->rx.prepend_size,
749
- rxm->full_len - tls_ctx->rx.prepend_size);
1506
+ rxm->offset + prot->prepend_size,
1507
+ rxm->full_len - prot->prepend_size);
7501508 if (err < 0) {
7511509 kfree(mem);
7521510 return err;
....@@ -755,12 +1513,12 @@
7551513 if (n_sgout) {
7561514 if (out_iov) {
7571515 sg_init_table(sgout, n_sgout);
758
- sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
1516
+ sg_set_buf(&sgout[0], aad, prot->aad_size);
7591517
7601518 *chunk = 0;
761
- err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
762
- chunk, &sgout[1],
763
- (n_sgout - 1), false);
1519
+ err = tls_setup_from_iter(sk, out_iov, data_len,
1520
+ &pages, chunk, &sgout[1],
1521
+ (n_sgout - 1));
7641522 if (err < 0)
7651523 goto fallback_to_reg_recv;
7661524 } else if (out_sg) {
....@@ -772,12 +1530,15 @@
7721530 fallback_to_reg_recv:
7731531 sgout = sgin;
7741532 pages = 0;
775
- *chunk = 0;
1533
+ *chunk = data_len;
7761534 *zc = false;
7771535 }
7781536
7791537 /* Prepare and submit AEAD request */
780
- err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
1538
+ err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1539
+ data_len, aead_req, async);
1540
+ if (err == -EINPROGRESS)
1541
+ return err;
7811542
7821543 /* Release the pages in case iov was mapped to pages */
7831544 for (; pages > 0; pages--)
....@@ -788,31 +1549,52 @@
7881549 }
7891550
7901551 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
791
- struct iov_iter *dest, int *chunk, bool *zc)
1552
+ struct iov_iter *dest, int *chunk, bool *zc,
1553
+ bool async)
7921554 {
7931555 struct tls_context *tls_ctx = tls_get_ctx(sk);
7941556 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1557
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
7951558 struct strp_msg *rxm = strp_msg(skb);
796
- int err = 0;
1559
+ int pad, err = 0;
7971560
798
-#ifdef CONFIG_TLS_DEVICE
799
- err = tls_device_decrypted(sk, skb);
800
- if (err < 0)
801
- return err;
802
-#endif
8031561 if (!ctx->decrypted) {
804
- err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
805
- if (err < 0)
806
- return err;
1562
+ if (tls_ctx->rx_conf == TLS_HW) {
1563
+ err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
1564
+ if (err < 0)
1565
+ return err;
1566
+ }
1567
+
1568
+ /* Still not decrypted after tls_device */
1569
+ if (!ctx->decrypted) {
1570
+ err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
1571
+ async);
1572
+ if (err < 0) {
1573
+ if (err == -EINPROGRESS)
1574
+ tls_advance_record_sn(sk, prot,
1575
+ &tls_ctx->rx);
1576
+ else if (err == -EBADMSG)
1577
+ TLS_INC_STATS(sock_net(sk),
1578
+ LINUX_MIB_TLSDECRYPTERROR);
1579
+ return err;
1580
+ }
1581
+ } else {
1582
+ *zc = false;
1583
+ }
1584
+
1585
+ pad = padding_length(ctx, prot, skb);
1586
+ if (pad < 0)
1587
+ return pad;
1588
+
1589
+ rxm->full_len -= pad;
1590
+ rxm->offset += prot->prepend_size;
1591
+ rxm->full_len -= prot->overhead_size;
1592
+ tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1593
+ ctx->decrypted = 1;
1594
+ ctx->saved_data_ready(sk);
8071595 } else {
8081596 *zc = false;
8091597 }
810
-
811
- rxm->offset += tls_ctx->rx.prepend_size;
812
- rxm->full_len -= tls_ctx->rx.overhead_size;
813
- tls_advance_record_sn(sk, &tls_ctx->rx);
814
- ctx->decrypted = true;
815
- ctx->saved_data_ready(sk);
8161598
8171599 return err;
8181600 }
....@@ -823,7 +1605,7 @@
8231605 bool zc = true;
8241606 int chunk;
8251607
826
- return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
1608
+ return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
8271609 }
8281610
8291611 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
....@@ -831,21 +1613,132 @@
8311613 {
8321614 struct tls_context *tls_ctx = tls_get_ctx(sk);
8331615 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
834
- struct strp_msg *rxm = strp_msg(skb);
8351616
836
- if (len < rxm->full_len) {
837
- rxm->offset += len;
838
- rxm->full_len -= len;
1617
+ if (skb) {
1618
+ struct strp_msg *rxm = strp_msg(skb);
8391619
840
- return false;
1620
+ if (len < rxm->full_len) {
1621
+ rxm->offset += len;
1622
+ rxm->full_len -= len;
1623
+ return false;
1624
+ }
1625
+ consume_skb(skb);
8411626 }
8421627
8431628 /* Finished with message */
8441629 ctx->recv_pkt = NULL;
845
- kfree_skb(skb);
8461630 __strp_unpause(&ctx->strp);
8471631
8481632 return true;
1633
+}
1634
+
1635
+/* This function traverses the rx_list in tls receive context to copies the
1636
+ * decrypted records into the buffer provided by caller zero copy is not
1637
+ * true. Further, the records are removed from the rx_list if it is not a peek
1638
+ * case and the record has been consumed completely.
1639
+ */
1640
+static int process_rx_list(struct tls_sw_context_rx *ctx,
1641
+ struct msghdr *msg,
1642
+ u8 *control,
1643
+ bool *cmsg,
1644
+ size_t skip,
1645
+ size_t len,
1646
+ bool zc,
1647
+ bool is_peek)
1648
+{
1649
+ struct sk_buff *skb = skb_peek(&ctx->rx_list);
1650
+ u8 ctrl = *control;
1651
+ u8 msgc = *cmsg;
1652
+ struct tls_msg *tlm;
1653
+ ssize_t copied = 0;
1654
+
1655
+ /* Set the record type in 'control' if caller didn't pass it */
1656
+ if (!ctrl && skb) {
1657
+ tlm = tls_msg(skb);
1658
+ ctrl = tlm->control;
1659
+ }
1660
+
1661
+ while (skip && skb) {
1662
+ struct strp_msg *rxm = strp_msg(skb);
1663
+ tlm = tls_msg(skb);
1664
+
1665
+ /* Cannot process a record of different type */
1666
+ if (ctrl != tlm->control)
1667
+ return 0;
1668
+
1669
+ if (skip < rxm->full_len)
1670
+ break;
1671
+
1672
+ skip = skip - rxm->full_len;
1673
+ skb = skb_peek_next(skb, &ctx->rx_list);
1674
+ }
1675
+
1676
+ while (len && skb) {
1677
+ struct sk_buff *next_skb;
1678
+ struct strp_msg *rxm = strp_msg(skb);
1679
+ int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1680
+
1681
+ tlm = tls_msg(skb);
1682
+
1683
+ /* Cannot process a record of different type */
1684
+ if (ctrl != tlm->control)
1685
+ return 0;
1686
+
1687
+ /* Set record type if not already done. For a non-data record,
1688
+ * do not proceed if record type could not be copied.
1689
+ */
1690
+ if (!msgc) {
1691
+ int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1692
+ sizeof(ctrl), &ctrl);
1693
+ msgc = true;
1694
+ if (ctrl != TLS_RECORD_TYPE_DATA) {
1695
+ if (cerr || msg->msg_flags & MSG_CTRUNC)
1696
+ return -EIO;
1697
+
1698
+ *cmsg = msgc;
1699
+ }
1700
+ }
1701
+
1702
+ if (!zc || (rxm->full_len - skip) > len) {
1703
+ int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1704
+ msg, chunk);
1705
+ if (err < 0)
1706
+ return err;
1707
+ }
1708
+
1709
+ len = len - chunk;
1710
+ copied = copied + chunk;
1711
+
1712
+ /* Consume the data from record if it is non-peek case*/
1713
+ if (!is_peek) {
1714
+ rxm->offset = rxm->offset + chunk;
1715
+ rxm->full_len = rxm->full_len - chunk;
1716
+
1717
+ /* Return if there is unconsumed data in the record */
1718
+ if (rxm->full_len - skip)
1719
+ break;
1720
+ }
1721
+
1722
+ /* The remaining skip-bytes must lie in 1st record in rx_list.
1723
+ * So from the 2nd record, 'skip' should be 0.
1724
+ */
1725
+ skip = 0;
1726
+
1727
+ if (msg)
1728
+ msg->msg_flags |= MSG_EOR;
1729
+
1730
+ next_skb = skb_peek_next(skb, &ctx->rx_list);
1731
+
1732
+ if (!is_peek) {
1733
+ skb_unlink(skb, &ctx->rx_list);
1734
+ consume_skb(skb);
1735
+ }
1736
+
1737
+ skb = next_skb;
1738
+ }
1739
+
1740
+ *control = ctrl;
1741
+ return copied;
8491742 }
8501743
8511744 int tls_sw_recvmsg(struct sock *sk,
....@@ -857,104 +1750,241 @@
8571750 {
8581751 struct tls_context *tls_ctx = tls_get_ctx(sk);
8591752 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
860
- unsigned char control;
1753
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
1754
+ struct sk_psock *psock;
1755
+ unsigned char control = 0;
1756
+ ssize_t decrypted = 0;
8611757 struct strp_msg *rxm;
1758
+ struct tls_msg *tlm;
8621759 struct sk_buff *skb;
8631760 ssize_t copied = 0;
8641761 bool cmsg = false;
8651762 int target, err = 0;
8661763 long timeo;
867
- bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1764
+ bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1765
+ bool is_peek = flags & MSG_PEEK;
1766
+ bool bpf_strp_enabled;
1767
+ int num_async = 0;
1768
+ int pending;
8681769
8691770 flags |= nonblock;
8701771
8711772 if (unlikely(flags & MSG_ERRQUEUE))
8721773 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
8731774
1775
+ psock = sk_psock_get(sk);
8741776 lock_sock(sk);
1777
+ bpf_strp_enabled = sk_psock_strp_enabled(psock);
1778
+
1779
+ /* Process pending decrypted records. It must be non-zero-copy */
1780
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
1781
+ is_peek);
1782
+ if (err < 0) {
1783
+ tls_err_abort(sk, err);
1784
+ goto end;
1785
+ } else {
1786
+ copied = err;
1787
+ }
1788
+
1789
+ if (len <= copied)
1790
+ goto recv_end;
8751791
8761792 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1793
+ len = len - copied;
8771794 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
878
- do {
879
- bool zc = false;
880
- int chunk = 0;
8811795
882
- skb = tls_wait_data(sk, flags, timeo, &err);
883
- if (!skb)
1796
+ while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1797
+ bool retain_skb = false;
1798
+ bool zc = false;
1799
+ int to_decrypt;
1800
+ int chunk = 0;
1801
+ bool async_capable;
1802
+ bool async = false;
1803
+
1804
+ skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
1805
+ if (!skb) {
1806
+ if (psock) {
1807
+ int ret = __tcp_bpf_recvmsg(sk, psock,
1808
+ msg, len, flags);
1809
+
1810
+ if (ret > 0) {
1811
+ decrypted += ret;
1812
+ len -= ret;
1813
+ continue;
1814
+ }
1815
+ }
8841816 goto recv_end;
1817
+ } else {
1818
+ tlm = tls_msg(skb);
1819
+ if (prot->version == TLS_1_3_VERSION)
1820
+ tlm->control = 0;
1821
+ else
1822
+ tlm->control = ctx->control;
1823
+ }
8851824
8861825 rxm = strp_msg(skb);
1826
+
1827
+ to_decrypt = rxm->full_len - prot->overhead_size;
1828
+
1829
+ if (to_decrypt <= len && !is_kvec && !is_peek &&
1830
+ ctx->control == TLS_RECORD_TYPE_DATA &&
1831
+ prot->version != TLS_1_3_VERSION &&
1832
+ !bpf_strp_enabled)
1833
+ zc = true;
1834
+
1835
+ /* Do not use async mode if record is non-data */
1836
+ if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1837
+ async_capable = ctx->async_capable;
1838
+ else
1839
+ async_capable = false;
1840
+
1841
+ err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1842
+ &chunk, &zc, async_capable);
1843
+ if (err < 0 && err != -EINPROGRESS) {
1844
+ tls_err_abort(sk, -EBADMSG);
1845
+ goto recv_end;
1846
+ }
1847
+
1848
+ if (err == -EINPROGRESS) {
1849
+ async = true;
1850
+ num_async++;
1851
+ } else if (prot->version == TLS_1_3_VERSION) {
1852
+ tlm->control = ctx->control;
1853
+ }
1854
+
1855
+ /* If the type of records being processed is not known yet,
1856
+ * set it to record type just dequeued. If it is already known,
1857
+ * but does not match the record type just dequeued, go to end.
1858
+ * We always get record type here since for tls1.2, record type
1859
+ * is known just after record is dequeued from stream parser.
1860
+ * For tls1.3, we disable async.
1861
+ */
1862
+
1863
+ if (!control)
1864
+ control = tlm->control;
1865
+ else if (control != tlm->control)
1866
+ goto recv_end;
1867
+
8871868 if (!cmsg) {
8881869 int cerr;
8891870
8901871 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
891
- sizeof(ctx->control), &ctx->control);
1872
+ sizeof(control), &control);
8921873 cmsg = true;
893
- control = ctx->control;
894
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
1874
+ if (control != TLS_RECORD_TYPE_DATA) {
8951875 if (cerr || msg->msg_flags & MSG_CTRUNC) {
8961876 err = -EIO;
8971877 goto recv_end;
8981878 }
8991879 }
900
- } else if (control != ctx->control) {
901
- goto recv_end;
9021880 }
9031881
904
- if (!ctx->decrypted) {
905
- int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
906
-
907
- if (!is_kvec && to_copy <= len &&
908
- likely(!(flags & MSG_PEEK)))
909
- zc = true;
910
-
911
- err = decrypt_skb_update(sk, skb, &msg->msg_iter,
912
- &chunk, &zc);
913
- if (err < 0) {
914
- tls_err_abort(sk, EBADMSG);
915
- goto recv_end;
916
- }
917
- ctx->decrypted = true;
918
- }
1882
+ if (async)
1883
+ goto pick_next_record;
9191884
9201885 if (!zc) {
921
- chunk = min_t(unsigned int, rxm->full_len, len);
922
- err = skb_copy_datagram_msg(skb, rxm->offset, msg,
923
- chunk);
1886
+ if (bpf_strp_enabled) {
1887
+ err = sk_psock_tls_strp_read(psock, skb);
1888
+ if (err != __SK_PASS) {
1889
+ rxm->offset = rxm->offset + rxm->full_len;
1890
+ rxm->full_len = 0;
1891
+ if (err == __SK_DROP)
1892
+ consume_skb(skb);
1893
+ ctx->recv_pkt = NULL;
1894
+ __strp_unpause(&ctx->strp);
1895
+ continue;
1896
+ }
1897
+ }
1898
+
1899
+ if (rxm->full_len > len) {
1900
+ retain_skb = true;
1901
+ chunk = len;
1902
+ } else {
1903
+ chunk = rxm->full_len;
1904
+ }
1905
+
1906
+ err = skb_copy_datagram_msg(skb, rxm->offset,
1907
+ msg, chunk);
9241908 if (err < 0)
9251909 goto recv_end;
926
- }
9271910
928
- copied += chunk;
929
- len -= chunk;
930
- if (likely(!(flags & MSG_PEEK))) {
931
- u8 control = ctx->control;
932
-
933
- if (tls_sw_advance_skb(sk, skb, chunk)) {
934
- /* Return full control message to
935
- * userspace before trying to parse
936
- * another message type
937
- */
938
- msg->msg_flags |= MSG_EOR;
939
- if (control != TLS_RECORD_TYPE_DATA)
940
- goto recv_end;
1911
+ if (!is_peek) {
1912
+ rxm->offset = rxm->offset + chunk;
1913
+ rxm->full_len = rxm->full_len - chunk;
9411914 }
942
- } else {
943
- /* MSG_PEEK right now cannot look beyond current skb
944
- * from strparser, meaning we cannot advance skb here
945
- * and thus unpause strparser since we'd loose original
946
- * one.
947
- */
948
- break;
9491915 }
9501916
951
- /* If we have a new message from strparser, continue now. */
952
- if (copied >= target && !ctx->recv_pkt)
1917
+pick_next_record:
1918
+ if (chunk > len)
1919
+ chunk = len;
1920
+
1921
+ decrypted += chunk;
1922
+ len -= chunk;
1923
+
1924
+ /* For async or peek case, queue the current skb */
1925
+ if (async || is_peek || retain_skb) {
1926
+ skb_queue_tail(&ctx->rx_list, skb);
1927
+ skb = NULL;
1928
+ }
1929
+
1930
+ if (tls_sw_advance_skb(sk, skb, chunk)) {
1931
+ /* Return full control message to
1932
+ * userspace before trying to parse
1933
+ * another message type
1934
+ */
1935
+ msg->msg_flags |= MSG_EOR;
1936
+ if (control != TLS_RECORD_TYPE_DATA)
1937
+ goto recv_end;
1938
+ } else {
9531939 break;
954
- } while (len);
1940
+ }
1941
+ }
9551942
9561943 recv_end:
1944
+ if (num_async) {
1945
+ /* Wait for all previously submitted records to be decrypted */
1946
+ spin_lock_bh(&ctx->decrypt_compl_lock);
1947
+ ctx->async_notify = true;
1948
+ pending = atomic_read(&ctx->decrypt_pending);
1949
+ spin_unlock_bh(&ctx->decrypt_compl_lock);
1950
+ if (pending) {
1951
+ err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1952
+ if (err) {
1953
+ /* one of async decrypt failed */
1954
+ tls_err_abort(sk, err);
1955
+ copied = 0;
1956
+ decrypted = 0;
1957
+ goto end;
1958
+ }
1959
+ } else {
1960
+ reinit_completion(&ctx->async_wait.completion);
1961
+ }
1962
+
1963
+ /* There can be no concurrent accesses, since we have no
1964
+ * pending decrypt operations
1965
+ */
1966
+ WRITE_ONCE(ctx->async_notify, false);
1967
+
1968
+ /* Drain records from the rx_list & copy if required */
1969
+ if (is_peek || is_kvec)
1970
+ err = process_rx_list(ctx, msg, &control, &cmsg, copied,
1971
+ decrypted, false, is_peek);
1972
+ else
1973
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0,
1974
+ decrypted, true, is_peek);
1975
+ if (err < 0) {
1976
+ tls_err_abort(sk, err);
1977
+ copied = 0;
1978
+ goto end;
1979
+ }
1980
+ }
1981
+
1982
+ copied += decrypted;
1983
+
1984
+end:
9571985 release_sock(sk);
1986
+ if (psock)
1987
+ sk_psock_put(sk, psock);
9581988 return copied ? : err;
9591989 }
9601990
....@@ -975,27 +2005,24 @@
9752005
9762006 lock_sock(sk);
9772007
978
- timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
2008
+ timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
9792009
980
- skb = tls_wait_data(sk, flags, timeo, &err);
2010
+ skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
9812011 if (!skb)
9822012 goto splice_read_end;
9832013
984
- /* splice does not support reading control messages */
985
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
986
- err = -ENOTSUPP;
2014
+ err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
2015
+ if (err < 0) {
2016
+ tls_err_abort(sk, -EBADMSG);
9872017 goto splice_read_end;
9882018 }
9892019
990
- if (!ctx->decrypted) {
991
- err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
992
-
993
- if (err < 0) {
994
- tls_err_abort(sk, EBADMSG);
995
- goto splice_read_end;
996
- }
997
- ctx->decrypted = true;
2020
+ /* splice does not support reading control messages */
2021
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
2022
+ err = -EINVAL;
2023
+ goto splice_read_end;
9982024 }
2025
+
9992026 rxm = strp_msg(skb);
10002027
10012028 chunk = min_t(unsigned int, rxm->full_len, len);
....@@ -1011,29 +2038,28 @@
10112038 return copied ? : err;
10122039 }
10132040
1014
-unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1015
- struct poll_table_struct *wait)
2041
+bool tls_sw_stream_read(const struct sock *sk)
10162042 {
1017
- unsigned int ret;
1018
- struct sock *sk = sock->sk;
10192043 struct tls_context *tls_ctx = tls_get_ctx(sk);
10202044 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2045
+ bool ingress_empty = true;
2046
+ struct sk_psock *psock;
10212047
1022
- /* Grab POLLOUT and POLLHUP from the underlying socket */
1023
- ret = ctx->sk_poll(file, sock, wait);
2048
+ rcu_read_lock();
2049
+ psock = sk_psock(sk);
2050
+ if (psock)
2051
+ ingress_empty = list_empty(&psock->ingress_msg);
2052
+ rcu_read_unlock();
10242053
1025
- /* Clear POLLIN bits, and set based on recv_pkt */
1026
- ret &= ~(POLLIN | POLLRDNORM);
1027
- if (ctx->recv_pkt)
1028
- ret |= POLLIN | POLLRDNORM;
1029
-
1030
- return ret;
2054
+ return !ingress_empty || ctx->recv_pkt ||
2055
+ !skb_queue_empty(&ctx->rx_list);
10312056 }
10322057
10332058 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
10342059 {
10352060 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
10362061 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2062
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
10372063 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
10382064 struct strp_msg *rxm = strp_msg(skb);
10392065 size_t cipher_overhead;
....@@ -1041,17 +2067,17 @@
10412067 int ret;
10422068
10432069 /* Verify that we have a full TLS header, or wait for more data */
1044
- if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
2070
+ if (rxm->offset + prot->prepend_size > skb->len)
10452071 return 0;
10462072
10472073 /* Sanity-check size of on-stack buffer. */
1048
- if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
2074
+ if (WARN_ON(prot->prepend_size > sizeof(header))) {
10492075 ret = -EINVAL;
10502076 goto read_failure;
10512077 }
10522078
10532079 /* Linearize header to local buffer */
1054
- ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
2080
+ ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
10552081
10562082 if (ret < 0)
10572083 goto read_failure;
....@@ -1060,9 +2086,12 @@
10602086
10612087 data_len = ((header[4] & 0xFF) | (header[3] << 8));
10622088
1063
- cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
2089
+ cipher_overhead = prot->tag_size;
2090
+ if (prot->version != TLS_1_3_VERSION)
2091
+ cipher_overhead += prot->iv_size;
10642092
1065
- if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
2093
+ if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2094
+ prot->tail_size) {
10662095 ret = -EMSGSIZE;
10672096 goto read_failure;
10682097 }
....@@ -1071,16 +2100,15 @@
10712100 goto read_failure;
10722101 }
10732102
1074
- if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1075
- header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
2103
+ /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2104
+ if (header[1] != TLS_1_2_VERSION_MINOR ||
2105
+ header[2] != TLS_1_2_VERSION_MAJOR) {
10762106 ret = -EINVAL;
10772107 goto read_failure;
10782108 }
10792109
1080
-#ifdef CONFIG_TLS_DEVICE
1081
- handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1082
- *(u64*)tls_ctx->rx.rec_seq);
1083
-#endif
2110
+ tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2111
+ TCP_SKB_CB(skb)->seq + rxm->offset);
10842112 return data_len + TLS_HEADER_SIZE;
10852113
10862114 read_failure:
....@@ -1094,7 +2122,7 @@
10942122 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
10952123 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
10962124
1097
- ctx->decrypted = false;
2125
+ ctx->decrypted = 0;
10982126
10992127 ctx->recv_pkt = skb;
11002128 strp_pause(strp);
....@@ -1106,17 +2134,71 @@
11062134 {
11072135 struct tls_context *tls_ctx = tls_get_ctx(sk);
11082136 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2137
+ struct sk_psock *psock;
11092138
11102139 strp_data_ready(&ctx->strp);
2140
+
2141
+ psock = sk_psock_get(sk);
2142
+ if (psock) {
2143
+ if (!list_empty(&psock->ingress_msg))
2144
+ ctx->saved_data_ready(sk);
2145
+ sk_psock_put(sk, psock);
2146
+ }
11112147 }
11122148
1113
-void tls_sw_free_resources_tx(struct sock *sk)
2149
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2150
+{
2151
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2152
+
2153
+ set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2154
+ set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2155
+ cancel_delayed_work_sync(&ctx->tx_work.work);
2156
+}
2157
+
2158
+void tls_sw_release_resources_tx(struct sock *sk)
11142159 {
11152160 struct tls_context *tls_ctx = tls_get_ctx(sk);
11162161 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2162
+ struct tls_rec *rec, *tmp;
2163
+ int pending;
2164
+
2165
+ /* Wait for any pending async encryptions to complete */
2166
+ spin_lock_bh(&ctx->encrypt_compl_lock);
2167
+ ctx->async_notify = true;
2168
+ pending = atomic_read(&ctx->encrypt_pending);
2169
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
2170
+
2171
+ if (pending)
2172
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2173
+
2174
+ tls_tx_records(sk, -1);
2175
+
2176
+ /* Free up un-sent records in tx_list. First, free
2177
+ * the partially sent record if any at head of tx_list.
2178
+ */
2179
+ if (tls_ctx->partially_sent_record) {
2180
+ tls_free_partial_record(sk, tls_ctx);
2181
+ rec = list_first_entry(&ctx->tx_list,
2182
+ struct tls_rec, list);
2183
+ list_del(&rec->list);
2184
+ sk_msg_free(sk, &rec->msg_plaintext);
2185
+ kfree(rec);
2186
+ }
2187
+
2188
+ list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2189
+ list_del(&rec->list);
2190
+ sk_msg_free(sk, &rec->msg_encrypted);
2191
+ sk_msg_free(sk, &rec->msg_plaintext);
2192
+ kfree(rec);
2193
+ }
11172194
11182195 crypto_free_aead(ctx->aead_send);
1119
- tls_free_both_sg(sk);
2196
+ tls_free_open_rec(sk);
2197
+}
2198
+
2199
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2200
+{
2201
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
11202202
11212203 kfree(ctx);
11222204 }
....@@ -1132,38 +2214,116 @@
11322214 if (ctx->aead_recv) {
11332215 kfree_skb(ctx->recv_pkt);
11342216 ctx->recv_pkt = NULL;
2217
+ skb_queue_purge(&ctx->rx_list);
11352218 crypto_free_aead(ctx->aead_recv);
11362219 strp_stop(&ctx->strp);
1137
- write_lock_bh(&sk->sk_callback_lock);
1138
- sk->sk_data_ready = ctx->saved_data_ready;
1139
- write_unlock_bh(&sk->sk_callback_lock);
1140
- release_sock(sk);
1141
- strp_done(&ctx->strp);
1142
- lock_sock(sk);
2220
+ /* If tls_sw_strparser_arm() was not called (cleanup paths)
2221
+ * we still want to strp_stop(), but sk->sk_data_ready was
2222
+ * never swapped.
2223
+ */
2224
+ if (ctx->saved_data_ready) {
2225
+ write_lock_bh(&sk->sk_callback_lock);
2226
+ sk->sk_data_ready = ctx->saved_data_ready;
2227
+ write_unlock_bh(&sk->sk_callback_lock);
2228
+ }
11432229 }
2230
+}
2231
+
2232
+void tls_sw_strparser_done(struct tls_context *tls_ctx)
2233
+{
2234
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2235
+
2236
+ strp_done(&ctx->strp);
2237
+}
2238
+
2239
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2240
+{
2241
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2242
+
2243
+ kfree(ctx);
11442244 }
11452245
11462246 void tls_sw_free_resources_rx(struct sock *sk)
11472247 {
11482248 struct tls_context *tls_ctx = tls_get_ctx(sk);
1149
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
11502249
11512250 tls_sw_release_resources_rx(sk);
2251
+ tls_sw_free_ctx_rx(tls_ctx);
2252
+}
11522253
1153
- kfree(ctx);
2254
+/* The work handler to transmitt the encrypted records in tx_list */
2255
+static void tx_work_handler(struct work_struct *work)
2256
+{
2257
+ struct delayed_work *delayed_work = to_delayed_work(work);
2258
+ struct tx_work *tx_work = container_of(delayed_work,
2259
+ struct tx_work, work);
2260
+ struct sock *sk = tx_work->sk;
2261
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
2262
+ struct tls_sw_context_tx *ctx;
2263
+
2264
+ if (unlikely(!tls_ctx))
2265
+ return;
2266
+
2267
+ ctx = tls_sw_ctx_tx(tls_ctx);
2268
+ if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2269
+ return;
2270
+
2271
+ if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2272
+ return;
2273
+
2274
+ if (mutex_trylock(&tls_ctx->tx_lock)) {
2275
+ lock_sock(sk);
2276
+ tls_tx_records(sk, -1);
2277
+ release_sock(sk);
2278
+ mutex_unlock(&tls_ctx->tx_lock);
2279
+ } else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
2280
+ /* Someone is holding the tx_lock, they will likely run Tx
2281
+ * and cancel the work on their way out of the lock section.
2282
+ * Schedule a long delay just in case.
2283
+ */
2284
+ schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10));
2285
+ }
2286
+}
2287
+
2288
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2289
+{
2290
+ struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2291
+
2292
+ /* Schedule the transmission if tx list is ready */
2293
+ if (is_tx_ready(tx_ctx) &&
2294
+ !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2295
+ schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2296
+}
2297
+
2298
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2299
+{
2300
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2301
+
2302
+ write_lock_bh(&sk->sk_callback_lock);
2303
+ rx_ctx->saved_data_ready = sk->sk_data_ready;
2304
+ sk->sk_data_ready = tls_data_ready;
2305
+ write_unlock_bh(&sk->sk_callback_lock);
2306
+
2307
+ strp_check_rcv(&rx_ctx->strp);
11542308 }
11552309
11562310 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
11572311 {
2312
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
2313
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
11582314 struct tls_crypto_info *crypto_info;
11592315 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2316
+ struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2317
+ struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
11602318 struct tls_sw_context_tx *sw_ctx_tx = NULL;
11612319 struct tls_sw_context_rx *sw_ctx_rx = NULL;
11622320 struct cipher_context *cctx;
11632321 struct crypto_aead **aead;
11642322 struct strp_callbacks cb;
1165
- u16 nonce_size, tag_size, iv_size, rec_seq_size;
1166
- char *iv, *rec_seq;
2323
+ u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2324
+ struct crypto_tfm *tfm;
2325
+ char *iv, *rec_seq, *key, *salt, *cipher_name;
2326
+ size_t keysize;
11672327 int rc = 0;
11682328
11692329 if (!ctx) {
....@@ -1199,13 +2359,19 @@
11992359
12002360 if (tx) {
12012361 crypto_init_wait(&sw_ctx_tx->async_wait);
2362
+ spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
12022363 crypto_info = &ctx->crypto_send.info;
12032364 cctx = &ctx->tx;
12042365 aead = &sw_ctx_tx->aead_send;
2366
+ INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2367
+ INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2368
+ sw_ctx_tx->tx_work.sk = sk;
12052369 } else {
12062370 crypto_init_wait(&sw_ctx_rx->async_wait);
2371
+ spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
12072372 crypto_info = &ctx->crypto_recv.info;
12082373 cctx = &ctx->rx;
2374
+ skb_queue_head_init(&sw_ctx_rx->rx_list);
12092375 aead = &sw_ctx_rx->aead_recv;
12102376 }
12112377
....@@ -1220,6 +2386,45 @@
12202386 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
12212387 gcm_128_info =
12222388 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
2389
+ keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2390
+ key = gcm_128_info->key;
2391
+ salt = gcm_128_info->salt;
2392
+ salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2393
+ cipher_name = "gcm(aes)";
2394
+ break;
2395
+ }
2396
+ case TLS_CIPHER_AES_GCM_256: {
2397
+ nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2398
+ tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2399
+ iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2400
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
2401
+ rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2402
+ rec_seq =
2403
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
2404
+ gcm_256_info =
2405
+ (struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
2406
+ keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2407
+ key = gcm_256_info->key;
2408
+ salt = gcm_256_info->salt;
2409
+ salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2410
+ cipher_name = "gcm(aes)";
2411
+ break;
2412
+ }
2413
+ case TLS_CIPHER_AES_CCM_128: {
2414
+ nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2415
+ tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2416
+ iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2417
+ iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
2418
+ rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2419
+ rec_seq =
2420
+ ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
2421
+ ccm_128_info =
2422
+ (struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
2423
+ keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2424
+ key = ccm_128_info->key;
2425
+ salt = ccm_128_info->salt;
2426
+ salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2427
+ cipher_name = "ccm(aes)";
12232428 break;
12242429 }
12252430 default:
....@@ -1227,53 +2432,47 @@
12272432 goto free_priv;
12282433 }
12292434
1230
- /* Sanity-check the IV size for stack allocations. */
1231
- if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
2435
+ /* Sanity-check the sizes for stack allocations. */
2436
+ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2437
+ rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
12322438 rc = -EINVAL;
12332439 goto free_priv;
12342440 }
12352441
1236
- cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1237
- cctx->tag_size = tag_size;
1238
- cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1239
- cctx->iv_size = iv_size;
1240
- cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1241
- GFP_KERNEL);
2442
+ if (crypto_info->version == TLS_1_3_VERSION) {
2443
+ nonce_size = 0;
2444
+ prot->aad_size = TLS_HEADER_SIZE;
2445
+ prot->tail_size = 1;
2446
+ } else {
2447
+ prot->aad_size = TLS_AAD_SPACE_SIZE;
2448
+ prot->tail_size = 0;
2449
+ }
2450
+
2451
+ prot->version = crypto_info->version;
2452
+ prot->cipher_type = crypto_info->cipher_type;
2453
+ prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2454
+ prot->tag_size = tag_size;
2455
+ prot->overhead_size = prot->prepend_size +
2456
+ prot->tag_size + prot->tail_size;
2457
+ prot->iv_size = iv_size;
2458
+ prot->salt_size = salt_size;
2459
+ cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
12422460 if (!cctx->iv) {
12432461 rc = -ENOMEM;
12442462 goto free_priv;
12452463 }
1246
- memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1247
- memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1248
- cctx->rec_seq_size = rec_seq_size;
2464
+ /* Note: 128 & 256 bit salt are the same size */
2465
+ prot->rec_seq_size = rec_seq_size;
2466
+ memcpy(cctx->iv, salt, salt_size);
2467
+ memcpy(cctx->iv + salt_size, iv, iv_size);
12492468 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
12502469 if (!cctx->rec_seq) {
12512470 rc = -ENOMEM;
12522471 goto free_iv;
12532472 }
12542473
1255
- if (sw_ctx_tx) {
1256
- sg_init_table(sw_ctx_tx->sg_encrypted_data,
1257
- ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1258
- sg_init_table(sw_ctx_tx->sg_plaintext_data,
1259
- ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1260
-
1261
- sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1262
- sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1263
- sizeof(sw_ctx_tx->aad_space));
1264
- sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1265
- sg_chain(sw_ctx_tx->sg_aead_in, 2,
1266
- sw_ctx_tx->sg_plaintext_data);
1267
- sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1268
- sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1269
- sizeof(sw_ctx_tx->aad_space));
1270
- sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1271
- sg_chain(sw_ctx_tx->sg_aead_out, 2,
1272
- sw_ctx_tx->sg_encrypted_data);
1273
- }
1274
-
12752474 if (!*aead) {
1276
- *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
2475
+ *aead = crypto_alloc_aead(cipher_name, 0, 0);
12772476 if (IS_ERR(*aead)) {
12782477 rc = PTR_ERR(*aead);
12792478 *aead = NULL;
....@@ -1283,31 +2482,31 @@
12832482
12842483 ctx->push_pending_record = tls_sw_push_pending_record;
12852484
1286
- rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1287
- TLS_CIPHER_AES_GCM_128_KEY_SIZE);
2485
+ rc = crypto_aead_setkey(*aead, key, keysize);
2486
+
12882487 if (rc)
12892488 goto free_aead;
12902489
1291
- rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
2490
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
12922491 if (rc)
12932492 goto free_aead;
12942493
12952494 if (sw_ctx_rx) {
2495
+ tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2496
+
2497
+ if (crypto_info->version == TLS_1_3_VERSION)
2498
+ sw_ctx_rx->async_capable = 0;
2499
+ else
2500
+ sw_ctx_rx->async_capable =
2501
+ !!(tfm->__crt_alg->cra_flags &
2502
+ CRYPTO_ALG_ASYNC);
2503
+
12962504 /* Set up strparser */
12972505 memset(&cb, 0, sizeof(cb));
12982506 cb.rcv_msg = tls_queue;
12992507 cb.parse_msg = tls_read_size;
13002508
13012509 strp_init(&sw_ctx_rx->strp, sk, &cb);
1302
-
1303
- write_lock_bh(&sk->sk_callback_lock);
1304
- sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1305
- sk->sk_data_ready = tls_data_ready;
1306
- write_unlock_bh(&sk->sk_callback_lock);
1307
-
1308
- sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1309
-
1310
- strp_check_rcv(&sw_ctx_rx->strp);
13112510 }
13122511
13132512 goto out;