1// Copyright 2018 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package paths
16
17import (
18	"context"
19	"encoding/gob"
20	"fmt"
21	"io/ioutil"
22	"net"
23	"os"
24	"path/filepath"
25	"runtime"
26	"sync"
27	"syscall"
28	"time"
29)
30
31type LogProcess struct {
32	Pid     int
33	Command string
34}
35
36type LogEntry struct {
37	Basename string
38	Args     []string
39	Parents  []LogProcess
40}
41
42const timeoutDuration = time.Duration(100) * time.Millisecond
43
44type socketAddrFunc func(string) (string, func(), error)
45
46func procFallback(name string) (string, func(), error) {
47	d, err := os.Open(filepath.Dir(name))
48	if err != nil {
49		return "", func() {}, err
50	}
51
52	return fmt.Sprintf("/proc/self/fd/%d/%s", d.Fd(), filepath.Base(name)), func() {
53		d.Close()
54	}, nil
55}
56
57func tmpFallback(name string) (addr string, cleanup func(), err error) {
58	d, err := ioutil.TempDir("/tmp", "log_sock")
59	if err != nil {
60		cleanup = func() {}
61		return
62	}
63	cleanup = func() {
64		os.RemoveAll(d)
65	}
66
67	dir := filepath.Dir(name)
68
69	absDir, err := filepath.Abs(dir)
70	if err != nil {
71		return
72	}
73
74	err = os.Symlink(absDir, filepath.Join(d, "d"))
75	if err != nil {
76		return
77	}
78
79	addr = filepath.Join(d, "d", filepath.Base(name))
80
81	return
82}
83
84func getSocketAddr(name string) (string, func(), error) {
85	maxNameLen := len(syscall.RawSockaddrUnix{}.Path)
86
87	if len(name) < maxNameLen {
88		return name, func() {}, nil
89	}
90
91	if runtime.GOOS == "linux" {
92		addr, cleanup, err := procFallback(name)
93		if err == nil {
94			if len(addr) < maxNameLen {
95				return addr, cleanup, nil
96			}
97		}
98		cleanup()
99	}
100
101	addr, cleanup, err := tmpFallback(name)
102	if err == nil {
103		if len(addr) < maxNameLen {
104			return addr, cleanup, nil
105		}
106	}
107	cleanup()
108
109	return name, func() {}, fmt.Errorf("Path to socket is still over size limit, fallbacks failed.")
110}
111
112func dial(name string, lookup socketAddrFunc, timeout time.Duration) (net.Conn, error) {
113	socket, cleanup, err := lookup(name)
114	defer cleanup()
115	if err != nil {
116		return nil, err
117	}
118
119	dialer := &net.Dialer{
120		Timeout: timeout,
121	}
122	return dialer.Dial("unix", socket)
123}
124
125func listen(name string, lookup socketAddrFunc) (net.Listener, error) {
126	socket, cleanup, err := lookup(name)
127	defer cleanup()
128	if err != nil {
129		return nil, err
130	}
131
132	return net.Listen("unix", socket)
133}
134
135func SendLog(logSocket string, entry *LogEntry, done chan interface{}) {
136	sendLog(logSocket, getSocketAddr, timeoutDuration, entry, done)
137}
138
139func sendLog(logSocket string, lookup socketAddrFunc, timeout time.Duration, entry *LogEntry, done chan interface{}) {
140	defer close(done)
141
142	conn, err := dial(logSocket, lookup, timeout)
143	if err != nil {
144		return
145	}
146	defer conn.Close()
147
148	if timeout != 0 {
149		conn.SetDeadline(time.Now().Add(timeout))
150	}
151
152	enc := gob.NewEncoder(conn)
153	enc.Encode(entry)
154}
155
156func LogListener(ctx context.Context, logSocket string) (chan *LogEntry, error) {
157	return logListener(ctx, logSocket, getSocketAddr)
158}
159
160func logListener(ctx context.Context, logSocket string, lookup socketAddrFunc) (chan *LogEntry, error) {
161	ret := make(chan *LogEntry, 5)
162
163	if err := os.Remove(logSocket); err != nil && !os.IsNotExist(err) {
164		return nil, err
165	}
166
167	ln, err := listen(logSocket, lookup)
168	if err != nil {
169		return nil, err
170	}
171
172	go func() {
173		for {
174			select {
175			case <-ctx.Done():
176				ln.Close()
177			}
178		}
179	}()
180
181	go func() {
182		var wg sync.WaitGroup
183		defer func() {
184			wg.Wait()
185			close(ret)
186		}()
187
188		for {
189			conn, err := ln.Accept()
190			if err != nil {
191				ln.Close()
192				break
193			}
194			conn.SetDeadline(time.Now().Add(timeoutDuration))
195			wg.Add(1)
196
197			go func() {
198				defer wg.Done()
199				defer conn.Close()
200
201				dec := gob.NewDecoder(conn)
202				entry := &LogEntry{}
203				if err := dec.Decode(entry); err != nil {
204					return
205				}
206				ret <- entry
207			}()
208		}
209	}()
210	return ret, nil
211}
212