ronnie
2022-10-14 1504bb53e29d3d46222c0b3ea994fc494b48e153
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
 
// +build ignore
 
// This program is run via "go generate" (via a directive in sort.go)
// to generate zfuncversion.go.
//
// It copies sort.go to zfuncversion.go, only retaining funcs which
// take a "data Interface" parameter, and renaming each to have a
// "_func" suffix and taking a "data lessSwap" instead. It then rewrites
// each internal function call to the appropriate _func variants.
 
package main
 
import (
   "bytes"
   "go/ast"
   "go/format"
   "go/parser"
   "go/token"
   "io/ioutil"
   "log"
   "regexp"
)
 
var fset = token.NewFileSet()
 
func main() {
   af, err := parser.ParseFile(fset, "sort.go", nil, 0)
   if err != nil {
       log.Fatal(err)
   }
   af.Doc = nil
   af.Imports = nil
   af.Comments = nil
 
   var newDecl []ast.Decl
   for _, d := range af.Decls {
       fd, ok := d.(*ast.FuncDecl)
       if !ok {
           continue
       }
       if fd.Recv != nil || fd.Name.IsExported() {
           continue
       }
       typ := fd.Type
       if len(typ.Params.List) < 1 {
           continue
       }
       arg0 := typ.Params.List[0]
       arg0Name := arg0.Names[0].Name
       arg0Type := arg0.Type.(*ast.Ident)
       if arg0Name != "data" || arg0Type.Name != "Interface" {
           continue
       }
       arg0Type.Name = "lessSwap"
 
       newDecl = append(newDecl, fd)
   }
   af.Decls = newDecl
   ast.Walk(visitFunc(rewriteCalls), af)
 
   var out bytes.Buffer
   if err := format.Node(&out, fset, af); err != nil {
       log.Fatalf("format.Node: %v", err)
   }
 
   // Get rid of blank lines after removal of comments.
   src := regexp.MustCompile(`\n{2,}`).ReplaceAll(out.Bytes(), []byte("\n"))
 
   // Add comments to each func, for the lost reader.
   // This is so much easier than adding comments via the AST
   // and trying to get position info correct.
   src = regexp.MustCompile(`(?m)^func (\w+)`).ReplaceAll(src, []byte("\n// Auto-generated variant of sort.go:$1\nfunc ${1}_func"))
 
   // Final gofmt.
   src, err = format.Source(src)
   if err != nil {
       log.Fatalf("format.Source: %v on\n%s", err, src)
   }
 
   out.Reset()
   out.WriteString(`// Code generated from sort.go using genzfunc.go; DO NOT EDIT.
 
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
 
`)
   out.Write(src)
 
   const target = "zfuncversion.go"
   if err := ioutil.WriteFile(target, out.Bytes(), 0644); err != nil {
       log.Fatal(err)
   }
}
 
type visitFunc func(ast.Node) ast.Visitor
 
func (f visitFunc) Visit(n ast.Node) ast.Visitor { return f(n) }
 
func rewriteCalls(n ast.Node) ast.Visitor {
   ce, ok := n.(*ast.CallExpr)
   if ok {
       rewriteCall(ce)
   }
   return visitFunc(rewriteCalls)
}
 
func rewriteCall(ce *ast.CallExpr) {
   ident, ok := ce.Fun.(*ast.Ident)
   if !ok {
       // e.g. skip SelectorExpr (data.Less(..) calls)
       return
   }
   // skip casts
   if ident.Name == "int" || ident.Name == "uint" {
       return
   }
   if len(ce.Args) < 1 {
       return
   }
   ident.Name += "_func"
}