1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.view.textclassifier;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 import android.annotation.WorkerThread;
22 import android.app.RemoteAction;
23 import android.content.Context;
24 import android.content.Intent;
25 import android.icu.util.ULocale;
26 import android.os.Bundle;
27 import android.os.LocaleList;
28 import android.os.ParcelFileDescriptor;
29 import android.util.ArrayMap;
30 import android.util.ArraySet;
31 import android.util.Pair;
32 import android.view.textclassifier.ActionsModelParamsSupplier.ActionsModelParams;
33 import android.view.textclassifier.intent.ClassificationIntentFactory;
34 import android.view.textclassifier.intent.LabeledIntent;
35 import android.view.textclassifier.intent.LegacyClassificationIntentFactory;
36 import android.view.textclassifier.intent.TemplateClassificationIntentFactory;
37 import android.view.textclassifier.intent.TemplateIntentFactory;
38 
39 import com.android.internal.annotations.GuardedBy;
40 import com.android.internal.util.IndentingPrintWriter;
41 import com.android.internal.util.Preconditions;
42 
43 import com.google.android.textclassifier.ActionsSuggestionsModel;
44 import com.google.android.textclassifier.AnnotatorModel;
45 import com.google.android.textclassifier.LangIdModel;
46 import com.google.android.textclassifier.LangIdModel.LanguageResult;
47 
48 import java.io.File;
49 import java.io.FileNotFoundException;
50 import java.io.IOException;
51 import java.time.Instant;
52 import java.time.ZonedDateTime;
53 import java.util.ArrayList;
54 import java.util.Collection;
55 import java.util.Collections;
56 import java.util.List;
57 import java.util.Locale;
58 import java.util.Map;
59 import java.util.Objects;
60 import java.util.Set;
61 import java.util.function.Supplier;
62 
63 /**
64  * Default implementation of the {@link TextClassifier} interface.
65  *
66  * <p>This class uses machine learning to recognize entities in text.
67  * Unless otherwise stated, methods of this class are blocking operations and should most
68  * likely not be called on the UI thread.
69  *
70  * @hide
71  */
72 public final class TextClassifierImpl implements TextClassifier {
73 
74     private static final String LOG_TAG = DEFAULT_LOG_TAG;
75 
76     private static final boolean DEBUG = false;
77 
78     private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
79     // Annotator
80     private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
81             "textclassifier\\.(.*)\\.model";
82     private static final File ANNOTATOR_UPDATED_MODEL_FILE =
83             new File("/data/misc/textclassifier/textclassifier.model");
84 
85     // LangID
86     private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
87     private static final File UPDATED_LANG_ID_MODEL_FILE =
88             new File("/data/misc/textclassifier/lang_id.model");
89 
90     // Actions
91     private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
92             "actions_suggestions\\.(.*)\\.model";
93     private static final File UPDATED_ACTIONS_MODEL =
94             new File("/data/misc/textclassifier/actions_suggestions.model");
95 
96     private final Context mContext;
97     private final TextClassifier mFallback;
98     private final GenerateLinksLogger mGenerateLinksLogger;
99 
100     private final Object mLock = new Object();
101 
102     @GuardedBy("mLock")
103     private ModelFileManager.ModelFile mAnnotatorModelInUse;
104     @GuardedBy("mLock")
105     private AnnotatorModel mAnnotatorImpl;
106 
107     @GuardedBy("mLock")
108     private ModelFileManager.ModelFile mLangIdModelInUse;
109     @GuardedBy("mLock")
110     private LangIdModel mLangIdImpl;
111 
112     @GuardedBy("mLock")
113     private ModelFileManager.ModelFile mActionModelInUse;
114     @GuardedBy("mLock")
115     private ActionsSuggestionsModel mActionsImpl;
116 
117     private final SelectionSessionLogger mSessionLogger = new SelectionSessionLogger();
118     private final TextClassifierEventTronLogger mTextClassifierEventTronLogger =
119             new TextClassifierEventTronLogger();
120 
121     private final TextClassificationConstants mSettings;
122 
123     private final ModelFileManager mAnnotatorModelFileManager;
124     private final ModelFileManager mLangIdModelFileManager;
125     private final ModelFileManager mActionsModelFileManager;
126 
127     private final ClassificationIntentFactory mClassificationIntentFactory;
128     private final TemplateIntentFactory mTemplateIntentFactory;
129     private final Supplier<ActionsModelParams> mActionsModelParamsSupplier;
130 
TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback)131     public TextClassifierImpl(
132             Context context, TextClassificationConstants settings, TextClassifier fallback) {
133         mContext = Preconditions.checkNotNull(context);
134         mFallback = Preconditions.checkNotNull(fallback);
135         mSettings = Preconditions.checkNotNull(settings);
136         mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate());
137         mAnnotatorModelFileManager = new ModelFileManager(
138                 new ModelFileManager.ModelFileSupplierImpl(
139                         FACTORY_MODEL_DIR,
140                         ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
141                         ANNOTATOR_UPDATED_MODEL_FILE,
142                         AnnotatorModel::getVersion,
143                         AnnotatorModel::getLocales));
144         mLangIdModelFileManager = new ModelFileManager(
145                 new ModelFileManager.ModelFileSupplierImpl(
146                         FACTORY_MODEL_DIR,
147                         LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
148                         UPDATED_LANG_ID_MODEL_FILE,
149                         LangIdModel::getVersion,
150                         fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
151         mActionsModelFileManager = new ModelFileManager(
152                 new ModelFileManager.ModelFileSupplierImpl(
153                         FACTORY_MODEL_DIR,
154                         ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
155                         UPDATED_ACTIONS_MODEL,
156                         ActionsSuggestionsModel::getVersion,
157                         ActionsSuggestionsModel::getLocales));
158 
159         mTemplateIntentFactory = new TemplateIntentFactory();
160         mClassificationIntentFactory = mSettings.isTemplateIntentFactoryEnabled()
161                 ? new TemplateClassificationIntentFactory(
162                 mTemplateIntentFactory, new LegacyClassificationIntentFactory())
163                 : new LegacyClassificationIntentFactory();
164         mActionsModelParamsSupplier = new ActionsModelParamsSupplier(mContext,
165                 () -> {
166                     synchronized (mLock) {
167                         // Clear mActionsImpl here, so that we will create a new
168                         // ActionsSuggestionsModel object with the new flag in the next request.
169                         mActionsImpl = null;
170                         mActionModelInUse = null;
171                     }
172                 });
173     }
174 
TextClassifierImpl(Context context, TextClassificationConstants settings)175     public TextClassifierImpl(Context context, TextClassificationConstants settings) {
176         this(context, settings, TextClassifier.NO_OP);
177     }
178 
179     /** @inheritDoc */
180     @Override
181     @WorkerThread
suggestSelection(TextSelection.Request request)182     public TextSelection suggestSelection(TextSelection.Request request) {
183         Preconditions.checkNotNull(request);
184         Utils.checkMainThread();
185         try {
186             final int rangeLength = request.getEndIndex() - request.getStartIndex();
187             final String string = request.getText().toString();
188             if (string.length() > 0
189                     && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
190                 final String localesString = concatenateLocales(request.getDefaultLocales());
191                 final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
192                 final ZonedDateTime refTime = ZonedDateTime.now();
193                 final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
194                 final int start;
195                 final int end;
196                 if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) {
197                     start = request.getStartIndex();
198                     end = request.getEndIndex();
199                 } else {
200                     final int[] startEnd = annotatorImpl.suggestSelection(
201                             string, request.getStartIndex(), request.getEndIndex(),
202                             new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
203                     start = startEnd[0];
204                     end = startEnd[1];
205                 }
206                 if (start < end
207                         && start >= 0 && end <= string.length()
208                         && start <= request.getStartIndex() && end >= request.getEndIndex()) {
209                     final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
210                     final AnnotatorModel.ClassificationResult[] results =
211                             annotatorImpl.classifyText(
212                                     string, start, end,
213                                     new AnnotatorModel.ClassificationOptions(
214                                             refTime.toInstant().toEpochMilli(),
215                                             refTime.getZone().getId(),
216                                             localesString,
217                                             detectLanguageTags),
218                                     // Passing null here to suppress intent generation
219                                     // TODO: Use an explicit flag to suppress it.
220                                     /* appContext */ null,
221                                     /* deviceLocales */null);
222                     final int size = results.length;
223                     for (int i = 0; i < size; i++) {
224                         tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
225                     }
226                     return tsBuilder.setId(createId(
227                             string, request.getStartIndex(), request.getEndIndex()))
228                             .build();
229                 } else {
230                     // We can not trust the result. Log the issue and ignore the result.
231                     Log.d(LOG_TAG, "Got bad indices for input text. Ignoring result.");
232                 }
233             }
234         } catch (Throwable t) {
235             // Avoid throwing from this method. Log the error.
236             Log.e(LOG_TAG,
237                     "Error suggesting selection for text. No changes to selection suggested.",
238                     t);
239         }
240         // Getting here means something went wrong, return a NO_OP result.
241         return mFallback.suggestSelection(request);
242     }
243 
244     /** @inheritDoc */
245     @Override
246     @WorkerThread
classifyText(TextClassification.Request request)247     public TextClassification classifyText(TextClassification.Request request) {
248         Preconditions.checkNotNull(request);
249         Utils.checkMainThread();
250         try {
251             final int rangeLength = request.getEndIndex() - request.getStartIndex();
252             final String string = request.getText().toString();
253             if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) {
254                 final String localesString = concatenateLocales(request.getDefaultLocales());
255                 final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
256                 final ZonedDateTime refTime = request.getReferenceTime() != null
257                         ? request.getReferenceTime() : ZonedDateTime.now();
258                 final AnnotatorModel.ClassificationResult[] results =
259                         getAnnotatorImpl(request.getDefaultLocales())
260                                 .classifyText(
261                                         string, request.getStartIndex(), request.getEndIndex(),
262                                         new AnnotatorModel.ClassificationOptions(
263                                                 refTime.toInstant().toEpochMilli(),
264                                                 refTime.getZone().getId(),
265                                                 localesString,
266                                                 detectLanguageTags),
267                                         mContext,
268                                         getResourceLocalesString()
269                                 );
270                 if (results.length > 0) {
271                     return createClassificationResult(
272                             results, string,
273                             request.getStartIndex(), request.getEndIndex(), refTime.toInstant());
274                 }
275             }
276         } catch (Throwable t) {
277             // Avoid throwing from this method. Log the error.
278             Log.e(LOG_TAG, "Error getting text classification info.", t);
279         }
280         // Getting here means something went wrong, return a NO_OP result.
281         return mFallback.classifyText(request);
282     }
283 
284     /** @inheritDoc */
285     @Override
286     @WorkerThread
generateLinks(@onNull TextLinks.Request request)287     public TextLinks generateLinks(@NonNull TextLinks.Request request) {
288         Preconditions.checkNotNull(request);
289         Utils.checkTextLength(request.getText(), getMaxGenerateLinksTextLength());
290         Utils.checkMainThread();
291 
292         if (!mSettings.isSmartLinkifyEnabled() && request.isLegacyFallback()) {
293             return Utils.generateLegacyLinks(request);
294         }
295 
296         final String textString = request.getText().toString();
297         final TextLinks.Builder builder = new TextLinks.Builder(textString);
298 
299         try {
300             final long startTimeMs = System.currentTimeMillis();
301             final ZonedDateTime refTime = ZonedDateTime.now();
302             final Collection<String> entitiesToIdentify = request.getEntityConfig() != null
303                     ? request.getEntityConfig().resolveEntityListModifications(
304                     getEntitiesForHints(request.getEntityConfig().getHints()))
305                     : mSettings.getEntityListDefault();
306             final String localesString = concatenateLocales(request.getDefaultLocales());
307             final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
308             final AnnotatorModel annotatorImpl =
309                     getAnnotatorImpl(request.getDefaultLocales());
310             final boolean isSerializedEntityDataEnabled =
311                     ExtrasUtils.isSerializedEntityDataEnabled(request);
312             final AnnotatorModel.AnnotatedSpan[] annotations =
313                     annotatorImpl.annotate(
314                             textString,
315                             new AnnotatorModel.AnnotationOptions(
316                                     refTime.toInstant().toEpochMilli(),
317                                     refTime.getZone().getId(),
318                                     localesString,
319                                     detectLanguageTags,
320                                     entitiesToIdentify,
321                                     AnnotatorModel.AnnotationUsecase.SMART.getValue(),
322                                     isSerializedEntityDataEnabled));
323             for (AnnotatorModel.AnnotatedSpan span : annotations) {
324                 final AnnotatorModel.ClassificationResult[] results =
325                         span.getClassification();
326                 if (results.length == 0
327                         || !entitiesToIdentify.contains(results[0].getCollection())) {
328                     continue;
329                 }
330                 final Map<String, Float> entityScores = new ArrayMap<>();
331                 for (int i = 0; i < results.length; i++) {
332                     entityScores.put(results[i].getCollection(), results[i].getScore());
333                 }
334                 Bundle extras = new Bundle();
335                 if (isSerializedEntityDataEnabled) {
336                     ExtrasUtils.putEntities(extras, results);
337                 }
338                 builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
339             }
340             final TextLinks links = builder.build();
341             final long endTimeMs = System.currentTimeMillis();
342             final String callingPackageName = request.getCallingPackageName() == null
343                     ? mContext.getPackageName()  // local (in process) TC.
344                     : request.getCallingPackageName();
345             mGenerateLinksLogger.logGenerateLinks(
346                     request.getText(), links, callingPackageName, endTimeMs - startTimeMs);
347             return links;
348         } catch (Throwable t) {
349             // Avoid throwing from this method. Log the error.
350             Log.e(LOG_TAG, "Error getting links info.", t);
351         }
352         return mFallback.generateLinks(request);
353     }
354 
355     /** @inheritDoc */
356     @Override
getMaxGenerateLinksTextLength()357     public int getMaxGenerateLinksTextLength() {
358         return mSettings.getGenerateLinksMaxTextLength();
359     }
360 
getEntitiesForHints(Collection<String> hints)361     private Collection<String> getEntitiesForHints(Collection<String> hints) {
362         final boolean editable = hints.contains(HINT_TEXT_IS_EDITABLE);
363         final boolean notEditable = hints.contains(HINT_TEXT_IS_NOT_EDITABLE);
364 
365         // Use the default if there is no hint, or conflicting ones.
366         final boolean useDefault = editable == notEditable;
367         if (useDefault) {
368             return mSettings.getEntityListDefault();
369         } else if (editable) {
370             return mSettings.getEntityListEditable();
371         } else {  // notEditable
372             return mSettings.getEntityListNotEditable();
373         }
374     }
375 
376     /** @inheritDoc */
377     @Override
onSelectionEvent(SelectionEvent event)378     public void onSelectionEvent(SelectionEvent event) {
379         mSessionLogger.writeEvent(event);
380     }
381 
382     @Override
onTextClassifierEvent(TextClassifierEvent event)383     public void onTextClassifierEvent(TextClassifierEvent event) {
384         if (DEBUG) {
385             Log.d(DEFAULT_LOG_TAG, "onTextClassifierEvent() called with: event = [" + event + "]");
386         }
387         try {
388             final SelectionEvent selEvent = event.toSelectionEvent();
389             if (selEvent != null) {
390                 mSessionLogger.writeEvent(selEvent);
391             } else {
392                 mTextClassifierEventTronLogger.writeEvent(event);
393             }
394         } catch (Exception e) {
395             Log.e(LOG_TAG, "Error writing event", e);
396         }
397     }
398 
399     /** @inheritDoc */
400     @Override
detectLanguage(@onNull TextLanguage.Request request)401     public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
402         Preconditions.checkNotNull(request);
403         Utils.checkMainThread();
404         try {
405             final TextLanguage.Builder builder = new TextLanguage.Builder();
406             final LangIdModel.LanguageResult[] langResults =
407                     getLangIdImpl().detectLanguages(request.getText().toString());
408             for (int i = 0; i < langResults.length; i++) {
409                 builder.putLocale(
410                         ULocale.forLanguageTag(langResults[i].getLanguage()),
411                         langResults[i].getScore());
412             }
413             return builder.build();
414         } catch (Throwable t) {
415             // Avoid throwing from this method. Log the error.
416             Log.e(LOG_TAG, "Error detecting text language.", t);
417         }
418         return mFallback.detectLanguage(request);
419     }
420 
421     @Override
suggestConversationActions(ConversationActions.Request request)422     public ConversationActions suggestConversationActions(ConversationActions.Request request) {
423         Preconditions.checkNotNull(request);
424         Utils.checkMainThread();
425         try {
426             ActionsSuggestionsModel actionsImpl = getActionsImpl();
427             if (actionsImpl == null) {
428                 // Actions model is optional, fallback if it is not available.
429                 return mFallback.suggestConversationActions(request);
430             }
431             ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
432                     ActionsSuggestionsHelper.toNativeMessages(
433                             request.getConversation(), this::detectLanguageTagsFromText);
434             if (nativeMessages.length == 0) {
435                 return mFallback.suggestConversationActions(request);
436             }
437             ActionsSuggestionsModel.Conversation nativeConversation =
438                     new ActionsSuggestionsModel.Conversation(nativeMessages);
439 
440             ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
441                     actionsImpl.suggestActionsWithIntents(
442                             nativeConversation,
443                             null,
444                             mContext,
445                             getResourceLocalesString(),
446                             getAnnotatorImpl(LocaleList.getDefault()));
447             return createConversationActionResult(request, nativeSuggestions);
448         } catch (Throwable t) {
449             // Avoid throwing from this method. Log the error.
450             Log.e(LOG_TAG, "Error suggesting conversation actions.", t);
451         }
452         return mFallback.suggestConversationActions(request);
453     }
454 
455     /**
456      * Returns the {@link ConversationAction} result, with a non-null extras.
457      * <p>
458      * Whenever the RemoteAction is non-null, you can expect its corresponding intent
459      * with a non-null component name is in the extras.
460      */
createConversationActionResult( ConversationActions.Request request, ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions)461     private ConversationActions createConversationActionResult(
462             ConversationActions.Request request,
463             ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
464         Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
465         List<ConversationAction> conversationActions = new ArrayList<>();
466         for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
467             String actionType = nativeSuggestion.getActionType();
468             if (!expectedTypes.contains(actionType)) {
469                 continue;
470             }
471             LabeledIntent.Result labeledIntentResult =
472                     ActionsSuggestionsHelper.createLabeledIntentResult(
473                             mContext,
474                             mTemplateIntentFactory,
475                             nativeSuggestion);
476             RemoteAction remoteAction = null;
477             Bundle extras = new Bundle();
478             if (labeledIntentResult != null) {
479                 remoteAction = labeledIntentResult.remoteAction;
480                 ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
481             }
482             ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
483             ExtrasUtils.putEntitiesExtras(
484                     extras,
485                     TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
486             conversationActions.add(
487                     new ConversationAction.Builder(actionType)
488                             .setConfidenceScore(nativeSuggestion.getScore())
489                             .setTextReply(nativeSuggestion.getResponseText())
490                             .setAction(remoteAction)
491                             .setExtras(extras)
492                             .build());
493         }
494         conversationActions =
495                 ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
496         if (request.getMaxSuggestions() >= 0
497                 && conversationActions.size() > request.getMaxSuggestions()) {
498             conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
499         }
500         String resultId = ActionsSuggestionsHelper.createResultId(
501                 mContext,
502                 request.getConversation(),
503                 mActionModelInUse.getVersion(),
504                 mActionModelInUse.getSupportedLocales());
505         return new ConversationActions(conversationActions, resultId);
506     }
507 
508     @Nullable
detectLanguageTagsFromText(CharSequence text)509     private String detectLanguageTagsFromText(CharSequence text) {
510         if (!mSettings.isDetectLanguagesFromTextEnabled()) {
511             return null;
512         }
513         final float threshold = getLangIdThreshold();
514         if (threshold < 0 || threshold > 1) {
515             Log.w(LOG_TAG,
516                     "[detectLanguageTagsFromText] unexpected threshold is found: " + threshold);
517             return null;
518         }
519         TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
520         TextLanguage textLanguage = detectLanguage(request);
521         int localeHypothesisCount = textLanguage.getLocaleHypothesisCount();
522         List<String> languageTags = new ArrayList<>();
523         for (int i = 0; i < localeHypothesisCount; i++) {
524             ULocale locale = textLanguage.getLocale(i);
525             if (textLanguage.getConfidenceScore(locale) < threshold) {
526                 break;
527             }
528             languageTags.add(locale.toLanguageTag());
529         }
530         if (languageTags.isEmpty()) {
531             return null;
532         }
533         return String.join(",", languageTags);
534     }
535 
resolveActionTypesFromRequest(ConversationActions.Request request)536     private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
537         List<String> defaultActionTypes =
538                 request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
539                         ? mSettings.getNotificationConversationActionTypes()
540                         : mSettings.getInAppConversationActionTypes();
541         return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
542     }
543 
getAnnotatorImpl(LocaleList localeList)544     private AnnotatorModel getAnnotatorImpl(LocaleList localeList)
545             throws FileNotFoundException {
546         synchronized (mLock) {
547             localeList = localeList == null ? LocaleList.getDefault() : localeList;
548             final ModelFileManager.ModelFile bestModel =
549                     mAnnotatorModelFileManager.findBestModelFile(localeList);
550             if (bestModel == null) {
551                 throw new FileNotFoundException(
552                         "No annotator model for " + localeList.toLanguageTags());
553             }
554             if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) {
555                 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
556                 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
557                         new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
558                 try {
559                     if (pfd != null) {
560                         // The current annotator model may be still used by another thread / model.
561                         // Do not call close() here, and let the GC to clean it up when no one else
562                         // is using it.
563                         mAnnotatorImpl = new AnnotatorModel(pfd.getFd());
564                         mAnnotatorModelInUse = bestModel;
565                     }
566                 } finally {
567                     maybeCloseAndLogError(pfd);
568                 }
569             }
570             return mAnnotatorImpl;
571         }
572     }
573 
getLangIdImpl()574     private LangIdModel getLangIdImpl() throws FileNotFoundException {
575         synchronized (mLock) {
576             final ModelFileManager.ModelFile bestModel =
577                     mLangIdModelFileManager.findBestModelFile(null);
578             if (bestModel == null) {
579                 throw new FileNotFoundException("No LangID model is found");
580             }
581             if (mLangIdImpl == null || !Objects.equals(mLangIdModelInUse, bestModel)) {
582                 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
583                 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
584                         new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
585                 try {
586                     if (pfd != null) {
587                         mLangIdImpl = new LangIdModel(pfd.getFd());
588                         mLangIdModelInUse = bestModel;
589                     }
590                 } finally {
591                     maybeCloseAndLogError(pfd);
592                 }
593             }
594             return mLangIdImpl;
595         }
596     }
597 
598     @Nullable
getActionsImpl()599     private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
600         synchronized (mLock) {
601             // TODO: Use LangID to determine the locale we should use here?
602             final ModelFileManager.ModelFile bestModel =
603                     mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
604             if (bestModel == null) {
605                 return null;
606             }
607             if (mActionsImpl == null || !Objects.equals(mActionModelInUse, bestModel)) {
608                 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel);
609                 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open(
610                         new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
611                 try {
612                     if (pfd == null) {
613                         Log.d(LOG_TAG, "Failed to read the model file: " + bestModel.getPath());
614                         return null;
615                     }
616                     ActionsModelParams params = mActionsModelParamsSupplier.get();
617                     mActionsImpl = new ActionsSuggestionsModel(
618                             pfd.getFd(), params.getSerializedPreconditions(bestModel));
619                     mActionModelInUse = bestModel;
620                 } finally {
621                     maybeCloseAndLogError(pfd);
622                 }
623             }
624             return mActionsImpl;
625         }
626     }
627 
createId(String text, int start, int end)628     private String createId(String text, int start, int end) {
629         synchronized (mLock) {
630             return SelectionSessionLogger.createId(text, start, end, mContext,
631                     mAnnotatorModelInUse.getVersion(),
632                     mAnnotatorModelInUse.getSupportedLocales());
633         }
634     }
635 
concatenateLocales(@ullable LocaleList locales)636     private static String concatenateLocales(@Nullable LocaleList locales) {
637         return (locales == null) ? "" : locales.toLanguageTags();
638     }
639 
createClassificationResult( AnnotatorModel.ClassificationResult[] classifications, String text, int start, int end, @Nullable Instant referenceTime)640     private TextClassification createClassificationResult(
641             AnnotatorModel.ClassificationResult[] classifications,
642             String text, int start, int end, @Nullable Instant referenceTime) {
643         final String classifiedText = text.substring(start, end);
644         final TextClassification.Builder builder = new TextClassification.Builder()
645                 .setText(classifiedText);
646 
647         final int typeCount = classifications.length;
648         AnnotatorModel.ClassificationResult highestScoringResult =
649                 typeCount > 0 ? classifications[0] : null;
650         for (int i = 0; i < typeCount; i++) {
651             builder.setEntityType(classifications[i]);
652             if (classifications[i].getScore() > highestScoringResult.getScore()) {
653                 highestScoringResult = classifications[i];
654             }
655         }
656 
657         final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end);
658         final Bundle textLanguagesBundle = languagesBundles.first;
659         final Bundle foreignLanguageBundle = languagesBundles.second;
660         builder.setForeignLanguageExtra(foreignLanguageBundle);
661 
662         boolean isPrimaryAction = true;
663         final List<LabeledIntent> labeledIntents = mClassificationIntentFactory.create(
664                 mContext,
665                 classifiedText,
666                 foreignLanguageBundle != null,
667                 referenceTime,
668                 highestScoringResult);
669         final LabeledIntent.TitleChooser titleChooser =
670                 (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
671 
672         for (LabeledIntent labeledIntent : labeledIntents) {
673             final LabeledIntent.Result result =
674                     labeledIntent.resolve(mContext, titleChooser, textLanguagesBundle);
675             if (result == null) {
676                 continue;
677             }
678 
679             final Intent intent = result.resolvedIntent;
680             final RemoteAction action = result.remoteAction;
681             if (isPrimaryAction) {
682                 // For O backwards compatibility, the first RemoteAction is also written to the
683                 // legacy API fields.
684                 builder.setIcon(action.getIcon().loadDrawable(mContext));
685                 builder.setLabel(action.getTitle().toString());
686                 builder.setIntent(intent);
687                 builder.setOnClickListener(TextClassification.createIntentOnClickListener(
688                         TextClassification.createPendingIntent(
689                                 mContext, intent, labeledIntent.requestCode)));
690                 isPrimaryAction = false;
691             }
692             builder.addAction(action, intent);
693         }
694         return builder.setId(createId(text, start, end)).build();
695     }
696 
697     /**
698      * Returns a bundle pair with language detection information for extras.
699      * <p>
700      * Pair.first = textLanguagesBundle - A bundle containing information about all detected
701      * languages in the text. May be null if language detection fails or is disabled. This is
702      * typically expected to be added to a textClassifier generated remote action intent.
703      * See {@link ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}.
704      * See {@link ExtrasUtils#getTopLanguage(Intent)}.
705      * <p>
706      * Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if the
707      * system finds the text to be in a foreign language. Otherwise is null.
708      * See {@link TextClassification.Builder#setForeignLanguageExtra(Bundle)}.
709      *
710      * @param context the context of the text to detect languages for
711      * @param start the start index of the text
712      * @param end the end index of the text
713      */
714     // TODO: Revisit this algorithm.
715     // TODO: Consider making this public API.
generateLanguageBundles(String context, int start, int end)716     private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) {
717         if (!mSettings.isTranslateInClassificationEnabled()) {
718             return null;
719         }
720         try {
721             final float threshold = getLangIdThreshold();
722             if (threshold < 0 || threshold > 1) {
723                 Log.w(LOG_TAG,
724                         "[detectForeignLanguage] unexpected threshold is found: " + threshold);
725                 return Pair.create(null, null);
726             }
727 
728             final EntityConfidence languageScores = detectLanguages(context, start, end);
729             if (languageScores.getEntities().isEmpty()) {
730                 return Pair.create(null, null);
731             }
732 
733             final Bundle textLanguagesBundle = new Bundle();
734             ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores);
735 
736             final String language = languageScores.getEntities().get(0);
737             final float score = languageScores.getConfidenceScore(language);
738             if (score < threshold) {
739                 return Pair.create(textLanguagesBundle, null);
740             }
741 
742             Log.v(LOG_TAG, String.format(
743                     Locale.US, "Language detected: <%s:%.2f>", language, score));
744 
745             final Locale detected = new Locale(language);
746             final LocaleList deviceLocales = LocaleList.getDefault();
747             final int size = deviceLocales.size();
748             for (int i = 0; i < size; i++) {
749                 if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) {
750                     return Pair.create(textLanguagesBundle, null);
751                 }
752             }
753             final Bundle foreignLanguageBundle = ExtrasUtils.createForeignLanguageExtra(
754                     detected.getLanguage(), score, getLangIdImpl().getVersion());
755             return Pair.create(textLanguagesBundle, foreignLanguageBundle);
756         } catch (Throwable t) {
757             Log.e(LOG_TAG, "Error generating language bundles.", t);
758         }
759         return Pair.create(null, null);
760     }
761 
762     /**
763      * Detect the language of a piece of text by taking surrounding text into consideration.
764      *
765      * @param text text providing context for the text for which its language is to be detected
766      * @param start the start index of the text to detect its language
767      * @param end the end index of the text to detect its language
768      */
769     // TODO: Revisit this algorithm.
detectLanguages(String text, int start, int end)770     private EntityConfidence detectLanguages(String text, int start, int end)
771             throws FileNotFoundException {
772         Preconditions.checkArgument(start >= 0);
773         Preconditions.checkArgument(end <= text.length());
774         Preconditions.checkArgument(start <= end);
775 
776         final float[] langIdContextSettings = mSettings.getLangIdContextSettings();
777         // The minimum size of text to prefer for detection.
778         final int minimumTextSize = (int) langIdContextSettings[0];
779         // For reducing the score when text is less than the preferred size.
780         final float penalizeRatio = langIdContextSettings[1];
781         // Original detection score to surrounding text detection score ratios.
782         final float subjectTextScoreRatio = langIdContextSettings[2];
783         final float moreTextScoreRatio = 1f - subjectTextScoreRatio;
784         Log.v(LOG_TAG,
785                 String.format(Locale.US, "LangIdContextSettings: "
786                                 + "minimumTextSize=%d, penalizeRatio=%.2f, "
787                                 + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f",
788                         minimumTextSize, penalizeRatio, subjectTextScoreRatio, moreTextScoreRatio));
789 
790         if (end - start < minimumTextSize && penalizeRatio <= 0) {
791             return new EntityConfidence(Collections.emptyMap());
792         }
793 
794         final String subject = text.substring(start, end);
795         final EntityConfidence scores = detectLanguages(subject);
796 
797         if (subject.length() >= minimumTextSize
798                 || subject.length() == text.length()
799                 || subjectTextScoreRatio * penalizeRatio >= 1) {
800             return scores;
801         }
802 
803         final EntityConfidence moreTextScores;
804         if (moreTextScoreRatio >= 0) {
805             // Attempt to grow the detection text to be at least minimumTextSize long.
806             final String moreText = Utils.getSubString(text, start, end, minimumTextSize);
807             moreTextScores = detectLanguages(moreText);
808         } else {
809             moreTextScores = new EntityConfidence(Collections.emptyMap());
810         }
811 
812         // Combine the original detection scores with the those returned after including more text.
813         final Map<String, Float> newScores = new ArrayMap<>();
814         final Set<String> languages = new ArraySet<>();
815         languages.addAll(scores.getEntities());
816         languages.addAll(moreTextScores.getEntities());
817         for (String language : languages) {
818             final float score = (subjectTextScoreRatio * scores.getConfidenceScore(language)
819                     + moreTextScoreRatio * moreTextScores.getConfidenceScore(language))
820                     * penalizeRatio;
821             newScores.put(language, score);
822         }
823         return new EntityConfidence(newScores);
824     }
825 
826     /**
827      * Detect languages for the specified text.
828      */
detectLanguages(String text)829     private EntityConfidence detectLanguages(String text) throws FileNotFoundException {
830         final LangIdModel langId = getLangIdImpl();
831         final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text);
832         final Map<String, Float> languagesMap = new ArrayMap<>();
833         for (LanguageResult langResult : langResults) {
834             languagesMap.put(langResult.getLanguage(), langResult.getScore());
835         }
836         return new EntityConfidence(languagesMap);
837     }
838 
getLangIdThreshold()839     private float getLangIdThreshold() {
840         try {
841             return mSettings.getLangIdThresholdOverride() >= 0
842                     ? mSettings.getLangIdThresholdOverride()
843                     : getLangIdImpl().getLangIdThreshold();
844         } catch (FileNotFoundException e) {
845             final float defaultThreshold = 0.5f;
846             Log.v(LOG_TAG, "Using default foreign language threshold: " + defaultThreshold);
847             return defaultThreshold;
848         }
849     }
850 
851     @Override
dump(@onNull IndentingPrintWriter printWriter)852     public void dump(@NonNull IndentingPrintWriter printWriter) {
853         synchronized (mLock) {
854             printWriter.println("TextClassifierImpl:");
855             printWriter.increaseIndent();
856             printWriter.println("Annotator model file(s):");
857             printWriter.increaseIndent();
858             for (ModelFileManager.ModelFile modelFile :
859                     mAnnotatorModelFileManager.listModelFiles()) {
860                 printWriter.println(modelFile.toString());
861             }
862             printWriter.decreaseIndent();
863             printWriter.println("LangID model file(s):");
864             printWriter.increaseIndent();
865             for (ModelFileManager.ModelFile modelFile :
866                     mLangIdModelFileManager.listModelFiles()) {
867                 printWriter.println(modelFile.toString());
868             }
869             printWriter.decreaseIndent();
870             printWriter.println("Actions model file(s):");
871             printWriter.increaseIndent();
872             for (ModelFileManager.ModelFile modelFile :
873                     mActionsModelFileManager.listModelFiles()) {
874                 printWriter.println(modelFile.toString());
875             }
876             printWriter.decreaseIndent();
877             printWriter.printPair("mFallback", mFallback);
878             printWriter.decreaseIndent();
879             printWriter.println();
880         }
881     }
882 
883     /**
884      * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur.
885      */
maybeCloseAndLogError(@ullable ParcelFileDescriptor fd)886     private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
887         if (fd == null) {
888             return;
889         }
890 
891         try {
892             fd.close();
893         } catch (IOException e) {
894             Log.e(LOG_TAG, "Error closing file.", e);
895         }
896     }
897 
898     /**
899      * Returns the locales string for the current resources configuration.
900      */
getResourceLocalesString()901     private String getResourceLocalesString() {
902         try {
903             return mContext.getResources().getConfiguration().getLocales().toLanguageTags();
904         } catch (NullPointerException e) {
905             // NPE is unexpected. Erring on the side of caution.
906             return LocaleList.getDefault().toLanguageTags();
907         }
908     }
909 }
910