hc
2024-01-31 f70575805708cabdedea7498aaa3f710fde4d920
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/* SPDX-License-Identifier: GPL-2.0 */
/*
 * NH - ε-almost-universal hash function, x86_64 AVX2 accelerated
 *
 * Copyright 2018 Google LLC
 *
 * Author: Eric Biggers <ebiggers@google.com>
 */
 
#include <linux/linkage.h>
 
#define        PASS0_SUMS    %ymm0
#define        PASS1_SUMS    %ymm1
#define        PASS2_SUMS    %ymm2
#define        PASS3_SUMS    %ymm3
#define        K0        %ymm4
#define        K0_XMM        %xmm4
#define        K1        %ymm5
#define        K1_XMM        %xmm5
#define        K2        %ymm6
#define        K2_XMM        %xmm6
#define        K3        %ymm7
#define        K3_XMM        %xmm7
#define        T0        %ymm8
#define        T1        %ymm9
#define        T2        %ymm10
#define        T2_XMM        %xmm10
#define        T3        %ymm11
#define        T3_XMM        %xmm11
#define        T4        %ymm12
#define        T5        %ymm13
#define        T6        %ymm14
#define        T7        %ymm15
#define        KEY        %rdi
#define        MESSAGE        %rsi
#define        MESSAGE_LEN    %rdx
#define        HASH        %rcx
 
.macro _nh_2xstride    k0, k1, k2, k3
 
   // Add message words to key words
   vpaddd        \k0, T3, T0
   vpaddd        \k1, T3, T1
   vpaddd        \k2, T3, T2
   vpaddd        \k3, T3, T3
 
   // Multiply 32x32 => 64 and accumulate
   vpshufd        $0x10, T0, T4
   vpshufd        $0x32, T0, T0
   vpshufd        $0x10, T1, T5
   vpshufd        $0x32, T1, T1
   vpshufd        $0x10, T2, T6
   vpshufd        $0x32, T2, T2
   vpshufd        $0x10, T3, T7
   vpshufd        $0x32, T3, T3
   vpmuludq    T4, T0, T0
   vpmuludq    T5, T1, T1
   vpmuludq    T6, T2, T2
   vpmuludq    T7, T3, T3
   vpaddq        T0, PASS0_SUMS, PASS0_SUMS
   vpaddq        T1, PASS1_SUMS, PASS1_SUMS
   vpaddq        T2, PASS2_SUMS, PASS2_SUMS
   vpaddq        T3, PASS3_SUMS, PASS3_SUMS
.endm
 
/*
 * void nh_avx2(const u32 *key, const u8 *message, size_t message_len,
 *        u8 hash[NH_HASH_BYTES])
 *
 * It's guaranteed that message_len % 16 == 0.
 */
SYM_FUNC_START(nh_avx2)
 
   vmovdqu        0x00(KEY), K0
   vmovdqu        0x10(KEY), K1
   add        $0x20, KEY
   vpxor        PASS0_SUMS, PASS0_SUMS, PASS0_SUMS
   vpxor        PASS1_SUMS, PASS1_SUMS, PASS1_SUMS
   vpxor        PASS2_SUMS, PASS2_SUMS, PASS2_SUMS
   vpxor        PASS3_SUMS, PASS3_SUMS, PASS3_SUMS
 
   sub        $0x40, MESSAGE_LEN
   jl        .Lloop4_done
.Lloop4:
   vmovdqu        (MESSAGE), T3
   vmovdqu        0x00(KEY), K2
   vmovdqu        0x10(KEY), K3
   _nh_2xstride    K0, K1, K2, K3
 
   vmovdqu        0x20(MESSAGE), T3
   vmovdqu        0x20(KEY), K0
   vmovdqu        0x30(KEY), K1
   _nh_2xstride    K2, K3, K0, K1
 
   add        $0x40, MESSAGE
   add        $0x40, KEY
   sub        $0x40, MESSAGE_LEN
   jge        .Lloop4
 
.Lloop4_done:
   and        $0x3f, MESSAGE_LEN
   jz        .Ldone
 
   cmp        $0x20, MESSAGE_LEN
   jl        .Llast
 
   // 2 or 3 strides remain; do 2 more.
   vmovdqu        (MESSAGE), T3
   vmovdqu        0x00(KEY), K2
   vmovdqu        0x10(KEY), K3
   _nh_2xstride    K0, K1, K2, K3
   add        $0x20, MESSAGE
   add        $0x20, KEY
   sub        $0x20, MESSAGE_LEN
   jz        .Ldone
   vmovdqa        K2, K0
   vmovdqa        K3, K1
.Llast:
   // Last stride.  Zero the high 128 bits of the message and keys so they
   // don't affect the result when processing them like 2 strides.
   vmovdqu        (MESSAGE), T3_XMM
   vmovdqa        K0_XMM, K0_XMM
   vmovdqa        K1_XMM, K1_XMM
   vmovdqu        0x00(KEY), K2_XMM
   vmovdqu        0x10(KEY), K3_XMM
   _nh_2xstride    K0, K1, K2, K3
 
.Ldone:
   // Sum the accumulators for each pass, then store the sums to 'hash'
 
   // PASS0_SUMS is (0A 0B 0C 0D)
   // PASS1_SUMS is (1A 1B 1C 1D)
   // PASS2_SUMS is (2A 2B 2C 2D)
   // PASS3_SUMS is (3A 3B 3C 3D)
   // We need the horizontal sums:
   //     (0A + 0B + 0C + 0D,
   //    1A + 1B + 1C + 1D,
   //    2A + 2B + 2C + 2D,
   //    3A + 3B + 3C + 3D)
   //
 
   vpunpcklqdq    PASS1_SUMS, PASS0_SUMS, T0    // T0 = (0A 1A 0C 1C)
   vpunpckhqdq    PASS1_SUMS, PASS0_SUMS, T1    // T1 = (0B 1B 0D 1D)
   vpunpcklqdq    PASS3_SUMS, PASS2_SUMS, T2    // T2 = (2A 3A 2C 3C)
   vpunpckhqdq    PASS3_SUMS, PASS2_SUMS, T3    // T3 = (2B 3B 2D 3D)
 
   vinserti128    $0x1, T2_XMM, T0, T4        // T4 = (0A 1A 2A 3A)
   vinserti128    $0x1, T3_XMM, T1, T5        // T5 = (0B 1B 2B 3B)
   vperm2i128    $0x31, T2, T0, T0        // T0 = (0C 1C 2C 3C)
   vperm2i128    $0x31, T3, T1, T1        // T1 = (0D 1D 2D 3D)
 
   vpaddq        T5, T4, T4
   vpaddq        T1, T0, T0
   vpaddq        T4, T0, T0
   vmovdqu        T0, (HASH)
   RET
SYM_FUNC_END(nh_avx2)