1 /*
2  * Copyright (C) 2009 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "edify/expr.h"
18 
19 #include <stdarg.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <string.h>
23 #include <unistd.h>
24 
25 #include <memory>
26 #include <string>
27 #include <unordered_map>
28 #include <vector>
29 
30 #include <android-base/parseint.h>
31 #include <android-base/stringprintf.h>
32 #include <android-base/strings.h>
33 
34 #include "otautil/error_code.h"
35 
36 // Functions should:
37 //
38 //    - return a malloc()'d string
39 //    - if Evaluate() on any argument returns nullptr, return nullptr.
40 
BooleanString(const std::string & s)41 static bool BooleanString(const std::string& s) {
42     return !s.empty();
43 }
44 
Evaluate(State * state,const std::unique_ptr<Expr> & expr,std::string * result)45 bool Evaluate(State* state, const std::unique_ptr<Expr>& expr, std::string* result) {
46     if (result == nullptr) {
47         return false;
48     }
49 
50     std::unique_ptr<Value> v(expr->fn(expr->name.c_str(), state, expr->argv));
51     if (!v) {
52         return false;
53     }
54     if (v->type != Value::Type::STRING) {
55       ErrorAbort(state, kArgsParsingFailure, "expecting string, got value type %d", v->type);
56       return false;
57     }
58 
59     *result = v->data;
60     return true;
61 }
62 
EvaluateValue(State * state,const std::unique_ptr<Expr> & expr)63 Value* EvaluateValue(State* state, const std::unique_ptr<Expr>& expr) {
64     return expr->fn(expr->name.c_str(), state, expr->argv);
65 }
66 
StringValue(const char * str)67 Value* StringValue(const char* str) {
68     if (str == nullptr) {
69         return nullptr;
70     }
71     return new Value(Value::Type::STRING, str);
72 }
73 
StringValue(const std::string & str)74 Value* StringValue(const std::string& str) {
75     return StringValue(str.c_str());
76 }
77 
ConcatFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)78 Value* ConcatFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
79     if (argv.empty()) {
80         return StringValue("");
81     }
82     std::string result;
83     for (size_t i = 0; i < argv.size(); ++i) {
84         std::string str;
85         if (!Evaluate(state, argv[i], &str)) {
86             return nullptr;
87         }
88         result += str;
89     }
90 
91     return StringValue(result);
92 }
93 
IfElseFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)94 Value* IfElseFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
95     if (argv.size() != 2 && argv.size() != 3) {
96         state->errmsg = "ifelse expects 2 or 3 arguments";
97         return nullptr;
98     }
99 
100     std::string cond;
101     if (!Evaluate(state, argv[0], &cond)) {
102         return nullptr;
103     }
104 
105     if (!cond.empty()) {
106         return EvaluateValue(state, argv[1]);
107     } else if (argv.size() == 3) {
108         return EvaluateValue(state, argv[2]);
109     }
110 
111     return StringValue("");
112 }
113 
AbortFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)114 Value* AbortFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
115     std::string msg;
116     if (!argv.empty() && Evaluate(state, argv[0], &msg)) {
117       state->errmsg += msg;
118     } else {
119       state->errmsg += "called abort()";
120     }
121     return nullptr;
122 }
123 
AssertFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)124 Value* AssertFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
125     for (size_t i = 0; i < argv.size(); ++i) {
126         std::string result;
127         if (!Evaluate(state, argv[i], &result)) {
128             return nullptr;
129         }
130         if (result.empty()) {
131             int len = argv[i]->end - argv[i]->start;
132             state->errmsg = "assert failed: " + state->script.substr(argv[i]->start, len);
133             return nullptr;
134         }
135     }
136     return StringValue("");
137 }
138 
SleepFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)139 Value* SleepFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
140     std::string val;
141     if (!Evaluate(state, argv[0], &val)) {
142         return nullptr;
143     }
144 
145     int v;
146     if (!android::base::ParseInt(val.c_str(), &v, 0)) {
147         return nullptr;
148     }
149     sleep(v);
150 
151     return StringValue(val);
152 }
153 
StdoutFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)154 Value* StdoutFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
155     for (size_t i = 0; i < argv.size(); ++i) {
156         std::string v;
157         if (!Evaluate(state, argv[i], &v)) {
158             return nullptr;
159         }
160         fputs(v.c_str(), stdout);
161     }
162     return StringValue("");
163 }
164 
LogicalAndFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)165 Value* LogicalAndFn(const char* name, State* state,
166                     const std::vector<std::unique_ptr<Expr>>& argv) {
167     std::string left;
168     if (!Evaluate(state, argv[0], &left)) {
169         return nullptr;
170     }
171     if (BooleanString(left)) {
172         return EvaluateValue(state, argv[1]);
173     } else {
174         return StringValue("");
175     }
176 }
177 
LogicalOrFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)178 Value* LogicalOrFn(const char* name, State* state,
179                    const std::vector<std::unique_ptr<Expr>>& argv) {
180     std::string left;
181     if (!Evaluate(state, argv[0], &left)) {
182         return nullptr;
183     }
184     if (!BooleanString(left)) {
185         return EvaluateValue(state, argv[1]);
186     } else {
187         return StringValue(left);
188     }
189 }
190 
LogicalNotFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)191 Value* LogicalNotFn(const char* name, State* state,
192                     const std::vector<std::unique_ptr<Expr>>& argv) {
193     std::string val;
194     if (!Evaluate(state, argv[0], &val)) {
195         return nullptr;
196     }
197 
198     return StringValue(BooleanString(val) ? "" : "t");
199 }
200 
SubstringFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)201 Value* SubstringFn(const char* name, State* state,
202                    const std::vector<std::unique_ptr<Expr>>& argv) {
203     std::string needle;
204     if (!Evaluate(state, argv[0], &needle)) {
205         return nullptr;
206     }
207 
208     std::string haystack;
209     if (!Evaluate(state, argv[1], &haystack)) {
210         return nullptr;
211     }
212 
213     std::string result = (haystack.find(needle) != std::string::npos) ? "t" : "";
214     return StringValue(result);
215 }
216 
EqualityFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)217 Value* EqualityFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
218     std::string left;
219     if (!Evaluate(state, argv[0], &left)) {
220         return nullptr;
221     }
222     std::string right;
223     if (!Evaluate(state, argv[1], &right)) {
224         return nullptr;
225     }
226 
227     const char* result = (left == right) ? "t" : "";
228     return StringValue(result);
229 }
230 
InequalityFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)231 Value* InequalityFn(const char* name, State* state,
232                     const std::vector<std::unique_ptr<Expr>>& argv) {
233     std::string left;
234     if (!Evaluate(state, argv[0], &left)) {
235         return nullptr;
236     }
237     std::string right;
238     if (!Evaluate(state, argv[1], &right)) {
239         return nullptr;
240     }
241 
242     const char* result = (left != right) ? "t" : "";
243     return StringValue(result);
244 }
245 
SequenceFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)246 Value* SequenceFn(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
247     std::unique_ptr<Value> left(EvaluateValue(state, argv[0]));
248     if (!left) {
249         return nullptr;
250     }
251     return EvaluateValue(state, argv[1]);
252 }
253 
LessThanIntFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)254 Value* LessThanIntFn(const char* name, State* state,
255                      const std::vector<std::unique_ptr<Expr>>& argv) {
256     if (argv.size() != 2) {
257         state->errmsg = "less_than_int expects 2 arguments";
258         return nullptr;
259     }
260 
261     std::vector<std::string> args;
262     if (!ReadArgs(state, argv, &args)) {
263         return nullptr;
264     }
265 
266     // Parse up to at least long long or 64-bit integers.
267     int64_t l_int;
268     if (!android::base::ParseInt(args[0].c_str(), &l_int)) {
269         state->errmsg = "failed to parse int in " + args[0];
270         return nullptr;
271     }
272 
273     int64_t r_int;
274     if (!android::base::ParseInt(args[1].c_str(), &r_int)) {
275         state->errmsg = "failed to parse int in " + args[1];
276         return nullptr;
277     }
278 
279     return StringValue(l_int < r_int ? "t" : "");
280 }
281 
GreaterThanIntFn(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)282 Value* GreaterThanIntFn(const char* name, State* state,
283                         const std::vector<std::unique_ptr<Expr>>& argv) {
284     if (argv.size() != 2) {
285         state->errmsg = "greater_than_int expects 2 arguments";
286         return nullptr;
287     }
288 
289     std::vector<std::string> args;
290     if (!ReadArgs(state, argv, &args)) {
291         return nullptr;
292     }
293 
294     // Parse up to at least long long or 64-bit integers.
295     int64_t l_int;
296     if (!android::base::ParseInt(args[0].c_str(), &l_int)) {
297         state->errmsg = "failed to parse int in " + args[0];
298         return nullptr;
299     }
300 
301     int64_t r_int;
302     if (!android::base::ParseInt(args[1].c_str(), &r_int)) {
303         state->errmsg = "failed to parse int in " + args[1];
304         return nullptr;
305     }
306 
307     return StringValue(l_int > r_int ? "t" : "");
308 }
309 
Literal(const char * name,State * state,const std::vector<std::unique_ptr<Expr>> & argv)310 Value* Literal(const char* name, State* state, const std::vector<std::unique_ptr<Expr>>& argv) {
311     return StringValue(name);
312 }
313 
314 // -----------------------------------------------------------------
315 //   the function table
316 // -----------------------------------------------------------------
317 
318 static std::unordered_map<std::string, Function> fn_table;
319 
RegisterFunction(const std::string & name,Function fn)320 void RegisterFunction(const std::string& name, Function fn) {
321     fn_table[name] = fn;
322 }
323 
FindFunction(const std::string & name)324 Function FindFunction(const std::string& name) {
325     if (fn_table.find(name) == fn_table.end()) {
326         return nullptr;
327     } else {
328         return fn_table[name];
329     }
330 }
331 
RegisterBuiltins()332 void RegisterBuiltins() {
333     RegisterFunction("ifelse", IfElseFn);
334     RegisterFunction("abort", AbortFn);
335     RegisterFunction("assert", AssertFn);
336     RegisterFunction("concat", ConcatFn);
337     RegisterFunction("is_substring", SubstringFn);
338     RegisterFunction("stdout", StdoutFn);
339     RegisterFunction("sleep", SleepFn);
340 
341     RegisterFunction("less_than_int", LessThanIntFn);
342     RegisterFunction("greater_than_int", GreaterThanIntFn);
343 }
344 
345 
346 // -----------------------------------------------------------------
347 //   convenience methods for functions
348 // -----------------------------------------------------------------
349 
350 // Evaluate the expressions in argv, and put the results of strings in args. If any expression
351 // evaluates to nullptr, return false. Return true on success.
ReadArgs(State * state,const std::vector<std::unique_ptr<Expr>> & argv,std::vector<std::string> * args)352 bool ReadArgs(State* state, const std::vector<std::unique_ptr<Expr>>& argv,
353               std::vector<std::string>* args) {
354     return ReadArgs(state, argv, args, 0, argv.size());
355 }
356 
ReadArgs(State * state,const std::vector<std::unique_ptr<Expr>> & argv,std::vector<std::string> * args,size_t start,size_t len)357 bool ReadArgs(State* state, const std::vector<std::unique_ptr<Expr>>& argv,
358               std::vector<std::string>* args, size_t start, size_t len) {
359     if (args == nullptr) {
360         return false;
361     }
362     if (start + len > argv.size()) {
363         return false;
364     }
365     for (size_t i = start; i < start + len; ++i) {
366         std::string var;
367         if (!Evaluate(state, argv[i], &var)) {
368             args->clear();
369             return false;
370         }
371         args->push_back(var);
372     }
373     return true;
374 }
375 
376 // Evaluate the expressions in argv, and put the results of Value* in args. If any expression
377 // evaluate to nullptr, return false. Return true on success.
ReadValueArgs(State * state,const std::vector<std::unique_ptr<Expr>> & argv,std::vector<std::unique_ptr<Value>> * args)378 bool ReadValueArgs(State* state, const std::vector<std::unique_ptr<Expr>>& argv,
379                    std::vector<std::unique_ptr<Value>>* args) {
380     return ReadValueArgs(state, argv, args, 0, argv.size());
381 }
382 
ReadValueArgs(State * state,const std::vector<std::unique_ptr<Expr>> & argv,std::vector<std::unique_ptr<Value>> * args,size_t start,size_t len)383 bool ReadValueArgs(State* state, const std::vector<std::unique_ptr<Expr>>& argv,
384                    std::vector<std::unique_ptr<Value>>* args, size_t start, size_t len) {
385     if (args == nullptr) {
386         return false;
387     }
388     if (len == 0 || start + len > argv.size()) {
389         return false;
390     }
391     for (size_t i = start; i < start + len; ++i) {
392         std::unique_ptr<Value> v(EvaluateValue(state, argv[i]));
393         if (!v) {
394             args->clear();
395             return false;
396         }
397         args->push_back(std::move(v));
398     }
399     return true;
400 }
401 
402 // Use printf-style arguments to compose an error message to put into
403 // *state.  Returns nullptr.
ErrorAbort(State * state,const char * format,...)404 Value* ErrorAbort(State* state, const char* format, ...) {
405     va_list ap;
406     va_start(ap, format);
407     android::base::StringAppendV(&state->errmsg, format, ap);
408     va_end(ap);
409     return nullptr;
410 }
411 
ErrorAbort(State * state,CauseCode cause_code,const char * format,...)412 Value* ErrorAbort(State* state, CauseCode cause_code, const char* format, ...) {
413   std::string err_message;
414   va_list ap;
415   va_start(ap, format);
416   android::base::StringAppendV(&err_message, format, ap);
417   va_end(ap);
418   // Ensure that there's exactly one line break at the end of the error message.
419   state->errmsg = android::base::Trim(err_message) + "\n";
420   state->cause_code = cause_code;
421   return nullptr;
422 }
423 
State(const std::string & script,UpdaterInterface * interface)424 State::State(const std::string& script, UpdaterInterface* interface)
425     : script(script), updater(interface), error_code(kNoError), cause_code(kNoCause) {}
426