1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package symbol_inject
16
17import (
18	"bytes"
19	"encoding/binary"
20	"fmt"
21	"io"
22	"math"
23)
24
25var maxUint64 uint64 = math.MaxUint64
26
27type cantParseError struct {
28	error
29}
30
31func OpenFile(r io.ReaderAt) (*File, error) {
32	file, err := elfSymbolsFromFile(r)
33	if elfError, ok := err.(cantParseError); ok {
34		// Try as a mach-o file
35		file, err = machoSymbolsFromFile(r)
36		if _, ok := err.(cantParseError); ok {
37			// Try as a windows PE file
38			file, err = peSymbolsFromFile(r)
39			if _, ok := err.(cantParseError); ok {
40				// Can't parse as elf, macho, or PE, return the elf error
41				return nil, elfError
42			}
43		}
44	}
45	if err != nil {
46		return nil, err
47	}
48
49	file.r = r
50
51	return file, err
52}
53
54func InjectStringSymbol(file *File, w io.Writer, symbol, value, from string) error {
55	offset, size, err := findSymbol(file, symbol)
56	if err != nil {
57		return err
58	}
59
60	if uint64(len(value))+1 > size {
61		return fmt.Errorf("value length %d overflows symbol size %d", len(value), size)
62	}
63
64	if from != "" {
65		// Read the exsting symbol contents and verify they match the expected value
66		expected := make([]byte, size)
67		existing := make([]byte, size)
68		copy(expected, from)
69		_, err := file.r.ReadAt(existing, int64(offset))
70		if err != nil {
71			return err
72		}
73		if bytes.Compare(existing, expected) != 0 {
74			return fmt.Errorf("existing symbol contents %q did not match expected value %q",
75				string(existing), string(expected))
76		}
77	}
78
79	buf := make([]byte, size)
80	copy(buf, value)
81
82	return copyAndInject(file.r, w, offset, buf)
83}
84
85func InjectUint64Symbol(file *File, w io.Writer, symbol string, value uint64) error {
86	offset, size, err := findSymbol(file, symbol)
87	if err != nil {
88		return err
89	}
90
91	if size != 8 {
92		return fmt.Errorf("symbol %q is not a uint64, it is %d bytes long", symbol, size)
93	}
94
95	buf := make([]byte, 8)
96	binary.LittleEndian.PutUint64(buf, value)
97
98	return copyAndInject(file.r, w, offset, buf)
99}
100
101func copyAndInject(r io.ReaderAt, w io.Writer, offset uint64, buf []byte) (err error) {
102	// Copy the first bytes up to the symbol offset
103	_, err = io.Copy(w, io.NewSectionReader(r, 0, int64(offset)))
104
105	// Write the injected value in the output file
106	if err == nil {
107		_, err = w.Write(buf)
108	}
109
110	// Write the remainder of the file
111	pos := int64(offset) + int64(len(buf))
112	if err == nil {
113		_, err = io.Copy(w, io.NewSectionReader(r, pos, 1<<63-1-pos))
114	}
115
116	if err == io.EOF {
117		err = io.ErrUnexpectedEOF
118	}
119
120	return err
121}
122
123func findSymbol(file *File, symbolName string) (uint64, uint64, error) {
124	for i, symbol := range file.Symbols {
125		if symbol.Name == symbolName {
126			// Find the next symbol (n the same section with a higher address
127			var n int
128			for n = i; n < len(file.Symbols); n++ {
129				if file.Symbols[n].Section != symbol.Section {
130					n = len(file.Symbols)
131					break
132				}
133				if file.Symbols[n].Addr > symbol.Addr {
134					break
135				}
136			}
137
138			size := symbol.Size
139			if size == 0 {
140				var end uint64
141				if n < len(file.Symbols) {
142					end = file.Symbols[n].Addr
143				} else {
144					end = symbol.Section.Size
145				}
146
147				if end <= symbol.Addr || end > symbol.Addr+4096 {
148					return maxUint64, maxUint64, fmt.Errorf("symbol end address does not seem valid, %x:%x", symbol.Addr, end)
149				}
150
151				size = end - symbol.Addr
152			}
153
154			offset := symbol.Section.Offset + symbol.Addr
155
156			return uint64(offset), uint64(size), nil
157		}
158	}
159
160	return maxUint64, maxUint64, fmt.Errorf("symbol not found")
161}
162
163type File struct {
164	r        io.ReaderAt
165	Symbols  []*Symbol
166	Sections []*Section
167}
168
169type Symbol struct {
170	Name    string
171	Addr    uint64 // Address of the symbol inside the section.
172	Size    uint64 // Size of the symbol, if known.
173	Section *Section
174}
175
176type Section struct {
177	Name   string
178	Addr   uint64 // Virtual address of the start of the section.
179	Offset uint64 // Offset into the file of the start of the section.
180	Size   uint64
181}
182
183func DumpSymbols(r io.ReaderAt) error {
184	err := dumpElfSymbols(r)
185	if elfError, ok := err.(cantParseError); ok {
186		// Try as a mach-o file
187		err = dumpMachoSymbols(r)
188		if _, ok := err.(cantParseError); ok {
189			// Try as a windows PE file
190			err = dumpPESymbols(r)
191			if _, ok := err.(cantParseError); ok {
192				// Can't parse as elf, macho, or PE, return the elf error
193				return elfError
194			}
195		}
196	}
197	return err
198}
199