1// Mostly copied from Go's src/cmd/gofmt:
2// Copyright 2009 The Go Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style
4// license that can be found in the LICENSE file.
5
6package main
7
8import (
9	"bytes"
10	"flag"
11	"fmt"
12	"io"
13	"io/ioutil"
14	"os"
15	"os/exec"
16	"path/filepath"
17	"strings"
18	"unicode"
19
20	"github.com/google/blueprint/parser"
21)
22
23var (
24	// main operation modes
25	list            = flag.Bool("l", false, "list files that would be modified by bpmodify")
26	write           = flag.Bool("w", false, "write result to (source) file instead of stdout")
27	doDiff          = flag.Bool("d", false, "display diffs instead of rewriting files")
28	sortLists       = flag.Bool("s", false, "sort touched lists, even if they were unsorted")
29	parameter       = flag.String("parameter", "deps", "name of parameter to modify on each module")
30	targetedModules = new(identSet)
31	addIdents       = new(identSet)
32	removeIdents    = new(identSet)
33)
34
35func init() {
36	flag.Var(targetedModules, "m", "comma or whitespace separated list of modules on which to operate")
37	flag.Var(addIdents, "a", "comma or whitespace separated list of identifiers to add")
38	flag.Var(removeIdents, "r", "comma or whitespace separated list of identifiers to remove")
39	flag.Usage = usage
40}
41
42var (
43	exitCode = 0
44)
45
46func report(err error) {
47	fmt.Fprintln(os.Stderr, err)
48	exitCode = 2
49}
50
51func usage() {
52	fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [flags] [path ...]\n", os.Args[0])
53	flag.PrintDefaults()
54}
55
56// If in == nil, the source is the contents of the file with the given filename.
57func processFile(filename string, in io.Reader, out io.Writer) error {
58	if in == nil {
59		f, err := os.Open(filename)
60		if err != nil {
61			return err
62		}
63		defer f.Close()
64		in = f
65	}
66
67	src, err := ioutil.ReadAll(in)
68	if err != nil {
69		return err
70	}
71
72	r := bytes.NewBuffer(src)
73
74	file, errs := parser.Parse(filename, r, parser.NewScope(nil))
75	if len(errs) > 0 {
76		for _, err := range errs {
77			fmt.Fprintln(os.Stderr, err)
78		}
79		return fmt.Errorf("%d parsing errors", len(errs))
80	}
81
82	modified, errs := findModules(file)
83	if len(errs) > 0 {
84		for _, err := range errs {
85			fmt.Fprintln(os.Stderr, err)
86		}
87		fmt.Fprintln(os.Stderr, "continuing...")
88	}
89
90	if modified {
91		res, err := parser.Print(file)
92		if err != nil {
93			return err
94		}
95
96		if *list {
97			fmt.Fprintln(out, filename)
98		}
99		if *write {
100			err = ioutil.WriteFile(filename, res, 0644)
101			if err != nil {
102				return err
103			}
104		}
105		if *doDiff {
106			data, err := diff(src, res)
107			if err != nil {
108				return fmt.Errorf("computing diff: %s", err)
109			}
110			fmt.Printf("diff %s bpfmt/%s\n", filename, filename)
111			out.Write(data)
112		}
113
114		if !*list && !*write && !*doDiff {
115			_, err = out.Write(res)
116		}
117	}
118
119	return err
120}
121
122func findModules(file *parser.File) (modified bool, errs []error) {
123
124	for _, def := range file.Defs {
125		if module, ok := def.(*parser.Module); ok {
126			for _, prop := range module.Properties {
127				if prop.Name == "name" && prop.Value.Type() == parser.StringType {
128					if targetedModule(prop.Value.Eval().(*parser.String).Value) {
129						m, newErrs := processModule(module, prop.Name, file)
130						errs = append(errs, newErrs...)
131						modified = modified || m
132					}
133				}
134			}
135		}
136	}
137
138	return modified, errs
139}
140
141func processModule(module *parser.Module, moduleName string,
142	file *parser.File) (modified bool, errs []error) {
143
144	for _, prop := range module.Properties {
145		if prop.Name == *parameter {
146			modified, errs = processParameter(prop.Value, *parameter, moduleName, file)
147			return
148		}
149	}
150
151	prop := parser.Property{Name: *parameter, Value: &parser.List{}}
152	modified, errs = processParameter(prop.Value, *parameter, moduleName, file)
153
154	if modified {
155		module.Properties = append(module.Properties, &prop)
156	}
157
158	return modified, errs
159}
160
161func processParameter(value parser.Expression, paramName, moduleName string,
162	file *parser.File) (modified bool, errs []error) {
163	if _, ok := value.(*parser.Variable); ok {
164		return false, []error{fmt.Errorf("parameter %s in module %s is a variable, unsupported",
165			paramName, moduleName)}
166	}
167
168	if _, ok := value.(*parser.Operator); ok {
169		return false, []error{fmt.Errorf("parameter %s in module %s is an expression, unsupported",
170			paramName, moduleName)}
171	}
172
173	list, ok := value.(*parser.List)
174	if !ok {
175		return false, []error{fmt.Errorf("expected parameter %s in module %s to be list, found %s",
176			paramName, moduleName, value.Type().String())}
177	}
178
179	wasSorted := parser.ListIsSorted(list)
180
181	for _, a := range addIdents.idents {
182		m := parser.AddStringToList(list, a)
183		modified = modified || m
184	}
185
186	for _, r := range removeIdents.idents {
187		m := parser.RemoveStringFromList(list, r)
188		modified = modified || m
189	}
190
191	if (wasSorted || *sortLists) && modified {
192		parser.SortList(file, list)
193	}
194
195	return modified, nil
196}
197
198func targetedModule(name string) bool {
199	if targetedModules.all {
200		return true
201	}
202	for _, m := range targetedModules.idents {
203		if m == name {
204			return true
205		}
206	}
207
208	return false
209}
210
211func visitFile(path string, f os.FileInfo, err error) error {
212	if err == nil && f.Name() == "Blueprints" {
213		err = processFile(path, nil, os.Stdout)
214	}
215	if err != nil {
216		report(err)
217	}
218	return nil
219}
220
221func walkDir(path string) {
222	filepath.Walk(path, visitFile)
223}
224
225func main() {
226	defer func() {
227		if err := recover(); err != nil {
228			report(fmt.Errorf("error: %s", err))
229		}
230		os.Exit(exitCode)
231	}()
232
233	flag.Parse()
234
235	if flag.NArg() == 0 {
236		if *write {
237			report(fmt.Errorf("error: cannot use -w with standard input"))
238			return
239		}
240		if err := processFile("<standard input>", os.Stdin, os.Stdout); err != nil {
241			report(err)
242		}
243		return
244	}
245
246	if len(targetedModules.idents) == 0 {
247		report(fmt.Errorf("-m parameter is required"))
248		return
249	}
250
251	if len(addIdents.idents) == 0 && len(removeIdents.idents) == 0 {
252		report(fmt.Errorf("-a or -r parameter is required"))
253		return
254	}
255
256	for i := 0; i < flag.NArg(); i++ {
257		path := flag.Arg(i)
258		switch dir, err := os.Stat(path); {
259		case err != nil:
260			report(err)
261		case dir.IsDir():
262			walkDir(path)
263		default:
264			if err := processFile(path, nil, os.Stdout); err != nil {
265				report(err)
266			}
267		}
268	}
269}
270
271func diff(b1, b2 []byte) (data []byte, err error) {
272	f1, err := ioutil.TempFile("", "bpfmt")
273	if err != nil {
274		return
275	}
276	defer os.Remove(f1.Name())
277	defer f1.Close()
278
279	f2, err := ioutil.TempFile("", "bpfmt")
280	if err != nil {
281		return
282	}
283	defer os.Remove(f2.Name())
284	defer f2.Close()
285
286	f1.Write(b1)
287	f2.Write(b2)
288
289	data, err = exec.Command("diff", "-uw", f1.Name(), f2.Name()).CombinedOutput()
290	if len(data) > 0 {
291		// diff exits with a non-zero status when the files don't match.
292		// Ignore that failure as long as we get output.
293		err = nil
294	}
295	return
296
297}
298
299type identSet struct {
300	idents []string
301	all    bool
302}
303
304func (m *identSet) String() string {
305	return strings.Join(m.idents, ",")
306}
307
308func (m *identSet) Set(s string) error {
309	m.idents = strings.FieldsFunc(s, func(c rune) bool {
310		return unicode.IsSpace(c) || c == ','
311	})
312	if len(m.idents) == 1 && m.idents[0] == "*" {
313		m.all = true
314	}
315	return nil
316}
317
318func (m *identSet) Get() interface{} {
319	return m.idents
320}
321