liyujie
2025-08-28 d9927380ed7c8366f762049be9f3fee225860833
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
 
// Calibration used to determine thresholds for using
// different algorithms.  Ideally, this would be converted
// to go generate to create thresholds.go
 
// This file prints execution times for the Mul benchmark
// given different Karatsuba thresholds. The result may be
// used to manually fine-tune the threshold constant. The
// results are somewhat fragile; use repeated runs to get
// a clear picture.
 
// Calculates lower and upper thresholds for when basicSqr
// is faster than standard multiplication.
 
// Usage: go test -run=TestCalibrate -v -calibrate
 
package big
 
import (
   "flag"
   "fmt"
   "testing"
   "time"
)
 
var calibrate = flag.Bool("calibrate", false, "run calibration test")
 
const (
   sqrModeMul       = "mul(x, x)"
   sqrModeBasic     = "basicSqr(x)"
   sqrModeKaratsuba = "karatsubaSqr(x)"
)
 
func TestCalibrate(t *testing.T) {
   if !*calibrate {
       return
   }
 
   computeKaratsubaThresholds()
 
   // compute basicSqrThreshold where overhead becomes negligible
   minSqr := computeSqrThreshold(10, 30, 1, 3, sqrModeMul, sqrModeBasic)
   // compute karatsubaSqrThreshold where karatsuba is faster
   maxSqr := computeSqrThreshold(200, 500, 10, 3, sqrModeBasic, sqrModeKaratsuba)
   if minSqr != 0 {
       fmt.Printf("found basicSqrThreshold = %d\n", minSqr)
   } else {
       fmt.Println("no basicSqrThreshold found")
   }
   if maxSqr != 0 {
       fmt.Printf("found karatsubaSqrThreshold = %d\n", maxSqr)
   } else {
       fmt.Println("no karatsubaSqrThreshold found")
   }
}
 
func karatsubaLoad(b *testing.B) {
   BenchmarkMul(b)
}
 
// measureKaratsuba returns the time to run a Karatsuba-relevant benchmark
// given Karatsuba threshold th.
func measureKaratsuba(th int) time.Duration {
   th, karatsubaThreshold = karatsubaThreshold, th
   res := testing.Benchmark(karatsubaLoad)
   karatsubaThreshold = th
   return time.Duration(res.NsPerOp())
}
 
func computeKaratsubaThresholds() {
   fmt.Printf("Multiplication times for varying Karatsuba thresholds\n")
   fmt.Printf("(run repeatedly for good results)\n")
 
   // determine Tk, the work load execution time using basic multiplication
   Tb := measureKaratsuba(1e9) // th == 1e9 => Karatsuba multiplication disabled
   fmt.Printf("Tb = %10s\n", Tb)
 
   // thresholds
   th := 4
   th1 := -1
   th2 := -1
 
   var deltaOld time.Duration
   for count := -1; count != 0 && th < 128; count-- {
       // determine Tk, the work load execution time using Karatsuba multiplication
       Tk := measureKaratsuba(th)
 
       // improvement over Tb
       delta := (Tb - Tk) * 100 / Tb
 
       fmt.Printf("th = %3d  Tk = %10s  %4d%%", th, Tk, delta)
 
       // determine break-even point
       if Tk < Tb && th1 < 0 {
           th1 = th
           fmt.Print("  break-even point")
       }
 
       // determine diminishing return
       if 0 < delta && delta < deltaOld && th2 < 0 {
           th2 = th
           fmt.Print("  diminishing return")
       }
       deltaOld = delta
 
       fmt.Println()
 
       // trigger counter
       if th1 >= 0 && th2 >= 0 && count < 0 {
           count = 10 // this many extra measurements after we got both thresholds
       }
 
       th++
   }
}
 
func measureSqr(words, nruns int, mode string) time.Duration {
   // more runs for better statistics
   initBasicSqr, initKaratsubaSqr := basicSqrThreshold, karatsubaSqrThreshold
 
   switch mode {
   case sqrModeMul:
       basicSqrThreshold = words + 1
   case sqrModeBasic:
       basicSqrThreshold, karatsubaSqrThreshold = words-1, words+1
   case sqrModeKaratsuba:
       karatsubaSqrThreshold = words - 1
   }
 
   var testval int64
   for i := 0; i < nruns; i++ {
       res := testing.Benchmark(func(b *testing.B) { benchmarkNatSqr(b, words) })
       testval += res.NsPerOp()
   }
   testval /= int64(nruns)
 
   basicSqrThreshold, karatsubaSqrThreshold = initBasicSqr, initKaratsubaSqr
 
   return time.Duration(testval)
}
 
func computeSqrThreshold(from, to, step, nruns int, lower, upper string) int {
   fmt.Printf("Calibrating threshold between %s and %s\n", lower, upper)
   fmt.Printf("Looking for a timing difference for x between %d - %d words by %d step\n", from, to, step)
   var initPos bool
   var threshold int
   for i := from; i <= to; i += step {
       baseline := measureSqr(i, nruns, lower)
       testval := measureSqr(i, nruns, upper)
       pos := baseline > testval
       delta := baseline - testval
       percent := delta * 100 / baseline
       fmt.Printf("words = %3d deltaT = %10s (%4d%%) is %s better: %v", i, delta, percent, upper, pos)
       if i == from {
           initPos = pos
       }
       if threshold == 0 && pos != initPos {
           threshold = i
           fmt.Printf("  threshold  found")
       }
       fmt.Println()
 
   }
   if threshold != 0 {
       fmt.Printf("Found threshold = %d between %d - %d\n", threshold, from, to)
   } else {
       fmt.Printf("Found NO threshold between %d - %d\n", from, to)
   }
   return threshold
}