hc
2023-12-11 d2ccde1c8e90d38cee87a1b0309ad2827f3fd30d
kernel/net/tls/tls_sw.c
....@@ -4,6 +4,7 @@
44 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
55 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
66 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
78 *
89 * This software is available to you under a choice of one of two
910 * licenses. You may choose to be licensed under the terms of the GNU
....@@ -34,244 +35,1323 @@
3435 * SOFTWARE.
3536 */
3637
38
+#include <linux/bug.h>
3739 #include <linux/sched/signal.h>
3840 #include <linux/module.h>
41
+#include <linux/splice.h>
3942 #include <crypto/aead.h>
4043
4144 #include <net/strparser.h>
4245 #include <net/tls.h>
4346
44
-#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
47
+noinline void tls_err_abort(struct sock *sk, int err)
48
+{
49
+ WARN_ON_ONCE(err >= 0);
50
+ /* sk->sk_err should contain a positive error code. */
51
+ sk->sk_err = -err;
52
+ sk->sk_error_report(sk);
53
+}
54
+
55
+static int __skb_nsg(struct sk_buff *skb, int offset, int len,
56
+ unsigned int recursion_level)
57
+{
58
+ int start = skb_headlen(skb);
59
+ int i, chunk = start - offset;
60
+ struct sk_buff *frag_iter;
61
+ int elt = 0;
62
+
63
+ if (unlikely(recursion_level >= 24))
64
+ return -EMSGSIZE;
65
+
66
+ if (chunk > 0) {
67
+ if (chunk > len)
68
+ chunk = len;
69
+ elt++;
70
+ len -= chunk;
71
+ if (len == 0)
72
+ return elt;
73
+ offset += chunk;
74
+ }
75
+
76
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
77
+ int end;
78
+
79
+ WARN_ON(start > offset + len);
80
+
81
+ end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
82
+ chunk = end - offset;
83
+ if (chunk > 0) {
84
+ if (chunk > len)
85
+ chunk = len;
86
+ elt++;
87
+ len -= chunk;
88
+ if (len == 0)
89
+ return elt;
90
+ offset += chunk;
91
+ }
92
+ start = end;
93
+ }
94
+
95
+ if (unlikely(skb_has_frag_list(skb))) {
96
+ skb_walk_frags(skb, frag_iter) {
97
+ int end, ret;
98
+
99
+ WARN_ON(start > offset + len);
100
+
101
+ end = start + frag_iter->len;
102
+ chunk = end - offset;
103
+ if (chunk > 0) {
104
+ if (chunk > len)
105
+ chunk = len;
106
+ ret = __skb_nsg(frag_iter, offset - start, chunk,
107
+ recursion_level + 1);
108
+ if (unlikely(ret < 0))
109
+ return ret;
110
+ elt += ret;
111
+ len -= chunk;
112
+ if (len == 0)
113
+ return elt;
114
+ offset += chunk;
115
+ }
116
+ start = end;
117
+ }
118
+ }
119
+ BUG_ON(len);
120
+ return elt;
121
+}
122
+
123
+/* Return the number of scatterlist elements required to completely map the
124
+ * skb, or -EMSGSIZE if the recursion depth is exceeded.
125
+ */
126
+static int skb_nsg(struct sk_buff *skb, int offset, int len)
127
+{
128
+ return __skb_nsg(skb, offset, len, 0);
129
+}
130
+
131
+static int padding_length(struct tls_sw_context_rx *ctx,
132
+ struct tls_prot_info *prot, struct sk_buff *skb)
133
+{
134
+ struct strp_msg *rxm = strp_msg(skb);
135
+ int sub = 0;
136
+
137
+ /* Determine zero-padding length */
138
+ if (prot->version == TLS_1_3_VERSION) {
139
+ char content_type = 0;
140
+ int err;
141
+ int back = 17;
142
+
143
+ while (content_type == 0) {
144
+ if (back > rxm->full_len - prot->prepend_size)
145
+ return -EBADMSG;
146
+ err = skb_copy_bits(skb,
147
+ rxm->offset + rxm->full_len - back,
148
+ &content_type, 1);
149
+ if (err)
150
+ return err;
151
+ if (content_type)
152
+ break;
153
+ sub++;
154
+ back++;
155
+ }
156
+ ctx->control = content_type;
157
+ }
158
+ return sub;
159
+}
160
+
161
+static void tls_decrypt_done(struct crypto_async_request *req, int err)
162
+{
163
+ struct aead_request *aead_req = (struct aead_request *)req;
164
+ struct scatterlist *sgout = aead_req->dst;
165
+ struct scatterlist *sgin = aead_req->src;
166
+ struct tls_sw_context_rx *ctx;
167
+ struct tls_context *tls_ctx;
168
+ struct tls_prot_info *prot;
169
+ struct scatterlist *sg;
170
+ struct sk_buff *skb;
171
+ unsigned int pages;
172
+ int pending;
173
+
174
+ skb = (struct sk_buff *)req->data;
175
+ tls_ctx = tls_get_ctx(skb->sk);
176
+ ctx = tls_sw_ctx_rx(tls_ctx);
177
+ prot = &tls_ctx->prot_info;
178
+
179
+ /* Propagate if there was an err */
180
+ if (err) {
181
+ if (err == -EBADMSG)
182
+ TLS_INC_STATS(sock_net(skb->sk),
183
+ LINUX_MIB_TLSDECRYPTERROR);
184
+ ctx->async_wait.err = err;
185
+ tls_err_abort(skb->sk, err);
186
+ } else {
187
+ struct strp_msg *rxm = strp_msg(skb);
188
+ int pad;
189
+
190
+ pad = padding_length(ctx, prot, skb);
191
+ if (pad < 0) {
192
+ ctx->async_wait.err = pad;
193
+ tls_err_abort(skb->sk, pad);
194
+ } else {
195
+ rxm->full_len -= pad;
196
+ rxm->offset += prot->prepend_size;
197
+ rxm->full_len -= prot->overhead_size;
198
+ }
199
+ }
200
+
201
+ /* After using skb->sk to propagate sk through crypto async callback
202
+ * we need to NULL it again.
203
+ */
204
+ skb->sk = NULL;
205
+
206
+
207
+ /* Free the destination pages if skb was not decrypted inplace */
208
+ if (sgout != sgin) {
209
+ /* Skip the first S/G entry as it points to AAD */
210
+ for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
211
+ if (!sg)
212
+ break;
213
+ put_page(sg_page(sg));
214
+ }
215
+ }
216
+
217
+ kfree(aead_req);
218
+
219
+ spin_lock_bh(&ctx->decrypt_compl_lock);
220
+ pending = atomic_dec_return(&ctx->decrypt_pending);
221
+
222
+ if (!pending && ctx->async_notify)
223
+ complete(&ctx->async_wait.completion);
224
+ spin_unlock_bh(&ctx->decrypt_compl_lock);
225
+}
45226
46227 static int tls_do_decryption(struct sock *sk,
228
+ struct sk_buff *skb,
47229 struct scatterlist *sgin,
48230 struct scatterlist *sgout,
49231 char *iv_recv,
50232 size_t data_len,
51
- struct aead_request *aead_req)
233
+ struct aead_request *aead_req,
234
+ bool async)
52235 {
53236 struct tls_context *tls_ctx = tls_get_ctx(sk);
237
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
54238 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55239 int ret;
56240
57241 aead_request_set_tfm(aead_req, ctx->aead_recv);
58
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
242
+ aead_request_set_ad(aead_req, prot->aad_size);
59243 aead_request_set_crypt(aead_req, sgin, sgout,
60
- data_len + tls_ctx->rx.tag_size,
244
+ data_len + prot->tag_size,
61245 (u8 *)iv_recv);
62
- aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63
- crypto_req_done, &ctx->async_wait);
64246
65
- ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
247
+ if (async) {
248
+ /* Using skb->sk to push sk through to crypto async callback
249
+ * handler. This allows propagating errors up to the socket
250
+ * if needed. It _must_ be cleared in the async handler
251
+ * before consume_skb is called. We _know_ skb->sk is NULL
252
+ * because it is a clone from strparser.
253
+ */
254
+ skb->sk = sk;
255
+ aead_request_set_callback(aead_req,
256
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
257
+ tls_decrypt_done, skb);
258
+ atomic_inc(&ctx->decrypt_pending);
259
+ } else {
260
+ aead_request_set_callback(aead_req,
261
+ CRYPTO_TFM_REQ_MAY_BACKLOG,
262
+ crypto_req_done, &ctx->async_wait);
263
+ }
264
+
265
+ ret = crypto_aead_decrypt(aead_req);
266
+ if (ret == -EINPROGRESS) {
267
+ if (async)
268
+ return ret;
269
+
270
+ ret = crypto_wait_req(ret, &ctx->async_wait);
271
+ }
272
+
273
+ if (async)
274
+ atomic_dec(&ctx->decrypt_pending);
275
+
66276 return ret;
67277 }
68278
69
-static void trim_sg(struct sock *sk, struct scatterlist *sg,
70
- int *sg_num_elem, unsigned int *sg_size, int target_size)
71
-{
72
- int i = *sg_num_elem - 1;
73
- int trim = *sg_size - target_size;
74
-
75
- if (trim <= 0) {
76
- WARN_ON(trim < 0);
77
- return;
78
- }
79
-
80
- *sg_size = target_size;
81
- while (trim >= sg[i].length) {
82
- trim -= sg[i].length;
83
- sk_mem_uncharge(sk, sg[i].length);
84
- put_page(sg_page(&sg[i]));
85
- i--;
86
-
87
- if (i < 0)
88
- goto out;
89
- }
90
-
91
- sg[i].length -= trim;
92
- sk_mem_uncharge(sk, trim);
93
-
94
-out:
95
- *sg_num_elem = i + 1;
96
-}
97
-
98
-static void trim_both_sgl(struct sock *sk, int target_size)
279
+static void tls_trim_both_msgs(struct sock *sk, int target_size)
99280 {
100281 struct tls_context *tls_ctx = tls_get_ctx(sk);
282
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
101283 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
284
+ struct tls_rec *rec = ctx->open_rec;
102285
103
- trim_sg(sk, ctx->sg_plaintext_data,
104
- &ctx->sg_plaintext_num_elem,
105
- &ctx->sg_plaintext_size,
106
- target_size);
107
-
286
+ sk_msg_trim(sk, &rec->msg_plaintext, target_size);
108287 if (target_size > 0)
109
- target_size += tls_ctx->tx.overhead_size;
110
-
111
- trim_sg(sk, ctx->sg_encrypted_data,
112
- &ctx->sg_encrypted_num_elem,
113
- &ctx->sg_encrypted_size,
114
- target_size);
288
+ target_size += prot->overhead_size;
289
+ sk_msg_trim(sk, &rec->msg_encrypted, target_size);
115290 }
116291
117
-static int alloc_encrypted_sg(struct sock *sk, int len)
292
+static int tls_alloc_encrypted_msg(struct sock *sk, int len)
118293 {
119294 struct tls_context *tls_ctx = tls_get_ctx(sk);
120295 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121
- int rc = 0;
296
+ struct tls_rec *rec = ctx->open_rec;
297
+ struct sk_msg *msg_en = &rec->msg_encrypted;
122298
123
- rc = sk_alloc_sg(sk, len,
124
- ctx->sg_encrypted_data, 0,
125
- &ctx->sg_encrypted_num_elem,
126
- &ctx->sg_encrypted_size, 0);
127
-
128
- if (rc == -ENOSPC)
129
- ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
130
-
131
- return rc;
299
+ return sk_msg_alloc(sk, msg_en, len, 0);
132300 }
133301
134
-static int alloc_plaintext_sg(struct sock *sk, int len)
302
+static int tls_clone_plaintext_msg(struct sock *sk, int required)
303
+{
304
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
305
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
306
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
307
+ struct tls_rec *rec = ctx->open_rec;
308
+ struct sk_msg *msg_pl = &rec->msg_plaintext;
309
+ struct sk_msg *msg_en = &rec->msg_encrypted;
310
+ int skip, len;
311
+
312
+ /* We add page references worth len bytes from encrypted sg
313
+ * at the end of plaintext sg. It is guaranteed that msg_en
314
+ * has enough required room (ensured by caller).
315
+ */
316
+ len = required - msg_pl->sg.size;
317
+
318
+ /* Skip initial bytes in msg_en's data to be able to use
319
+ * same offset of both plain and encrypted data.
320
+ */
321
+ skip = prot->prepend_size + msg_pl->sg.size;
322
+
323
+ return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
324
+}
325
+
326
+static struct tls_rec *tls_get_rec(struct sock *sk)
327
+{
328
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
329
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
330
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
331
+ struct sk_msg *msg_pl, *msg_en;
332
+ struct tls_rec *rec;
333
+ int mem_size;
334
+
335
+ mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
336
+
337
+ rec = kzalloc(mem_size, sk->sk_allocation);
338
+ if (!rec)
339
+ return NULL;
340
+
341
+ msg_pl = &rec->msg_plaintext;
342
+ msg_en = &rec->msg_encrypted;
343
+
344
+ sk_msg_init(msg_pl);
345
+ sk_msg_init(msg_en);
346
+
347
+ sg_init_table(rec->sg_aead_in, 2);
348
+ sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
349
+ sg_unmark_end(&rec->sg_aead_in[1]);
350
+
351
+ sg_init_table(rec->sg_aead_out, 2);
352
+ sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
353
+ sg_unmark_end(&rec->sg_aead_out[1]);
354
+
355
+ return rec;
356
+}
357
+
358
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
359
+{
360
+ sk_msg_free(sk, &rec->msg_encrypted);
361
+ sk_msg_free(sk, &rec->msg_plaintext);
362
+ kfree(rec);
363
+}
364
+
365
+static void tls_free_open_rec(struct sock *sk)
135366 {
136367 struct tls_context *tls_ctx = tls_get_ctx(sk);
137368 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
138
- int rc = 0;
369
+ struct tls_rec *rec = ctx->open_rec;
139370
140
- rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
141
- &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
142
- tls_ctx->pending_open_record_frags);
143
-
144
- if (rc == -ENOSPC)
145
- ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
146
-
147
- return rc;
148
-}
149
-
150
-static void free_sg(struct sock *sk, struct scatterlist *sg,
151
- int *sg_num_elem, unsigned int *sg_size)
152
-{
153
- int i, n = *sg_num_elem;
154
-
155
- for (i = 0; i < n; ++i) {
156
- sk_mem_uncharge(sk, sg[i].length);
157
- put_page(sg_page(&sg[i]));
371
+ if (rec) {
372
+ tls_free_rec(sk, rec);
373
+ ctx->open_rec = NULL;
158374 }
159
- *sg_num_elem = 0;
160
- *sg_size = 0;
161375 }
162376
163
-static void tls_free_both_sg(struct sock *sk)
377
+int tls_tx_records(struct sock *sk, int flags)
164378 {
165379 struct tls_context *tls_ctx = tls_get_ctx(sk);
166380 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
381
+ struct tls_rec *rec, *tmp;
382
+ struct sk_msg *msg_en;
383
+ int tx_flags, rc = 0;
167384
168
- free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
169
- &ctx->sg_encrypted_size);
385
+ if (tls_is_partially_sent_record(tls_ctx)) {
386
+ rec = list_first_entry(&ctx->tx_list,
387
+ struct tls_rec, list);
170388
171
- free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
172
- &ctx->sg_plaintext_size);
389
+ if (flags == -1)
390
+ tx_flags = rec->tx_flags;
391
+ else
392
+ tx_flags = flags;
393
+
394
+ rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
395
+ if (rc)
396
+ goto tx_err;
397
+
398
+ /* Full record has been transmitted.
399
+ * Remove the head of tx_list
400
+ */
401
+ list_del(&rec->list);
402
+ sk_msg_free(sk, &rec->msg_plaintext);
403
+ kfree(rec);
404
+ }
405
+
406
+ /* Tx all ready records */
407
+ list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
408
+ if (READ_ONCE(rec->tx_ready)) {
409
+ if (flags == -1)
410
+ tx_flags = rec->tx_flags;
411
+ else
412
+ tx_flags = flags;
413
+
414
+ msg_en = &rec->msg_encrypted;
415
+ rc = tls_push_sg(sk, tls_ctx,
416
+ &msg_en->sg.data[msg_en->sg.curr],
417
+ 0, tx_flags);
418
+ if (rc)
419
+ goto tx_err;
420
+
421
+ list_del(&rec->list);
422
+ sk_msg_free(sk, &rec->msg_plaintext);
423
+ kfree(rec);
424
+ } else {
425
+ break;
426
+ }
427
+ }
428
+
429
+tx_err:
430
+ if (rc < 0 && rc != -EAGAIN)
431
+ tls_err_abort(sk, -EBADMSG);
432
+
433
+ return rc;
173434 }
174435
175
-static int tls_do_encryption(struct tls_context *tls_ctx,
436
+static void tls_encrypt_done(struct crypto_async_request *req, int err)
437
+{
438
+ struct aead_request *aead_req = (struct aead_request *)req;
439
+ struct sock *sk = req->data;
440
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
441
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
442
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
443
+ struct scatterlist *sge;
444
+ struct sk_msg *msg_en;
445
+ struct tls_rec *rec;
446
+ bool ready = false;
447
+ int pending;
448
+
449
+ rec = container_of(aead_req, struct tls_rec, aead_req);
450
+ msg_en = &rec->msg_encrypted;
451
+
452
+ sge = sk_msg_elem(msg_en, msg_en->sg.curr);
453
+ sge->offset -= prot->prepend_size;
454
+ sge->length += prot->prepend_size;
455
+
456
+ /* Check if error is previously set on socket */
457
+ if (err || sk->sk_err) {
458
+ rec = NULL;
459
+
460
+ /* If err is already set on socket, return the same code */
461
+ if (sk->sk_err) {
462
+ ctx->async_wait.err = -sk->sk_err;
463
+ } else {
464
+ ctx->async_wait.err = err;
465
+ tls_err_abort(sk, err);
466
+ }
467
+ }
468
+
469
+ if (rec) {
470
+ struct tls_rec *first_rec;
471
+
472
+ /* Mark the record as ready for transmission */
473
+ smp_store_mb(rec->tx_ready, true);
474
+
475
+ /* If received record is at head of tx_list, schedule tx */
476
+ first_rec = list_first_entry(&ctx->tx_list,
477
+ struct tls_rec, list);
478
+ if (rec == first_rec)
479
+ ready = true;
480
+ }
481
+
482
+ spin_lock_bh(&ctx->encrypt_compl_lock);
483
+ pending = atomic_dec_return(&ctx->encrypt_pending);
484
+
485
+ if (!pending && ctx->async_notify)
486
+ complete(&ctx->async_wait.completion);
487
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
488
+
489
+ if (!ready)
490
+ return;
491
+
492
+ /* Schedule the transmission */
493
+ if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
494
+ schedule_delayed_work(&ctx->tx_work.work, 1);
495
+}
496
+
497
+static int tls_do_encryption(struct sock *sk,
498
+ struct tls_context *tls_ctx,
176499 struct tls_sw_context_tx *ctx,
177500 struct aead_request *aead_req,
178
- size_t data_len)
501
+ size_t data_len, u32 start)
179502 {
180
- int rc;
503
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
504
+ struct tls_rec *rec = ctx->open_rec;
505
+ struct sk_msg *msg_en = &rec->msg_encrypted;
506
+ struct scatterlist *sge = sk_msg_elem(msg_en, start);
507
+ int rc, iv_offset = 0;
181508
182
- ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
183
- ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
509
+ /* For CCM based ciphers, first byte of IV is a constant */
510
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
511
+ rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
512
+ iv_offset = 1;
513
+ }
514
+
515
+ memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
516
+ prot->iv_size + prot->salt_size);
517
+
518
+ xor_iv_with_seq(prot->version, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
519
+
520
+ sge->offset += prot->prepend_size;
521
+ sge->length -= prot->prepend_size;
522
+
523
+ msg_en->sg.curr = start;
184524
185525 aead_request_set_tfm(aead_req, ctx->aead_send);
186
- aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
187
- aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
188
- data_len, tls_ctx->tx.iv);
526
+ aead_request_set_ad(aead_req, prot->aad_size);
527
+ aead_request_set_crypt(aead_req, rec->sg_aead_in,
528
+ rec->sg_aead_out,
529
+ data_len, rec->iv_data);
189530
190531 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
191
- crypto_req_done, &ctx->async_wait);
532
+ tls_encrypt_done, sk);
192533
193
- rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
534
+ /* Add the record in tx_list */
535
+ list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
536
+ atomic_inc(&ctx->encrypt_pending);
194537
195
- ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
196
- ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
538
+ rc = crypto_aead_encrypt(aead_req);
539
+ if (!rc || rc != -EINPROGRESS) {
540
+ atomic_dec(&ctx->encrypt_pending);
541
+ sge->offset -= prot->prepend_size;
542
+ sge->length += prot->prepend_size;
543
+ }
197544
545
+ if (!rc) {
546
+ WRITE_ONCE(rec->tx_ready, true);
547
+ } else if (rc != -EINPROGRESS) {
548
+ list_del(&rec->list);
549
+ return rc;
550
+ }
551
+
552
+ /* Unhook the record from context if encryption is not failure */
553
+ ctx->open_rec = NULL;
554
+ tls_advance_record_sn(sk, prot, &tls_ctx->tx);
198555 return rc;
556
+}
557
+
558
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
559
+ struct tls_rec **to, struct sk_msg *msg_opl,
560
+ struct sk_msg *msg_oen, u32 split_point,
561
+ u32 tx_overhead_size, u32 *orig_end)
562
+{
563
+ u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
564
+ struct scatterlist *sge, *osge, *nsge;
565
+ u32 orig_size = msg_opl->sg.size;
566
+ struct scatterlist tmp = { };
567
+ struct sk_msg *msg_npl;
568
+ struct tls_rec *new;
569
+ int ret;
570
+
571
+ new = tls_get_rec(sk);
572
+ if (!new)
573
+ return -ENOMEM;
574
+ ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
575
+ tx_overhead_size, 0);
576
+ if (ret < 0) {
577
+ tls_free_rec(sk, new);
578
+ return ret;
579
+ }
580
+
581
+ *orig_end = msg_opl->sg.end;
582
+ i = msg_opl->sg.start;
583
+ sge = sk_msg_elem(msg_opl, i);
584
+ while (apply && sge->length) {
585
+ if (sge->length > apply) {
586
+ u32 len = sge->length - apply;
587
+
588
+ get_page(sg_page(sge));
589
+ sg_set_page(&tmp, sg_page(sge), len,
590
+ sge->offset + apply);
591
+ sge->length = apply;
592
+ bytes += apply;
593
+ apply = 0;
594
+ } else {
595
+ apply -= sge->length;
596
+ bytes += sge->length;
597
+ }
598
+
599
+ sk_msg_iter_var_next(i);
600
+ if (i == msg_opl->sg.end)
601
+ break;
602
+ sge = sk_msg_elem(msg_opl, i);
603
+ }
604
+
605
+ msg_opl->sg.end = i;
606
+ msg_opl->sg.curr = i;
607
+ msg_opl->sg.copybreak = 0;
608
+ msg_opl->apply_bytes = 0;
609
+ msg_opl->sg.size = bytes;
610
+
611
+ msg_npl = &new->msg_plaintext;
612
+ msg_npl->apply_bytes = apply;
613
+ msg_npl->sg.size = orig_size - bytes;
614
+
615
+ j = msg_npl->sg.start;
616
+ nsge = sk_msg_elem(msg_npl, j);
617
+ if (tmp.length) {
618
+ memcpy(nsge, &tmp, sizeof(*nsge));
619
+ sk_msg_iter_var_next(j);
620
+ nsge = sk_msg_elem(msg_npl, j);
621
+ }
622
+
623
+ osge = sk_msg_elem(msg_opl, i);
624
+ while (osge->length) {
625
+ memcpy(nsge, osge, sizeof(*nsge));
626
+ sg_unmark_end(nsge);
627
+ sk_msg_iter_var_next(i);
628
+ sk_msg_iter_var_next(j);
629
+ if (i == *orig_end)
630
+ break;
631
+ osge = sk_msg_elem(msg_opl, i);
632
+ nsge = sk_msg_elem(msg_npl, j);
633
+ }
634
+
635
+ msg_npl->sg.end = j;
636
+ msg_npl->sg.curr = j;
637
+ msg_npl->sg.copybreak = 0;
638
+
639
+ *to = new;
640
+ return 0;
641
+}
642
+
643
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
644
+ struct tls_rec *from, u32 orig_end)
645
+{
646
+ struct sk_msg *msg_npl = &from->msg_plaintext;
647
+ struct sk_msg *msg_opl = &to->msg_plaintext;
648
+ struct scatterlist *osge, *nsge;
649
+ u32 i, j;
650
+
651
+ i = msg_opl->sg.end;
652
+ sk_msg_iter_var_prev(i);
653
+ j = msg_npl->sg.start;
654
+
655
+ osge = sk_msg_elem(msg_opl, i);
656
+ nsge = sk_msg_elem(msg_npl, j);
657
+
658
+ if (sg_page(osge) == sg_page(nsge) &&
659
+ osge->offset + osge->length == nsge->offset) {
660
+ osge->length += nsge->length;
661
+ put_page(sg_page(nsge));
662
+ }
663
+
664
+ msg_opl->sg.end = orig_end;
665
+ msg_opl->sg.curr = orig_end;
666
+ msg_opl->sg.copybreak = 0;
667
+ msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
668
+ msg_opl->sg.size += msg_npl->sg.size;
669
+
670
+ sk_msg_free(sk, &to->msg_encrypted);
671
+ sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
672
+
673
+ kfree(from);
199674 }
200675
201676 static int tls_push_record(struct sock *sk, int flags,
202677 unsigned char record_type)
203678 {
204679 struct tls_context *tls_ctx = tls_get_ctx(sk);
680
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
205681 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
682
+ struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
683
+ u32 i, split_point, orig_end;
684
+ struct sk_msg *msg_pl, *msg_en;
206685 struct aead_request *req;
686
+ bool split;
207687 int rc;
208688
209
- req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
210
- if (!req)
211
- return -ENOMEM;
689
+ if (!rec)
690
+ return 0;
212691
213
- sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
214
- sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
692
+ msg_pl = &rec->msg_plaintext;
693
+ msg_en = &rec->msg_encrypted;
215694
216
- tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
217
- tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
218
- record_type);
219
-
220
- tls_fill_prepend(tls_ctx,
221
- page_address(sg_page(&ctx->sg_encrypted_data[0])) +
222
- ctx->sg_encrypted_data[0].offset,
223
- ctx->sg_plaintext_size, record_type);
224
-
225
- tls_ctx->pending_open_record_frags = 0;
226
- set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
227
-
228
- rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
229
- if (rc < 0) {
230
- /* If we are called from write_space and
231
- * we fail, we need to set this SOCK_NOSPACE
232
- * to trigger another write_space in the future.
695
+ split_point = msg_pl->apply_bytes;
696
+ split = split_point && split_point < msg_pl->sg.size;
697
+ if (unlikely((!split &&
698
+ msg_pl->sg.size +
699
+ prot->overhead_size > msg_en->sg.size) ||
700
+ (split &&
701
+ split_point +
702
+ prot->overhead_size > msg_en->sg.size))) {
703
+ split = true;
704
+ split_point = msg_en->sg.size;
705
+ }
706
+ if (split) {
707
+ rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
708
+ split_point, prot->overhead_size,
709
+ &orig_end);
710
+ if (rc < 0)
711
+ return rc;
712
+ /* This can happen if above tls_split_open_record allocates
713
+ * a single large encryption buffer instead of two smaller
714
+ * ones. In this case adjust pointers and continue without
715
+ * split.
233716 */
234
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
235
- goto out_req;
717
+ if (!msg_pl->sg.size) {
718
+ tls_merge_open_record(sk, rec, tmp, orig_end);
719
+ msg_pl = &rec->msg_plaintext;
720
+ msg_en = &rec->msg_encrypted;
721
+ split = false;
722
+ }
723
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size +
724
+ prot->overhead_size);
236725 }
237726
238
- free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
239
- &ctx->sg_plaintext_size);
727
+ rec->tx_flags = flags;
728
+ req = &rec->aead_req;
240729
241
- ctx->sg_encrypted_num_elem = 0;
242
- ctx->sg_encrypted_size = 0;
730
+ i = msg_pl->sg.end;
731
+ sk_msg_iter_var_prev(i);
243732
244
- /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
245
- rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
246
- if (rc < 0 && rc != -EAGAIN)
247
- tls_err_abort(sk, EBADMSG);
733
+ rec->content_type = record_type;
734
+ if (prot->version == TLS_1_3_VERSION) {
735
+ /* Add content type to end of message. No padding added */
736
+ sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
737
+ sg_mark_end(&rec->sg_content_type);
738
+ sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
739
+ &rec->sg_content_type);
740
+ } else {
741
+ sg_mark_end(sk_msg_elem(msg_pl, i));
742
+ }
248743
249
- tls_advance_record_sn(sk, &tls_ctx->tx);
250
-out_req:
251
- aead_request_free(req);
252
- return rc;
744
+ if (msg_pl->sg.end < msg_pl->sg.start) {
745
+ sg_chain(&msg_pl->sg.data[msg_pl->sg.start],
746
+ MAX_SKB_FRAGS - msg_pl->sg.start + 1,
747
+ msg_pl->sg.data);
748
+ }
749
+
750
+ i = msg_pl->sg.start;
751
+ sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
752
+
753
+ i = msg_en->sg.end;
754
+ sk_msg_iter_var_prev(i);
755
+ sg_mark_end(sk_msg_elem(msg_en, i));
756
+
757
+ i = msg_en->sg.start;
758
+ sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
759
+
760
+ tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
761
+ tls_ctx->tx.rec_seq, prot->rec_seq_size,
762
+ record_type, prot->version);
763
+
764
+ tls_fill_prepend(tls_ctx,
765
+ page_address(sg_page(&msg_en->sg.data[i])) +
766
+ msg_en->sg.data[i].offset,
767
+ msg_pl->sg.size + prot->tail_size,
768
+ record_type, prot->version);
769
+
770
+ tls_ctx->pending_open_record_frags = false;
771
+
772
+ rc = tls_do_encryption(sk, tls_ctx, ctx, req,
773
+ msg_pl->sg.size + prot->tail_size, i);
774
+ if (rc < 0) {
775
+ if (rc != -EINPROGRESS) {
776
+ tls_err_abort(sk, -EBADMSG);
777
+ if (split) {
778
+ tls_ctx->pending_open_record_frags = true;
779
+ tls_merge_open_record(sk, rec, tmp, orig_end);
780
+ }
781
+ }
782
+ ctx->async_capable = 1;
783
+ return rc;
784
+ } else if (split) {
785
+ msg_pl = &tmp->msg_plaintext;
786
+ msg_en = &tmp->msg_encrypted;
787
+ sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
788
+ tls_ctx->pending_open_record_frags = true;
789
+ ctx->open_rec = tmp;
790
+ }
791
+
792
+ return tls_tx_records(sk, flags);
793
+}
794
+
795
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
796
+ bool full_record, u8 record_type,
797
+ ssize_t *copied, int flags)
798
+{
799
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
800
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
801
+ struct sk_msg msg_redir = { };
802
+ struct sk_psock *psock;
803
+ struct sock *sk_redir;
804
+ struct tls_rec *rec;
805
+ bool enospc, policy;
806
+ int err = 0, send;
807
+ u32 delta = 0;
808
+
809
+ policy = !(flags & MSG_SENDPAGE_NOPOLICY);
810
+ psock = sk_psock_get(sk);
811
+ if (!psock || !policy) {
812
+ err = tls_push_record(sk, flags, record_type);
813
+ if (err && sk->sk_err == EBADMSG) {
814
+ *copied -= sk_msg_free(sk, msg);
815
+ tls_free_open_rec(sk);
816
+ err = -sk->sk_err;
817
+ }
818
+ if (psock)
819
+ sk_psock_put(sk, psock);
820
+ return err;
821
+ }
822
+more_data:
823
+ enospc = sk_msg_full(msg);
824
+ if (psock->eval == __SK_NONE) {
825
+ delta = msg->sg.size;
826
+ psock->eval = sk_psock_msg_verdict(sk, psock, msg);
827
+ delta -= msg->sg.size;
828
+ }
829
+ if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
830
+ !enospc && !full_record) {
831
+ err = -ENOSPC;
832
+ goto out_err;
833
+ }
834
+ msg->cork_bytes = 0;
835
+ send = msg->sg.size;
836
+ if (msg->apply_bytes && msg->apply_bytes < send)
837
+ send = msg->apply_bytes;
838
+
839
+ switch (psock->eval) {
840
+ case __SK_PASS:
841
+ err = tls_push_record(sk, flags, record_type);
842
+ if (err && sk->sk_err == EBADMSG) {
843
+ *copied -= sk_msg_free(sk, msg);
844
+ tls_free_open_rec(sk);
845
+ err = -sk->sk_err;
846
+ goto out_err;
847
+ }
848
+ break;
849
+ case __SK_REDIRECT:
850
+ sk_redir = psock->sk_redir;
851
+ memcpy(&msg_redir, msg, sizeof(*msg));
852
+ if (msg->apply_bytes < send)
853
+ msg->apply_bytes = 0;
854
+ else
855
+ msg->apply_bytes -= send;
856
+ sk_msg_return_zero(sk, msg, send);
857
+ msg->sg.size -= send;
858
+ release_sock(sk);
859
+ err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
860
+ lock_sock(sk);
861
+ if (err < 0) {
862
+ *copied -= sk_msg_free_nocharge(sk, &msg_redir);
863
+ msg->sg.size = 0;
864
+ }
865
+ if (msg->sg.size == 0)
866
+ tls_free_open_rec(sk);
867
+ break;
868
+ case __SK_DROP:
869
+ default:
870
+ sk_msg_free_partial(sk, msg, send);
871
+ if (msg->apply_bytes < send)
872
+ msg->apply_bytes = 0;
873
+ else
874
+ msg->apply_bytes -= send;
875
+ if (msg->sg.size == 0)
876
+ tls_free_open_rec(sk);
877
+ *copied -= (send + delta);
878
+ err = -EACCES;
879
+ }
880
+
881
+ if (likely(!err)) {
882
+ bool reset_eval = !ctx->open_rec;
883
+
884
+ rec = ctx->open_rec;
885
+ if (rec) {
886
+ msg = &rec->msg_plaintext;
887
+ if (!msg->apply_bytes)
888
+ reset_eval = true;
889
+ }
890
+ if (reset_eval) {
891
+ psock->eval = __SK_NONE;
892
+ if (psock->sk_redir) {
893
+ sock_put(psock->sk_redir);
894
+ psock->sk_redir = NULL;
895
+ }
896
+ }
897
+ if (rec)
898
+ goto more_data;
899
+ }
900
+ out_err:
901
+ sk_psock_put(sk, psock);
902
+ return err;
253903 }
254904
255905 static int tls_sw_push_pending_record(struct sock *sk, int flags)
256906 {
257
- return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
907
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
908
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
909
+ struct tls_rec *rec = ctx->open_rec;
910
+ struct sk_msg *msg_pl;
911
+ size_t copied;
912
+
913
+ if (!rec)
914
+ return 0;
915
+
916
+ msg_pl = &rec->msg_plaintext;
917
+ copied = msg_pl->sg.size;
918
+ if (!copied)
919
+ return 0;
920
+
921
+ return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
922
+ &copied, flags);
258923 }
259924
260
-static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
261
- int length, int *pages_used,
262
- unsigned int *size_used,
263
- struct scatterlist *to, int to_max_pages,
264
- bool charge)
925
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
265926 {
266
- struct page *pages[MAX_SKB_FRAGS];
927
+ long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
928
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
929
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
930
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
931
+ bool async_capable = ctx->async_capable;
932
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
933
+ bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
934
+ bool eor = !(msg->msg_flags & MSG_MORE);
935
+ size_t try_to_copy;
936
+ ssize_t copied = 0;
937
+ struct sk_msg *msg_pl, *msg_en;
938
+ struct tls_rec *rec;
939
+ int required_size;
940
+ int num_async = 0;
941
+ bool full_record;
942
+ int record_room;
943
+ int num_zc = 0;
944
+ int orig_size;
945
+ int ret = 0;
946
+ int pending;
267947
268
- size_t offset;
269
- ssize_t copied, use;
270
- int i = 0;
948
+ if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
949
+ MSG_CMSG_COMPAT))
950
+ return -EOPNOTSUPP;
951
+
952
+ mutex_lock(&tls_ctx->tx_lock);
953
+ lock_sock(sk);
954
+
955
+ if (unlikely(msg->msg_controllen)) {
956
+ ret = tls_proccess_cmsg(sk, msg, &record_type);
957
+ if (ret) {
958
+ if (ret == -EINPROGRESS)
959
+ num_async++;
960
+ else if (ret != -EAGAIN)
961
+ goto send_end;
962
+ }
963
+ }
964
+
965
+ while (msg_data_left(msg)) {
966
+ if (sk->sk_err) {
967
+ ret = -sk->sk_err;
968
+ goto send_end;
969
+ }
970
+
971
+ if (ctx->open_rec)
972
+ rec = ctx->open_rec;
973
+ else
974
+ rec = ctx->open_rec = tls_get_rec(sk);
975
+ if (!rec) {
976
+ ret = -ENOMEM;
977
+ goto send_end;
978
+ }
979
+
980
+ msg_pl = &rec->msg_plaintext;
981
+ msg_en = &rec->msg_encrypted;
982
+
983
+ orig_size = msg_pl->sg.size;
984
+ full_record = false;
985
+ try_to_copy = msg_data_left(msg);
986
+ record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
987
+ if (try_to_copy >= record_room) {
988
+ try_to_copy = record_room;
989
+ full_record = true;
990
+ }
991
+
992
+ required_size = msg_pl->sg.size + try_to_copy +
993
+ prot->overhead_size;
994
+
995
+ if (!sk_stream_memory_free(sk))
996
+ goto wait_for_sndbuf;
997
+
998
+alloc_encrypted:
999
+ ret = tls_alloc_encrypted_msg(sk, required_size);
1000
+ if (ret) {
1001
+ if (ret != -ENOSPC)
1002
+ goto wait_for_memory;
1003
+
1004
+ /* Adjust try_to_copy according to the amount that was
1005
+ * actually allocated. The difference is due
1006
+ * to max sg elements limit
1007
+ */
1008
+ try_to_copy -= required_size - msg_en->sg.size;
1009
+ full_record = true;
1010
+ }
1011
+
1012
+ if (!is_kvec && (full_record || eor) && !async_capable) {
1013
+ u32 first = msg_pl->sg.end;
1014
+
1015
+ ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
1016
+ msg_pl, try_to_copy);
1017
+ if (ret)
1018
+ goto fallback_to_reg_send;
1019
+
1020
+ num_zc++;
1021
+ copied += try_to_copy;
1022
+
1023
+ sk_msg_sg_copy_set(msg_pl, first);
1024
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1025
+ record_type, &copied,
1026
+ msg->msg_flags);
1027
+ if (ret) {
1028
+ if (ret == -EINPROGRESS)
1029
+ num_async++;
1030
+ else if (ret == -ENOMEM)
1031
+ goto wait_for_memory;
1032
+ else if (ctx->open_rec && ret == -ENOSPC)
1033
+ goto rollback_iter;
1034
+ else if (ret != -EAGAIN)
1035
+ goto send_end;
1036
+ }
1037
+ continue;
1038
+rollback_iter:
1039
+ copied -= try_to_copy;
1040
+ sk_msg_sg_copy_clear(msg_pl, first);
1041
+ iov_iter_revert(&msg->msg_iter,
1042
+ msg_pl->sg.size - orig_size);
1043
+fallback_to_reg_send:
1044
+ sk_msg_trim(sk, msg_pl, orig_size);
1045
+ }
1046
+
1047
+ required_size = msg_pl->sg.size + try_to_copy;
1048
+
1049
+ ret = tls_clone_plaintext_msg(sk, required_size);
1050
+ if (ret) {
1051
+ if (ret != -ENOSPC)
1052
+ goto send_end;
1053
+
1054
+ /* Adjust try_to_copy according to the amount that was
1055
+ * actually allocated. The difference is due
1056
+ * to max sg elements limit
1057
+ */
1058
+ try_to_copy -= required_size - msg_pl->sg.size;
1059
+ full_record = true;
1060
+ sk_msg_trim(sk, msg_en,
1061
+ msg_pl->sg.size + prot->overhead_size);
1062
+ }
1063
+
1064
+ if (try_to_copy) {
1065
+ ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
1066
+ msg_pl, try_to_copy);
1067
+ if (ret < 0)
1068
+ goto trim_sgl;
1069
+ }
1070
+
1071
+ /* Open records defined only if successfully copied, otherwise
1072
+ * we would trim the sg but not reset the open record frags.
1073
+ */
1074
+ tls_ctx->pending_open_record_frags = true;
1075
+ copied += try_to_copy;
1076
+ if (full_record || eor) {
1077
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1078
+ record_type, &copied,
1079
+ msg->msg_flags);
1080
+ if (ret) {
1081
+ if (ret == -EINPROGRESS)
1082
+ num_async++;
1083
+ else if (ret == -ENOMEM)
1084
+ goto wait_for_memory;
1085
+ else if (ret != -EAGAIN) {
1086
+ if (ret == -ENOSPC)
1087
+ ret = 0;
1088
+ goto send_end;
1089
+ }
1090
+ }
1091
+ }
1092
+
1093
+ continue;
1094
+
1095
+wait_for_sndbuf:
1096
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1097
+wait_for_memory:
1098
+ ret = sk_stream_wait_memory(sk, &timeo);
1099
+ if (ret) {
1100
+trim_sgl:
1101
+ if (ctx->open_rec)
1102
+ tls_trim_both_msgs(sk, orig_size);
1103
+ goto send_end;
1104
+ }
1105
+
1106
+ if (ctx->open_rec && msg_en->sg.size < required_size)
1107
+ goto alloc_encrypted;
1108
+ }
1109
+
1110
+ if (!num_async) {
1111
+ goto send_end;
1112
+ } else if (num_zc) {
1113
+ /* Wait for pending encryptions to get completed */
1114
+ spin_lock_bh(&ctx->encrypt_compl_lock);
1115
+ ctx->async_notify = true;
1116
+
1117
+ pending = atomic_read(&ctx->encrypt_pending);
1118
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
1119
+ if (pending)
1120
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1121
+ else
1122
+ reinit_completion(&ctx->async_wait.completion);
1123
+
1124
+ /* There can be no concurrent accesses, since we have no
1125
+ * pending encrypt operations
1126
+ */
1127
+ WRITE_ONCE(ctx->async_notify, false);
1128
+
1129
+ if (ctx->async_wait.err) {
1130
+ ret = ctx->async_wait.err;
1131
+ copied = 0;
1132
+ }
1133
+ }
1134
+
1135
+ /* Transmit if any encryptions have completed */
1136
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1137
+ cancel_delayed_work(&ctx->tx_work.work);
1138
+ tls_tx_records(sk, msg->msg_flags);
1139
+ }
1140
+
1141
+send_end:
1142
+ ret = sk_stream_error(sk, msg->msg_flags, ret);
1143
+
1144
+ release_sock(sk);
1145
+ mutex_unlock(&tls_ctx->tx_lock);
1146
+ return copied > 0 ? copied : ret;
1147
+}
1148
+
1149
+static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
1150
+ int offset, size_t size, int flags)
1151
+{
1152
+ long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
1153
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1154
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1155
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
1156
+ unsigned char record_type = TLS_RECORD_TYPE_DATA;
1157
+ struct sk_msg *msg_pl;
1158
+ struct tls_rec *rec;
1159
+ int num_async = 0;
1160
+ ssize_t copied = 0;
1161
+ bool full_record;
1162
+ int record_room;
1163
+ int ret = 0;
1164
+ bool eor;
1165
+
1166
+ eor = !(flags & MSG_SENDPAGE_NOTLAST);
1167
+ sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
1168
+
1169
+ /* Call the sk_stream functions to manage the sndbuf mem. */
1170
+ while (size > 0) {
1171
+ size_t copy, required_size;
1172
+
1173
+ if (sk->sk_err) {
1174
+ ret = -sk->sk_err;
1175
+ goto sendpage_end;
1176
+ }
1177
+
1178
+ if (ctx->open_rec)
1179
+ rec = ctx->open_rec;
1180
+ else
1181
+ rec = ctx->open_rec = tls_get_rec(sk);
1182
+ if (!rec) {
1183
+ ret = -ENOMEM;
1184
+ goto sendpage_end;
1185
+ }
1186
+
1187
+ msg_pl = &rec->msg_plaintext;
1188
+
1189
+ full_record = false;
1190
+ record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
1191
+ copy = size;
1192
+ if (copy >= record_room) {
1193
+ copy = record_room;
1194
+ full_record = true;
1195
+ }
1196
+
1197
+ required_size = msg_pl->sg.size + copy + prot->overhead_size;
1198
+
1199
+ if (!sk_stream_memory_free(sk))
1200
+ goto wait_for_sndbuf;
1201
+alloc_payload:
1202
+ ret = tls_alloc_encrypted_msg(sk, required_size);
1203
+ if (ret) {
1204
+ if (ret != -ENOSPC)
1205
+ goto wait_for_memory;
1206
+
1207
+ /* Adjust copy according to the amount that was
1208
+ * actually allocated. The difference is due
1209
+ * to max sg elements limit
1210
+ */
1211
+ copy -= required_size - msg_pl->sg.size;
1212
+ full_record = true;
1213
+ }
1214
+
1215
+ sk_msg_page_add(msg_pl, page, copy, offset);
1216
+ sk_mem_charge(sk, copy);
1217
+
1218
+ offset += copy;
1219
+ size -= copy;
1220
+ copied += copy;
1221
+
1222
+ tls_ctx->pending_open_record_frags = true;
1223
+ if (full_record || eor || sk_msg_full(msg_pl)) {
1224
+ ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
1225
+ record_type, &copied, flags);
1226
+ if (ret) {
1227
+ if (ret == -EINPROGRESS)
1228
+ num_async++;
1229
+ else if (ret == -ENOMEM)
1230
+ goto wait_for_memory;
1231
+ else if (ret != -EAGAIN) {
1232
+ if (ret == -ENOSPC)
1233
+ ret = 0;
1234
+ goto sendpage_end;
1235
+ }
1236
+ }
1237
+ }
1238
+ continue;
1239
+wait_for_sndbuf:
1240
+ set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
1241
+wait_for_memory:
1242
+ ret = sk_stream_wait_memory(sk, &timeo);
1243
+ if (ret) {
1244
+ if (ctx->open_rec)
1245
+ tls_trim_both_msgs(sk, msg_pl->sg.size);
1246
+ goto sendpage_end;
1247
+ }
1248
+
1249
+ if (ctx->open_rec)
1250
+ goto alloc_payload;
1251
+ }
1252
+
1253
+ if (num_async) {
1254
+ /* Transmit if any encryptions have completed */
1255
+ if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
1256
+ cancel_delayed_work(&ctx->tx_work.work);
1257
+ tls_tx_records(sk, flags);
1258
+ }
1259
+ }
1260
+sendpage_end:
1261
+ ret = sk_stream_error(sk, flags, ret);
1262
+ return copied > 0 ? copied : ret;
1263
+}
1264
+
1265
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
1266
+ int offset, size_t size, int flags)
1267
+{
1268
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1269
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
1270
+ MSG_NO_SHARED_FRAGS))
1271
+ return -EOPNOTSUPP;
1272
+
1273
+ return tls_sw_do_sendpage(sk, page, offset, size, flags);
1274
+}
1275
+
1276
+int tls_sw_sendpage(struct sock *sk, struct page *page,
1277
+ int offset, size_t size, int flags)
1278
+{
1279
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1280
+ int ret;
1281
+
1282
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
1283
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
1284
+ return -EOPNOTSUPP;
1285
+
1286
+ mutex_lock(&tls_ctx->tx_lock);
1287
+ lock_sock(sk);
1288
+ ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
1289
+ release_sock(sk);
1290
+ mutex_unlock(&tls_ctx->tx_lock);
1291
+ return ret;
1292
+}
1293
+
1294
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
1295
+ bool nonblock, long timeo, int *err)
1296
+{
1297
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
1298
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1299
+ struct sk_buff *skb;
1300
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
1301
+
1302
+ while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
1303
+ if (sk->sk_err) {
1304
+ *err = sock_error(sk);
1305
+ return NULL;
1306
+ }
1307
+
1308
+ if (!skb_queue_empty(&sk->sk_receive_queue)) {
1309
+ __strp_unpause(&ctx->strp);
1310
+ if (ctx->recv_pkt)
1311
+ return ctx->recv_pkt;
1312
+ }
1313
+
1314
+ if (sk->sk_shutdown & RCV_SHUTDOWN)
1315
+ return NULL;
1316
+
1317
+ if (sock_flag(sk, SOCK_DONE))
1318
+ return NULL;
1319
+
1320
+ if (nonblock || !timeo) {
1321
+ *err = -EAGAIN;
1322
+ return NULL;
1323
+ }
1324
+
1325
+ add_wait_queue(sk_sleep(sk), &wait);
1326
+ sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1327
+ sk_wait_event(sk, &timeo,
1328
+ ctx->recv_pkt != skb ||
1329
+ !sk_psock_queue_empty(psock),
1330
+ &wait);
1331
+ sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1332
+ remove_wait_queue(sk_sleep(sk), &wait);
1333
+
1334
+ /* Handle signals */
1335
+ if (signal_pending(current)) {
1336
+ *err = sock_intr_errno(timeo);
1337
+ return NULL;
1338
+ }
1339
+ }
1340
+
1341
+ return skb;
1342
+}
1343
+
1344
+static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
1345
+ int length, int *pages_used,
1346
+ unsigned int *size_used,
1347
+ struct scatterlist *to,
1348
+ int to_max_pages)
1349
+{
1350
+ int rc = 0, i = 0, num_elem = *pages_used, maxpages;
1351
+ struct page *pages[MAX_SKB_FRAGS];
2711352 unsigned int size = *size_used;
272
- int num_elem = *pages_used;
273
- int rc = 0;
274
- int maxpages;
1353
+ ssize_t copied, use;
1354
+ size_t offset;
2751355
2761356 while (length > 0) {
2771357 i = 0;
....@@ -298,17 +1378,15 @@
2981378 sg_set_page(&to[num_elem],
2991379 pages[i], use, offset);
3001380 sg_unmark_end(&to[num_elem]);
301
- if (charge)
302
- sk_mem_charge(sk, use);
1381
+ /* We do not uncharge memory from this API */
3031382
3041383 offset = 0;
3051384 copied -= use;
3061385
307
- ++i;
308
- ++num_elem;
1386
+ i++;
1387
+ num_elem++;
3091388 }
3101389 }
311
-
3121390 /* Mark the end in the last sg entry if newly added */
3131391 if (num_elem > *pages_used)
3141392 sg_mark_end(&to[num_elem - 1]);
....@@ -319,348 +1397,6 @@
3191397 *pages_used = num_elem;
3201398
3211399 return rc;
322
-}
323
-
324
-static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
325
- int bytes)
326
-{
327
- struct tls_context *tls_ctx = tls_get_ctx(sk);
328
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
329
- struct scatterlist *sg = ctx->sg_plaintext_data;
330
- int copy, i, rc = 0;
331
-
332
- for (i = tls_ctx->pending_open_record_frags;
333
- i < ctx->sg_plaintext_num_elem; ++i) {
334
- copy = sg[i].length;
335
- if (copy_from_iter(
336
- page_address(sg_page(&sg[i])) + sg[i].offset,
337
- copy, from) != copy) {
338
- rc = -EFAULT;
339
- goto out;
340
- }
341
- bytes -= copy;
342
-
343
- ++tls_ctx->pending_open_record_frags;
344
-
345
- if (!bytes)
346
- break;
347
- }
348
-
349
-out:
350
- return rc;
351
-}
352
-
353
-int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
354
-{
355
- struct tls_context *tls_ctx = tls_get_ctx(sk);
356
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
357
- int ret;
358
- int required_size;
359
- long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
360
- bool eor = !(msg->msg_flags & MSG_MORE);
361
- size_t try_to_copy, copied = 0;
362
- unsigned char record_type = TLS_RECORD_TYPE_DATA;
363
- int record_room;
364
- bool full_record;
365
- int orig_size;
366
- bool is_kvec = msg->msg_iter.type & ITER_KVEC;
367
-
368
- if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
369
- return -ENOTSUPP;
370
-
371
- lock_sock(sk);
372
-
373
- ret = tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo);
374
- if (ret)
375
- goto send_end;
376
-
377
- if (unlikely(msg->msg_controllen)) {
378
- ret = tls_proccess_cmsg(sk, msg, &record_type);
379
- if (ret)
380
- goto send_end;
381
- }
382
-
383
- while (msg_data_left(msg)) {
384
- if (sk->sk_err) {
385
- ret = -sk->sk_err;
386
- goto send_end;
387
- }
388
-
389
- orig_size = ctx->sg_plaintext_size;
390
- full_record = false;
391
- try_to_copy = msg_data_left(msg);
392
- record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
393
- if (try_to_copy >= record_room) {
394
- try_to_copy = record_room;
395
- full_record = true;
396
- }
397
-
398
- required_size = ctx->sg_plaintext_size + try_to_copy +
399
- tls_ctx->tx.overhead_size;
400
-
401
- if (!sk_stream_memory_free(sk))
402
- goto wait_for_sndbuf;
403
-alloc_encrypted:
404
- ret = alloc_encrypted_sg(sk, required_size);
405
- if (ret) {
406
- if (ret != -ENOSPC)
407
- goto wait_for_memory;
408
-
409
- /* Adjust try_to_copy according to the amount that was
410
- * actually allocated. The difference is due
411
- * to max sg elements limit
412
- */
413
- try_to_copy -= required_size - ctx->sg_encrypted_size;
414
- full_record = true;
415
- }
416
- if (!is_kvec && (full_record || eor)) {
417
- ret = zerocopy_from_iter(sk, &msg->msg_iter,
418
- try_to_copy, &ctx->sg_plaintext_num_elem,
419
- &ctx->sg_plaintext_size,
420
- ctx->sg_plaintext_data,
421
- ARRAY_SIZE(ctx->sg_plaintext_data),
422
- true);
423
- if (ret)
424
- goto fallback_to_reg_send;
425
-
426
- copied += try_to_copy;
427
- ret = tls_push_record(sk, msg->msg_flags, record_type);
428
- if (ret)
429
- goto send_end;
430
- continue;
431
-
432
-fallback_to_reg_send:
433
- trim_sg(sk, ctx->sg_plaintext_data,
434
- &ctx->sg_plaintext_num_elem,
435
- &ctx->sg_plaintext_size,
436
- orig_size);
437
- }
438
-
439
- required_size = ctx->sg_plaintext_size + try_to_copy;
440
-alloc_plaintext:
441
- ret = alloc_plaintext_sg(sk, required_size);
442
- if (ret) {
443
- if (ret != -ENOSPC)
444
- goto wait_for_memory;
445
-
446
- /* Adjust try_to_copy according to the amount that was
447
- * actually allocated. The difference is due
448
- * to max sg elements limit
449
- */
450
- try_to_copy -= required_size - ctx->sg_plaintext_size;
451
- full_record = true;
452
-
453
- trim_sg(sk, ctx->sg_encrypted_data,
454
- &ctx->sg_encrypted_num_elem,
455
- &ctx->sg_encrypted_size,
456
- ctx->sg_plaintext_size +
457
- tls_ctx->tx.overhead_size);
458
- }
459
-
460
- ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
461
- if (ret)
462
- goto trim_sgl;
463
-
464
- copied += try_to_copy;
465
- if (full_record || eor) {
466
-push_record:
467
- ret = tls_push_record(sk, msg->msg_flags, record_type);
468
- if (ret) {
469
- if (ret == -ENOMEM)
470
- goto wait_for_memory;
471
-
472
- goto send_end;
473
- }
474
- }
475
-
476
- continue;
477
-
478
-wait_for_sndbuf:
479
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
480
-wait_for_memory:
481
- ret = sk_stream_wait_memory(sk, &timeo);
482
- if (ret) {
483
-trim_sgl:
484
- trim_both_sgl(sk, orig_size);
485
- goto send_end;
486
- }
487
-
488
- if (tls_is_pending_closed_record(tls_ctx))
489
- goto push_record;
490
-
491
- if (ctx->sg_encrypted_size < required_size)
492
- goto alloc_encrypted;
493
-
494
- goto alloc_plaintext;
495
- }
496
-
497
-send_end:
498
- ret = sk_stream_error(sk, msg->msg_flags, ret);
499
-
500
- release_sock(sk);
501
- return copied ? copied : ret;
502
-}
503
-
504
-int tls_sw_sendpage(struct sock *sk, struct page *page,
505
- int offset, size_t size, int flags)
506
-{
507
- struct tls_context *tls_ctx = tls_get_ctx(sk);
508
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
509
- int ret;
510
- long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
511
- bool eor;
512
- size_t orig_size = size;
513
- unsigned char record_type = TLS_RECORD_TYPE_DATA;
514
- struct scatterlist *sg;
515
- bool full_record;
516
- int record_room;
517
-
518
- if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
519
- MSG_SENDPAGE_NOTLAST))
520
- return -ENOTSUPP;
521
-
522
- /* No MSG_EOR from splice, only look at MSG_MORE */
523
- eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
524
-
525
- lock_sock(sk);
526
-
527
- sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
528
-
529
- ret = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
530
- if (ret)
531
- goto sendpage_end;
532
-
533
- /* Call the sk_stream functions to manage the sndbuf mem. */
534
- while (size > 0) {
535
- size_t copy, required_size;
536
-
537
- if (sk->sk_err) {
538
- ret = -sk->sk_err;
539
- goto sendpage_end;
540
- }
541
-
542
- full_record = false;
543
- record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
544
- copy = size;
545
- if (copy >= record_room) {
546
- copy = record_room;
547
- full_record = true;
548
- }
549
- required_size = ctx->sg_plaintext_size + copy +
550
- tls_ctx->tx.overhead_size;
551
-
552
- if (!sk_stream_memory_free(sk))
553
- goto wait_for_sndbuf;
554
-alloc_payload:
555
- ret = alloc_encrypted_sg(sk, required_size);
556
- if (ret) {
557
- if (ret != -ENOSPC)
558
- goto wait_for_memory;
559
-
560
- /* Adjust copy according to the amount that was
561
- * actually allocated. The difference is due
562
- * to max sg elements limit
563
- */
564
- copy -= required_size - ctx->sg_plaintext_size;
565
- full_record = true;
566
- }
567
-
568
- get_page(page);
569
- sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
570
- sg_set_page(sg, page, copy, offset);
571
- sg_unmark_end(sg);
572
-
573
- ctx->sg_plaintext_num_elem++;
574
-
575
- sk_mem_charge(sk, copy);
576
- offset += copy;
577
- size -= copy;
578
- ctx->sg_plaintext_size += copy;
579
- tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
580
-
581
- if (full_record || eor ||
582
- ctx->sg_plaintext_num_elem ==
583
- ARRAY_SIZE(ctx->sg_plaintext_data)) {
584
-push_record:
585
- ret = tls_push_record(sk, flags, record_type);
586
- if (ret) {
587
- if (ret == -ENOMEM)
588
- goto wait_for_memory;
589
-
590
- goto sendpage_end;
591
- }
592
- }
593
- continue;
594
-wait_for_sndbuf:
595
- set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
596
-wait_for_memory:
597
- ret = sk_stream_wait_memory(sk, &timeo);
598
- if (ret) {
599
- trim_both_sgl(sk, ctx->sg_plaintext_size);
600
- goto sendpage_end;
601
- }
602
-
603
- if (tls_is_pending_closed_record(tls_ctx))
604
- goto push_record;
605
-
606
- goto alloc_payload;
607
- }
608
-
609
-sendpage_end:
610
- if (orig_size > size)
611
- ret = orig_size - size;
612
- else
613
- ret = sk_stream_error(sk, flags, ret);
614
-
615
- release_sock(sk);
616
- return ret;
617
-}
618
-
619
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
620
- long timeo, int *err)
621
-{
622
- struct tls_context *tls_ctx = tls_get_ctx(sk);
623
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
624
- struct sk_buff *skb;
625
- DEFINE_WAIT_FUNC(wait, woken_wake_function);
626
-
627
- while (!(skb = ctx->recv_pkt)) {
628
- if (sk->sk_err) {
629
- *err = sock_error(sk);
630
- return NULL;
631
- }
632
-
633
- if (!skb_queue_empty(&sk->sk_receive_queue)) {
634
- __strp_unpause(&ctx->strp);
635
- if (ctx->recv_pkt)
636
- return ctx->recv_pkt;
637
- }
638
-
639
- if (sk->sk_shutdown & RCV_SHUTDOWN)
640
- return NULL;
641
-
642
- if (sock_flag(sk, SOCK_DONE))
643
- return NULL;
644
-
645
- if ((flags & MSG_DONTWAIT) || !timeo) {
646
- *err = -EAGAIN;
647
- return NULL;
648
- }
649
-
650
- add_wait_queue(sk_sleep(sk), &wait);
651
- sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
652
- sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
653
- sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
654
- remove_wait_queue(sk_sleep(sk), &wait);
655
-
656
- /* Handle signals */
657
- if (signal_pending(current)) {
658
- *err = sock_intr_errno(timeo);
659
- return NULL;
660
- }
661
- }
662
-
663
- return skb;
6641400 }
6651401
6661402 /* This function decrypts the input skb into either out_iov or in out_sg
....@@ -674,10 +1410,11 @@
6741410 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
6751411 struct iov_iter *out_iov,
6761412 struct scatterlist *out_sg,
677
- int *chunk, bool *zc)
1413
+ int *chunk, bool *zc, bool async)
6781414 {
6791415 struct tls_context *tls_ctx = tls_get_ctx(sk);
6801416 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1417
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
6811418 struct strp_msg *rxm = strp_msg(skb);
6821419 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
6831420 struct aead_request *aead_req;
....@@ -685,19 +1422,23 @@
6851422 u8 *aad, *iv, *mem = NULL;
6861423 struct scatterlist *sgin = NULL;
6871424 struct scatterlist *sgout = NULL;
688
- const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
1425
+ const int data_len = rxm->full_len - prot->overhead_size +
1426
+ prot->tail_size;
1427
+ int iv_offset = 0;
6891428
6901429 if (*zc && (out_iov || out_sg)) {
6911430 if (out_iov)
6921431 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
6931432 else
6941433 n_sgout = sg_nents(out_sg);
1434
+ n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
1435
+ rxm->full_len - prot->prepend_size);
6951436 } else {
6961437 n_sgout = 0;
6971438 *zc = false;
1439
+ n_sgin = skb_cow_data(skb, 0, &unused);
6981440 }
6991441
700
- n_sgin = skb_cow_data(skb, 0, &unused);
7011442 if (n_sgin < 1)
7021443 return -EBADMSG;
7031444
....@@ -708,7 +1449,7 @@
7081449
7091450 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
7101451 mem_size = aead_size + (nsg * sizeof(struct scatterlist));
711
- mem_size = mem_size + TLS_AAD_SPACE_SIZE;
1452
+ mem_size = mem_size + prot->aad_size;
7121453 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
7131454
7141455 /* Allocate a single block of memory which contains
....@@ -724,29 +1465,42 @@
7241465 sgin = (struct scatterlist *)(mem + aead_size);
7251466 sgout = sgin + n_sgin;
7261467 aad = (u8 *)(sgout + n_sgout);
727
- iv = aad + TLS_AAD_SPACE_SIZE;
1468
+ iv = aad + prot->aad_size;
1469
+
1470
+ /* For CCM based ciphers, first byte of nonce+iv is always '2' */
1471
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
1472
+ iv[0] = 2;
1473
+ iv_offset = 1;
1474
+ }
7281475
7291476 /* Prepare IV */
7301477 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
731
- iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
732
- tls_ctx->rx.iv_size);
1478
+ iv + iv_offset + prot->salt_size,
1479
+ prot->iv_size);
7331480 if (err < 0) {
7341481 kfree(mem);
7351482 return err;
7361483 }
737
- memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1484
+ if (prot->version == TLS_1_3_VERSION)
1485
+ memcpy(iv + iv_offset, tls_ctx->rx.iv,
1486
+ prot->iv_size + prot->salt_size);
1487
+ else
1488
+ memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
1489
+
1490
+ xor_iv_with_seq(prot->version, iv + iv_offset, tls_ctx->rx.rec_seq);
7381491
7391492 /* Prepare AAD */
740
- tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
741
- tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
742
- ctx->control);
1493
+ tls_make_aad(aad, rxm->full_len - prot->overhead_size +
1494
+ prot->tail_size,
1495
+ tls_ctx->rx.rec_seq, prot->rec_seq_size,
1496
+ ctx->control, prot->version);
7431497
7441498 /* Prepare sgin */
7451499 sg_init_table(sgin, n_sgin);
746
- sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
1500
+ sg_set_buf(&sgin[0], aad, prot->aad_size);
7471501 err = skb_to_sgvec(skb, &sgin[1],
748
- rxm->offset + tls_ctx->rx.prepend_size,
749
- rxm->full_len - tls_ctx->rx.prepend_size);
1502
+ rxm->offset + prot->prepend_size,
1503
+ rxm->full_len - prot->prepend_size);
7501504 if (err < 0) {
7511505 kfree(mem);
7521506 return err;
....@@ -755,12 +1509,12 @@
7551509 if (n_sgout) {
7561510 if (out_iov) {
7571511 sg_init_table(sgout, n_sgout);
758
- sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
1512
+ sg_set_buf(&sgout[0], aad, prot->aad_size);
7591513
7601514 *chunk = 0;
761
- err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
762
- chunk, &sgout[1],
763
- (n_sgout - 1), false);
1515
+ err = tls_setup_from_iter(sk, out_iov, data_len,
1516
+ &pages, chunk, &sgout[1],
1517
+ (n_sgout - 1));
7641518 if (err < 0)
7651519 goto fallback_to_reg_recv;
7661520 } else if (out_sg) {
....@@ -772,12 +1526,15 @@
7721526 fallback_to_reg_recv:
7731527 sgout = sgin;
7741528 pages = 0;
775
- *chunk = 0;
1529
+ *chunk = data_len;
7761530 *zc = false;
7771531 }
7781532
7791533 /* Prepare and submit AEAD request */
780
- err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
1534
+ err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1535
+ data_len, aead_req, async);
1536
+ if (err == -EINPROGRESS)
1537
+ return err;
7811538
7821539 /* Release the pages in case iov was mapped to pages */
7831540 for (; pages > 0; pages--)
....@@ -788,31 +1545,52 @@
7881545 }
7891546
7901547 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
791
- struct iov_iter *dest, int *chunk, bool *zc)
1548
+ struct iov_iter *dest, int *chunk, bool *zc,
1549
+ bool async)
7921550 {
7931551 struct tls_context *tls_ctx = tls_get_ctx(sk);
7941552 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1553
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
7951554 struct strp_msg *rxm = strp_msg(skb);
796
- int err = 0;
1555
+ int pad, err = 0;
7971556
798
-#ifdef CONFIG_TLS_DEVICE
799
- err = tls_device_decrypted(sk, skb);
800
- if (err < 0)
801
- return err;
802
-#endif
8031557 if (!ctx->decrypted) {
804
- err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
805
- if (err < 0)
806
- return err;
1558
+ if (tls_ctx->rx_conf == TLS_HW) {
1559
+ err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
1560
+ if (err < 0)
1561
+ return err;
1562
+ }
1563
+
1564
+ /* Still not decrypted after tls_device */
1565
+ if (!ctx->decrypted) {
1566
+ err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
1567
+ async);
1568
+ if (err < 0) {
1569
+ if (err == -EINPROGRESS)
1570
+ tls_advance_record_sn(sk, prot,
1571
+ &tls_ctx->rx);
1572
+ else if (err == -EBADMSG)
1573
+ TLS_INC_STATS(sock_net(sk),
1574
+ LINUX_MIB_TLSDECRYPTERROR);
1575
+ return err;
1576
+ }
1577
+ } else {
1578
+ *zc = false;
1579
+ }
1580
+
1581
+ pad = padding_length(ctx, prot, skb);
1582
+ if (pad < 0)
1583
+ return pad;
1584
+
1585
+ rxm->full_len -= pad;
1586
+ rxm->offset += prot->prepend_size;
1587
+ rxm->full_len -= prot->overhead_size;
1588
+ tls_advance_record_sn(sk, prot, &tls_ctx->rx);
1589
+ ctx->decrypted = 1;
1590
+ ctx->saved_data_ready(sk);
8071591 } else {
8081592 *zc = false;
8091593 }
810
-
811
- rxm->offset += tls_ctx->rx.prepend_size;
812
- rxm->full_len -= tls_ctx->rx.overhead_size;
813
- tls_advance_record_sn(sk, &tls_ctx->rx);
814
- ctx->decrypted = true;
815
- ctx->saved_data_ready(sk);
8161594
8171595 return err;
8181596 }
....@@ -823,7 +1601,7 @@
8231601 bool zc = true;
8241602 int chunk;
8251603
826
- return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
1604
+ return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
8271605 }
8281606
8291607 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
....@@ -831,21 +1609,132 @@
8311609 {
8321610 struct tls_context *tls_ctx = tls_get_ctx(sk);
8331611 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
834
- struct strp_msg *rxm = strp_msg(skb);
8351612
836
- if (len < rxm->full_len) {
837
- rxm->offset += len;
838
- rxm->full_len -= len;
1613
+ if (skb) {
1614
+ struct strp_msg *rxm = strp_msg(skb);
8391615
840
- return false;
1616
+ if (len < rxm->full_len) {
1617
+ rxm->offset += len;
1618
+ rxm->full_len -= len;
1619
+ return false;
1620
+ }
1621
+ consume_skb(skb);
8411622 }
8421623
8431624 /* Finished with message */
8441625 ctx->recv_pkt = NULL;
845
- kfree_skb(skb);
8461626 __strp_unpause(&ctx->strp);
8471627
8481628 return true;
1629
+}
1630
+
1631
+/* This function traverses the rx_list in tls receive context to copies the
1632
+ * decrypted records into the buffer provided by caller zero copy is not
1633
+ * true. Further, the records are removed from the rx_list if it is not a peek
1634
+ * case and the record has been consumed completely.
1635
+ */
1636
+static int process_rx_list(struct tls_sw_context_rx *ctx,
1637
+ struct msghdr *msg,
1638
+ u8 *control,
1639
+ bool *cmsg,
1640
+ size_t skip,
1641
+ size_t len,
1642
+ bool zc,
1643
+ bool is_peek)
1644
+{
1645
+ struct sk_buff *skb = skb_peek(&ctx->rx_list);
1646
+ u8 ctrl = *control;
1647
+ u8 msgc = *cmsg;
1648
+ struct tls_msg *tlm;
1649
+ ssize_t copied = 0;
1650
+
1651
+ /* Set the record type in 'control' if caller didn't pass it */
1652
+ if (!ctrl && skb) {
1653
+ tlm = tls_msg(skb);
1654
+ ctrl = tlm->control;
1655
+ }
1656
+
1657
+ while (skip && skb) {
1658
+ struct strp_msg *rxm = strp_msg(skb);
1659
+ tlm = tls_msg(skb);
1660
+
1661
+ /* Cannot process a record of different type */
1662
+ if (ctrl != tlm->control)
1663
+ return 0;
1664
+
1665
+ if (skip < rxm->full_len)
1666
+ break;
1667
+
1668
+ skip = skip - rxm->full_len;
1669
+ skb = skb_peek_next(skb, &ctx->rx_list);
1670
+ }
1671
+
1672
+ while (len && skb) {
1673
+ struct sk_buff *next_skb;
1674
+ struct strp_msg *rxm = strp_msg(skb);
1675
+ int chunk = min_t(unsigned int, rxm->full_len - skip, len);
1676
+
1677
+ tlm = tls_msg(skb);
1678
+
1679
+ /* Cannot process a record of different type */
1680
+ if (ctrl != tlm->control)
1681
+ return 0;
1682
+
1683
+ /* Set record type if not already done. For a non-data record,
1684
+ * do not proceed if record type could not be copied.
1685
+ */
1686
+ if (!msgc) {
1687
+ int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1688
+ sizeof(ctrl), &ctrl);
1689
+ msgc = true;
1690
+ if (ctrl != TLS_RECORD_TYPE_DATA) {
1691
+ if (cerr || msg->msg_flags & MSG_CTRUNC)
1692
+ return -EIO;
1693
+
1694
+ *cmsg = msgc;
1695
+ }
1696
+ }
1697
+
1698
+ if (!zc || (rxm->full_len - skip) > len) {
1699
+ int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
1700
+ msg, chunk);
1701
+ if (err < 0)
1702
+ return err;
1703
+ }
1704
+
1705
+ len = len - chunk;
1706
+ copied = copied + chunk;
1707
+
1708
+ /* Consume the data from record if it is non-peek case*/
1709
+ if (!is_peek) {
1710
+ rxm->offset = rxm->offset + chunk;
1711
+ rxm->full_len = rxm->full_len - chunk;
1712
+
1713
+ /* Return if there is unconsumed data in the record */
1714
+ if (rxm->full_len - skip)
1715
+ break;
1716
+ }
1717
+
1718
+ /* The remaining skip-bytes must lie in 1st record in rx_list.
1719
+ * So from the 2nd record, 'skip' should be 0.
1720
+ */
1721
+ skip = 0;
1722
+
1723
+ if (msg)
1724
+ msg->msg_flags |= MSG_EOR;
1725
+
1726
+ next_skb = skb_peek_next(skb, &ctx->rx_list);
1727
+
1728
+ if (!is_peek) {
1729
+ skb_unlink(skb, &ctx->rx_list);
1730
+ consume_skb(skb);
1731
+ }
1732
+
1733
+ skb = next_skb;
1734
+ }
1735
+
1736
+ *control = ctrl;
1737
+ return copied;
8491738 }
8501739
8511740 int tls_sw_recvmsg(struct sock *sk,
....@@ -857,104 +1746,241 @@
8571746 {
8581747 struct tls_context *tls_ctx = tls_get_ctx(sk);
8591748 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
860
- unsigned char control;
1749
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
1750
+ struct sk_psock *psock;
1751
+ unsigned char control = 0;
1752
+ ssize_t decrypted = 0;
8611753 struct strp_msg *rxm;
1754
+ struct tls_msg *tlm;
8621755 struct sk_buff *skb;
8631756 ssize_t copied = 0;
8641757 bool cmsg = false;
8651758 int target, err = 0;
8661759 long timeo;
867
- bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1760
+ bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
1761
+ bool is_peek = flags & MSG_PEEK;
1762
+ bool bpf_strp_enabled;
1763
+ int num_async = 0;
1764
+ int pending;
8681765
8691766 flags |= nonblock;
8701767
8711768 if (unlikely(flags & MSG_ERRQUEUE))
8721769 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
8731770
1771
+ psock = sk_psock_get(sk);
8741772 lock_sock(sk);
1773
+ bpf_strp_enabled = sk_psock_strp_enabled(psock);
1774
+
1775
+ /* Process pending decrypted records. It must be non-zero-copy */
1776
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
1777
+ is_peek);
1778
+ if (err < 0) {
1779
+ tls_err_abort(sk, err);
1780
+ goto end;
1781
+ } else {
1782
+ copied = err;
1783
+ }
1784
+
1785
+ if (len <= copied)
1786
+ goto recv_end;
8751787
8761788 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1789
+ len = len - copied;
8771790 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
878
- do {
879
- bool zc = false;
880
- int chunk = 0;
8811791
882
- skb = tls_wait_data(sk, flags, timeo, &err);
883
- if (!skb)
1792
+ while (len && (decrypted + copied < target || ctx->recv_pkt)) {
1793
+ bool retain_skb = false;
1794
+ bool zc = false;
1795
+ int to_decrypt;
1796
+ int chunk = 0;
1797
+ bool async_capable;
1798
+ bool async = false;
1799
+
1800
+ skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
1801
+ if (!skb) {
1802
+ if (psock) {
1803
+ int ret = __tcp_bpf_recvmsg(sk, psock,
1804
+ msg, len, flags);
1805
+
1806
+ if (ret > 0) {
1807
+ decrypted += ret;
1808
+ len -= ret;
1809
+ continue;
1810
+ }
1811
+ }
8841812 goto recv_end;
1813
+ } else {
1814
+ tlm = tls_msg(skb);
1815
+ if (prot->version == TLS_1_3_VERSION)
1816
+ tlm->control = 0;
1817
+ else
1818
+ tlm->control = ctx->control;
1819
+ }
8851820
8861821 rxm = strp_msg(skb);
1822
+
1823
+ to_decrypt = rxm->full_len - prot->overhead_size;
1824
+
1825
+ if (to_decrypt <= len && !is_kvec && !is_peek &&
1826
+ ctx->control == TLS_RECORD_TYPE_DATA &&
1827
+ prot->version != TLS_1_3_VERSION &&
1828
+ !bpf_strp_enabled)
1829
+ zc = true;
1830
+
1831
+ /* Do not use async mode if record is non-data */
1832
+ if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
1833
+ async_capable = ctx->async_capable;
1834
+ else
1835
+ async_capable = false;
1836
+
1837
+ err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1838
+ &chunk, &zc, async_capable);
1839
+ if (err < 0 && err != -EINPROGRESS) {
1840
+ tls_err_abort(sk, -EBADMSG);
1841
+ goto recv_end;
1842
+ }
1843
+
1844
+ if (err == -EINPROGRESS) {
1845
+ async = true;
1846
+ num_async++;
1847
+ } else if (prot->version == TLS_1_3_VERSION) {
1848
+ tlm->control = ctx->control;
1849
+ }
1850
+
1851
+ /* If the type of records being processed is not known yet,
1852
+ * set it to record type just dequeued. If it is already known,
1853
+ * but does not match the record type just dequeued, go to end.
1854
+ * We always get record type here since for tls1.2, record type
1855
+ * is known just after record is dequeued from stream parser.
1856
+ * For tls1.3, we disable async.
1857
+ */
1858
+
1859
+ if (!control)
1860
+ control = tlm->control;
1861
+ else if (control != tlm->control)
1862
+ goto recv_end;
1863
+
8871864 if (!cmsg) {
8881865 int cerr;
8891866
8901867 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
891
- sizeof(ctx->control), &ctx->control);
1868
+ sizeof(control), &control);
8921869 cmsg = true;
893
- control = ctx->control;
894
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
1870
+ if (control != TLS_RECORD_TYPE_DATA) {
8951871 if (cerr || msg->msg_flags & MSG_CTRUNC) {
8961872 err = -EIO;
8971873 goto recv_end;
8981874 }
8991875 }
900
- } else if (control != ctx->control) {
901
- goto recv_end;
9021876 }
9031877
904
- if (!ctx->decrypted) {
905
- int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
906
-
907
- if (!is_kvec && to_copy <= len &&
908
- likely(!(flags & MSG_PEEK)))
909
- zc = true;
910
-
911
- err = decrypt_skb_update(sk, skb, &msg->msg_iter,
912
- &chunk, &zc);
913
- if (err < 0) {
914
- tls_err_abort(sk, EBADMSG);
915
- goto recv_end;
916
- }
917
- ctx->decrypted = true;
918
- }
1878
+ if (async)
1879
+ goto pick_next_record;
9191880
9201881 if (!zc) {
921
- chunk = min_t(unsigned int, rxm->full_len, len);
922
- err = skb_copy_datagram_msg(skb, rxm->offset, msg,
923
- chunk);
1882
+ if (bpf_strp_enabled) {
1883
+ err = sk_psock_tls_strp_read(psock, skb);
1884
+ if (err != __SK_PASS) {
1885
+ rxm->offset = rxm->offset + rxm->full_len;
1886
+ rxm->full_len = 0;
1887
+ if (err == __SK_DROP)
1888
+ consume_skb(skb);
1889
+ ctx->recv_pkt = NULL;
1890
+ __strp_unpause(&ctx->strp);
1891
+ continue;
1892
+ }
1893
+ }
1894
+
1895
+ if (rxm->full_len > len) {
1896
+ retain_skb = true;
1897
+ chunk = len;
1898
+ } else {
1899
+ chunk = rxm->full_len;
1900
+ }
1901
+
1902
+ err = skb_copy_datagram_msg(skb, rxm->offset,
1903
+ msg, chunk);
9241904 if (err < 0)
9251905 goto recv_end;
926
- }
9271906
928
- copied += chunk;
929
- len -= chunk;
930
- if (likely(!(flags & MSG_PEEK))) {
931
- u8 control = ctx->control;
932
-
933
- if (tls_sw_advance_skb(sk, skb, chunk)) {
934
- /* Return full control message to
935
- * userspace before trying to parse
936
- * another message type
937
- */
938
- msg->msg_flags |= MSG_EOR;
939
- if (control != TLS_RECORD_TYPE_DATA)
940
- goto recv_end;
1907
+ if (!is_peek) {
1908
+ rxm->offset = rxm->offset + chunk;
1909
+ rxm->full_len = rxm->full_len - chunk;
9411910 }
942
- } else {
943
- /* MSG_PEEK right now cannot look beyond current skb
944
- * from strparser, meaning we cannot advance skb here
945
- * and thus unpause strparser since we'd loose original
946
- * one.
947
- */
948
- break;
9491911 }
9501912
951
- /* If we have a new message from strparser, continue now. */
952
- if (copied >= target && !ctx->recv_pkt)
1913
+pick_next_record:
1914
+ if (chunk > len)
1915
+ chunk = len;
1916
+
1917
+ decrypted += chunk;
1918
+ len -= chunk;
1919
+
1920
+ /* For async or peek case, queue the current skb */
1921
+ if (async || is_peek || retain_skb) {
1922
+ skb_queue_tail(&ctx->rx_list, skb);
1923
+ skb = NULL;
1924
+ }
1925
+
1926
+ if (tls_sw_advance_skb(sk, skb, chunk)) {
1927
+ /* Return full control message to
1928
+ * userspace before trying to parse
1929
+ * another message type
1930
+ */
1931
+ msg->msg_flags |= MSG_EOR;
1932
+ if (control != TLS_RECORD_TYPE_DATA)
1933
+ goto recv_end;
1934
+ } else {
9531935 break;
954
- } while (len);
1936
+ }
1937
+ }
9551938
9561939 recv_end:
1940
+ if (num_async) {
1941
+ /* Wait for all previously submitted records to be decrypted */
1942
+ spin_lock_bh(&ctx->decrypt_compl_lock);
1943
+ ctx->async_notify = true;
1944
+ pending = atomic_read(&ctx->decrypt_pending);
1945
+ spin_unlock_bh(&ctx->decrypt_compl_lock);
1946
+ if (pending) {
1947
+ err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1948
+ if (err) {
1949
+ /* one of async decrypt failed */
1950
+ tls_err_abort(sk, err);
1951
+ copied = 0;
1952
+ decrypted = 0;
1953
+ goto end;
1954
+ }
1955
+ } else {
1956
+ reinit_completion(&ctx->async_wait.completion);
1957
+ }
1958
+
1959
+ /* There can be no concurrent accesses, since we have no
1960
+ * pending decrypt operations
1961
+ */
1962
+ WRITE_ONCE(ctx->async_notify, false);
1963
+
1964
+ /* Drain records from the rx_list & copy if required */
1965
+ if (is_peek || is_kvec)
1966
+ err = process_rx_list(ctx, msg, &control, &cmsg, copied,
1967
+ decrypted, false, is_peek);
1968
+ else
1969
+ err = process_rx_list(ctx, msg, &control, &cmsg, 0,
1970
+ decrypted, true, is_peek);
1971
+ if (err < 0) {
1972
+ tls_err_abort(sk, err);
1973
+ copied = 0;
1974
+ goto end;
1975
+ }
1976
+ }
1977
+
1978
+ copied += decrypted;
1979
+
1980
+end:
9571981 release_sock(sk);
1982
+ if (psock)
1983
+ sk_psock_put(sk, psock);
9581984 return copied ? : err;
9591985 }
9601986
....@@ -975,27 +2001,24 @@
9752001
9762002 lock_sock(sk);
9772003
978
- timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
2004
+ timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
9792005
980
- skb = tls_wait_data(sk, flags, timeo, &err);
2006
+ skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
9812007 if (!skb)
9822008 goto splice_read_end;
9832009
984
- /* splice does not support reading control messages */
985
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
986
- err = -ENOTSUPP;
2010
+ err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
2011
+ if (err < 0) {
2012
+ tls_err_abort(sk, -EBADMSG);
9872013 goto splice_read_end;
9882014 }
9892015
990
- if (!ctx->decrypted) {
991
- err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
992
-
993
- if (err < 0) {
994
- tls_err_abort(sk, EBADMSG);
995
- goto splice_read_end;
996
- }
997
- ctx->decrypted = true;
2016
+ /* splice does not support reading control messages */
2017
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
2018
+ err = -EINVAL;
2019
+ goto splice_read_end;
9982020 }
2021
+
9992022 rxm = strp_msg(skb);
10002023
10012024 chunk = min_t(unsigned int, rxm->full_len, len);
....@@ -1011,29 +2034,28 @@
10112034 return copied ? : err;
10122035 }
10132036
1014
-unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1015
- struct poll_table_struct *wait)
2037
+bool tls_sw_stream_read(const struct sock *sk)
10162038 {
1017
- unsigned int ret;
1018
- struct sock *sk = sock->sk;
10192039 struct tls_context *tls_ctx = tls_get_ctx(sk);
10202040 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2041
+ bool ingress_empty = true;
2042
+ struct sk_psock *psock;
10212043
1022
- /* Grab POLLOUT and POLLHUP from the underlying socket */
1023
- ret = ctx->sk_poll(file, sock, wait);
2044
+ rcu_read_lock();
2045
+ psock = sk_psock(sk);
2046
+ if (psock)
2047
+ ingress_empty = list_empty(&psock->ingress_msg);
2048
+ rcu_read_unlock();
10242049
1025
- /* Clear POLLIN bits, and set based on recv_pkt */
1026
- ret &= ~(POLLIN | POLLRDNORM);
1027
- if (ctx->recv_pkt)
1028
- ret |= POLLIN | POLLRDNORM;
1029
-
1030
- return ret;
2050
+ return !ingress_empty || ctx->recv_pkt ||
2051
+ !skb_queue_empty(&ctx->rx_list);
10312052 }
10322053
10332054 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
10342055 {
10352056 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
10362057 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2058
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
10372059 char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
10382060 struct strp_msg *rxm = strp_msg(skb);
10392061 size_t cipher_overhead;
....@@ -1041,17 +2063,17 @@
10412063 int ret;
10422064
10432065 /* Verify that we have a full TLS header, or wait for more data */
1044
- if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
2066
+ if (rxm->offset + prot->prepend_size > skb->len)
10452067 return 0;
10462068
10472069 /* Sanity-check size of on-stack buffer. */
1048
- if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
2070
+ if (WARN_ON(prot->prepend_size > sizeof(header))) {
10492071 ret = -EINVAL;
10502072 goto read_failure;
10512073 }
10522074
10532075 /* Linearize header to local buffer */
1054
- ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
2076
+ ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
10552077
10562078 if (ret < 0)
10572079 goto read_failure;
....@@ -1060,9 +2082,12 @@
10602082
10612083 data_len = ((header[4] & 0xFF) | (header[3] << 8));
10622084
1063
- cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
2085
+ cipher_overhead = prot->tag_size;
2086
+ if (prot->version != TLS_1_3_VERSION)
2087
+ cipher_overhead += prot->iv_size;
10642088
1065
- if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
2089
+ if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
2090
+ prot->tail_size) {
10662091 ret = -EMSGSIZE;
10672092 goto read_failure;
10682093 }
....@@ -1071,16 +2096,15 @@
10712096 goto read_failure;
10722097 }
10732098
1074
- if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1075
- header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
2099
+ /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
2100
+ if (header[1] != TLS_1_2_VERSION_MINOR ||
2101
+ header[2] != TLS_1_2_VERSION_MAJOR) {
10762102 ret = -EINVAL;
10772103 goto read_failure;
10782104 }
10792105
1080
-#ifdef CONFIG_TLS_DEVICE
1081
- handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1082
- *(u64*)tls_ctx->rx.rec_seq);
1083
-#endif
2106
+ tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
2107
+ TCP_SKB_CB(skb)->seq + rxm->offset);
10842108 return data_len + TLS_HEADER_SIZE;
10852109
10862110 read_failure:
....@@ -1094,7 +2118,7 @@
10942118 struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
10952119 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
10962120
1097
- ctx->decrypted = false;
2121
+ ctx->decrypted = 0;
10982122
10992123 ctx->recv_pkt = skb;
11002124 strp_pause(strp);
....@@ -1106,17 +2130,71 @@
11062130 {
11072131 struct tls_context *tls_ctx = tls_get_ctx(sk);
11082132 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2133
+ struct sk_psock *psock;
11092134
11102135 strp_data_ready(&ctx->strp);
2136
+
2137
+ psock = sk_psock_get(sk);
2138
+ if (psock) {
2139
+ if (!list_empty(&psock->ingress_msg))
2140
+ ctx->saved_data_ready(sk);
2141
+ sk_psock_put(sk, psock);
2142
+ }
11112143 }
11122144
1113
-void tls_sw_free_resources_tx(struct sock *sk)
2145
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
2146
+{
2147
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2148
+
2149
+ set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
2150
+ set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
2151
+ cancel_delayed_work_sync(&ctx->tx_work.work);
2152
+}
2153
+
2154
+void tls_sw_release_resources_tx(struct sock *sk)
11142155 {
11152156 struct tls_context *tls_ctx = tls_get_ctx(sk);
11162157 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
2158
+ struct tls_rec *rec, *tmp;
2159
+ int pending;
2160
+
2161
+ /* Wait for any pending async encryptions to complete */
2162
+ spin_lock_bh(&ctx->encrypt_compl_lock);
2163
+ ctx->async_notify = true;
2164
+ pending = atomic_read(&ctx->encrypt_pending);
2165
+ spin_unlock_bh(&ctx->encrypt_compl_lock);
2166
+
2167
+ if (pending)
2168
+ crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
2169
+
2170
+ tls_tx_records(sk, -1);
2171
+
2172
+ /* Free up un-sent records in tx_list. First, free
2173
+ * the partially sent record if any at head of tx_list.
2174
+ */
2175
+ if (tls_ctx->partially_sent_record) {
2176
+ tls_free_partial_record(sk, tls_ctx);
2177
+ rec = list_first_entry(&ctx->tx_list,
2178
+ struct tls_rec, list);
2179
+ list_del(&rec->list);
2180
+ sk_msg_free(sk, &rec->msg_plaintext);
2181
+ kfree(rec);
2182
+ }
2183
+
2184
+ list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
2185
+ list_del(&rec->list);
2186
+ sk_msg_free(sk, &rec->msg_encrypted);
2187
+ sk_msg_free(sk, &rec->msg_plaintext);
2188
+ kfree(rec);
2189
+ }
11172190
11182191 crypto_free_aead(ctx->aead_send);
1119
- tls_free_both_sg(sk);
2192
+ tls_free_open_rec(sk);
2193
+}
2194
+
2195
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
2196
+{
2197
+ struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
11202198
11212199 kfree(ctx);
11222200 }
....@@ -1132,38 +2210,108 @@
11322210 if (ctx->aead_recv) {
11332211 kfree_skb(ctx->recv_pkt);
11342212 ctx->recv_pkt = NULL;
2213
+ skb_queue_purge(&ctx->rx_list);
11352214 crypto_free_aead(ctx->aead_recv);
11362215 strp_stop(&ctx->strp);
1137
- write_lock_bh(&sk->sk_callback_lock);
1138
- sk->sk_data_ready = ctx->saved_data_ready;
1139
- write_unlock_bh(&sk->sk_callback_lock);
1140
- release_sock(sk);
1141
- strp_done(&ctx->strp);
1142
- lock_sock(sk);
2216
+ /* If tls_sw_strparser_arm() was not called (cleanup paths)
2217
+ * we still want to strp_stop(), but sk->sk_data_ready was
2218
+ * never swapped.
2219
+ */
2220
+ if (ctx->saved_data_ready) {
2221
+ write_lock_bh(&sk->sk_callback_lock);
2222
+ sk->sk_data_ready = ctx->saved_data_ready;
2223
+ write_unlock_bh(&sk->sk_callback_lock);
2224
+ }
11432225 }
2226
+}
2227
+
2228
+void tls_sw_strparser_done(struct tls_context *tls_ctx)
2229
+{
2230
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2231
+
2232
+ strp_done(&ctx->strp);
2233
+}
2234
+
2235
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
2236
+{
2237
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
2238
+
2239
+ kfree(ctx);
11442240 }
11452241
11462242 void tls_sw_free_resources_rx(struct sock *sk)
11472243 {
11482244 struct tls_context *tls_ctx = tls_get_ctx(sk);
1149
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
11502245
11512246 tls_sw_release_resources_rx(sk);
2247
+ tls_sw_free_ctx_rx(tls_ctx);
2248
+}
11522249
1153
- kfree(ctx);
2250
+/* The work handler to transmitt the encrypted records in tx_list */
2251
+static void tx_work_handler(struct work_struct *work)
2252
+{
2253
+ struct delayed_work *delayed_work = to_delayed_work(work);
2254
+ struct tx_work *tx_work = container_of(delayed_work,
2255
+ struct tx_work, work);
2256
+ struct sock *sk = tx_work->sk;
2257
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
2258
+ struct tls_sw_context_tx *ctx;
2259
+
2260
+ if (unlikely(!tls_ctx))
2261
+ return;
2262
+
2263
+ ctx = tls_sw_ctx_tx(tls_ctx);
2264
+ if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
2265
+ return;
2266
+
2267
+ if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
2268
+ return;
2269
+ mutex_lock(&tls_ctx->tx_lock);
2270
+ lock_sock(sk);
2271
+ tls_tx_records(sk, -1);
2272
+ release_sock(sk);
2273
+ mutex_unlock(&tls_ctx->tx_lock);
2274
+}
2275
+
2276
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
2277
+{
2278
+ struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
2279
+
2280
+ /* Schedule the transmission if tx list is ready */
2281
+ if (is_tx_ready(tx_ctx) &&
2282
+ !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
2283
+ schedule_delayed_work(&tx_ctx->tx_work.work, 0);
2284
+}
2285
+
2286
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
2287
+{
2288
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
2289
+
2290
+ write_lock_bh(&sk->sk_callback_lock);
2291
+ rx_ctx->saved_data_ready = sk->sk_data_ready;
2292
+ sk->sk_data_ready = tls_data_ready;
2293
+ write_unlock_bh(&sk->sk_callback_lock);
2294
+
2295
+ strp_check_rcv(&rx_ctx->strp);
11542296 }
11552297
11562298 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
11572299 {
2300
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
2301
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
11582302 struct tls_crypto_info *crypto_info;
11592303 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
2304
+ struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
2305
+ struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
11602306 struct tls_sw_context_tx *sw_ctx_tx = NULL;
11612307 struct tls_sw_context_rx *sw_ctx_rx = NULL;
11622308 struct cipher_context *cctx;
11632309 struct crypto_aead **aead;
11642310 struct strp_callbacks cb;
1165
- u16 nonce_size, tag_size, iv_size, rec_seq_size;
1166
- char *iv, *rec_seq;
2311
+ u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
2312
+ struct crypto_tfm *tfm;
2313
+ char *iv, *rec_seq, *key, *salt, *cipher_name;
2314
+ size_t keysize;
11672315 int rc = 0;
11682316
11692317 if (!ctx) {
....@@ -1199,13 +2347,19 @@
11992347
12002348 if (tx) {
12012349 crypto_init_wait(&sw_ctx_tx->async_wait);
2350
+ spin_lock_init(&sw_ctx_tx->encrypt_compl_lock);
12022351 crypto_info = &ctx->crypto_send.info;
12032352 cctx = &ctx->tx;
12042353 aead = &sw_ctx_tx->aead_send;
2354
+ INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
2355
+ INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
2356
+ sw_ctx_tx->tx_work.sk = sk;
12052357 } else {
12062358 crypto_init_wait(&sw_ctx_rx->async_wait);
2359
+ spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
12072360 crypto_info = &ctx->crypto_recv.info;
12082361 cctx = &ctx->rx;
2362
+ skb_queue_head_init(&sw_ctx_rx->rx_list);
12092363 aead = &sw_ctx_rx->aead_recv;
12102364 }
12112365
....@@ -1220,6 +2374,45 @@
12202374 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
12212375 gcm_128_info =
12222376 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
2377
+ keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
2378
+ key = gcm_128_info->key;
2379
+ salt = gcm_128_info->salt;
2380
+ salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
2381
+ cipher_name = "gcm(aes)";
2382
+ break;
2383
+ }
2384
+ case TLS_CIPHER_AES_GCM_256: {
2385
+ nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2386
+ tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
2387
+ iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
2388
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
2389
+ rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
2390
+ rec_seq =
2391
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
2392
+ gcm_256_info =
2393
+ (struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
2394
+ keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
2395
+ key = gcm_256_info->key;
2396
+ salt = gcm_256_info->salt;
2397
+ salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
2398
+ cipher_name = "gcm(aes)";
2399
+ break;
2400
+ }
2401
+ case TLS_CIPHER_AES_CCM_128: {
2402
+ nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2403
+ tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
2404
+ iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
2405
+ iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
2406
+ rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
2407
+ rec_seq =
2408
+ ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
2409
+ ccm_128_info =
2410
+ (struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
2411
+ keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
2412
+ key = ccm_128_info->key;
2413
+ salt = ccm_128_info->salt;
2414
+ salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
2415
+ cipher_name = "ccm(aes)";
12232416 break;
12242417 }
12252418 default:
....@@ -1227,53 +2420,47 @@
12272420 goto free_priv;
12282421 }
12292422
1230
- /* Sanity-check the IV size for stack allocations. */
1231
- if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
2423
+ /* Sanity-check the sizes for stack allocations. */
2424
+ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
2425
+ rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
12322426 rc = -EINVAL;
12332427 goto free_priv;
12342428 }
12352429
1236
- cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1237
- cctx->tag_size = tag_size;
1238
- cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1239
- cctx->iv_size = iv_size;
1240
- cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1241
- GFP_KERNEL);
2430
+ if (crypto_info->version == TLS_1_3_VERSION) {
2431
+ nonce_size = 0;
2432
+ prot->aad_size = TLS_HEADER_SIZE;
2433
+ prot->tail_size = 1;
2434
+ } else {
2435
+ prot->aad_size = TLS_AAD_SPACE_SIZE;
2436
+ prot->tail_size = 0;
2437
+ }
2438
+
2439
+ prot->version = crypto_info->version;
2440
+ prot->cipher_type = crypto_info->cipher_type;
2441
+ prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
2442
+ prot->tag_size = tag_size;
2443
+ prot->overhead_size = prot->prepend_size +
2444
+ prot->tag_size + prot->tail_size;
2445
+ prot->iv_size = iv_size;
2446
+ prot->salt_size = salt_size;
2447
+ cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
12422448 if (!cctx->iv) {
12432449 rc = -ENOMEM;
12442450 goto free_priv;
12452451 }
1246
- memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1247
- memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1248
- cctx->rec_seq_size = rec_seq_size;
2452
+ /* Note: 128 & 256 bit salt are the same size */
2453
+ prot->rec_seq_size = rec_seq_size;
2454
+ memcpy(cctx->iv, salt, salt_size);
2455
+ memcpy(cctx->iv + salt_size, iv, iv_size);
12492456 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
12502457 if (!cctx->rec_seq) {
12512458 rc = -ENOMEM;
12522459 goto free_iv;
12532460 }
12542461
1255
- if (sw_ctx_tx) {
1256
- sg_init_table(sw_ctx_tx->sg_encrypted_data,
1257
- ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1258
- sg_init_table(sw_ctx_tx->sg_plaintext_data,
1259
- ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1260
-
1261
- sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1262
- sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1263
- sizeof(sw_ctx_tx->aad_space));
1264
- sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1265
- sg_chain(sw_ctx_tx->sg_aead_in, 2,
1266
- sw_ctx_tx->sg_plaintext_data);
1267
- sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1268
- sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1269
- sizeof(sw_ctx_tx->aad_space));
1270
- sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1271
- sg_chain(sw_ctx_tx->sg_aead_out, 2,
1272
- sw_ctx_tx->sg_encrypted_data);
1273
- }
1274
-
12752462 if (!*aead) {
1276
- *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
2463
+ *aead = crypto_alloc_aead(cipher_name, 0, 0);
12772464 if (IS_ERR(*aead)) {
12782465 rc = PTR_ERR(*aead);
12792466 *aead = NULL;
....@@ -1283,31 +2470,31 @@
12832470
12842471 ctx->push_pending_record = tls_sw_push_pending_record;
12852472
1286
- rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1287
- TLS_CIPHER_AES_GCM_128_KEY_SIZE);
2473
+ rc = crypto_aead_setkey(*aead, key, keysize);
2474
+
12882475 if (rc)
12892476 goto free_aead;
12902477
1291
- rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
2478
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
12922479 if (rc)
12932480 goto free_aead;
12942481
12952482 if (sw_ctx_rx) {
2483
+ tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
2484
+
2485
+ if (crypto_info->version == TLS_1_3_VERSION)
2486
+ sw_ctx_rx->async_capable = 0;
2487
+ else
2488
+ sw_ctx_rx->async_capable =
2489
+ !!(tfm->__crt_alg->cra_flags &
2490
+ CRYPTO_ALG_ASYNC);
2491
+
12962492 /* Set up strparser */
12972493 memset(&cb, 0, sizeof(cb));
12982494 cb.rcv_msg = tls_queue;
12992495 cb.parse_msg = tls_read_size;
13002496
13012497 strp_init(&sw_ctx_rx->strp, sk, &cb);
1302
-
1303
- write_lock_bh(&sk->sk_callback_lock);
1304
- sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1305
- sk->sk_data_ready = tls_data_ready;
1306
- write_unlock_bh(&sk->sk_callback_lock);
1307
-
1308
- sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1309
-
1310
- strp_check_rcv(&sw_ctx_rx->strp);
13112498 }
13122499
13132500 goto out;