1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.view.textclassifier; 18 19 import android.annotation.NonNull; 20 import android.annotation.Nullable; 21 import android.annotation.WorkerThread; 22 import android.app.RemoteAction; 23 import android.content.Context; 24 import android.content.Intent; 25 import android.icu.util.ULocale; 26 import android.os.Bundle; 27 import android.os.LocaleList; 28 import android.os.ParcelFileDescriptor; 29 import android.util.ArrayMap; 30 import android.util.ArraySet; 31 import android.util.Pair; 32 import android.view.textclassifier.ActionsModelParamsSupplier.ActionsModelParams; 33 import android.view.textclassifier.intent.ClassificationIntentFactory; 34 import android.view.textclassifier.intent.LabeledIntent; 35 import android.view.textclassifier.intent.LegacyClassificationIntentFactory; 36 import android.view.textclassifier.intent.TemplateClassificationIntentFactory; 37 import android.view.textclassifier.intent.TemplateIntentFactory; 38 39 import com.android.internal.annotations.GuardedBy; 40 import com.android.internal.util.IndentingPrintWriter; 41 import com.android.internal.util.Preconditions; 42 43 import com.google.android.textclassifier.ActionsSuggestionsModel; 44 import com.google.android.textclassifier.AnnotatorModel; 45 import com.google.android.textclassifier.LangIdModel; 46 import com.google.android.textclassifier.LangIdModel.LanguageResult; 47 48 import java.io.File; 49 import java.io.FileNotFoundException; 50 import java.io.IOException; 51 import java.time.Instant; 52 import java.time.ZonedDateTime; 53 import java.util.ArrayList; 54 import java.util.Collection; 55 import java.util.Collections; 56 import java.util.List; 57 import java.util.Locale; 58 import java.util.Map; 59 import java.util.Objects; 60 import java.util.Set; 61 import java.util.function.Supplier; 62 63 /** 64 * Default implementation of the {@link TextClassifier} interface. 65 * 66 * <p>This class uses machine learning to recognize entities in text. 67 * Unless otherwise stated, methods of this class are blocking operations and should most 68 * likely not be called on the UI thread. 69 * 70 * @hide 71 */ 72 public final class TextClassifierImpl implements TextClassifier { 73 74 private static final String LOG_TAG = DEFAULT_LOG_TAG; 75 76 private static final boolean DEBUG = false; 77 78 private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/"); 79 // Annotator 80 private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX = 81 "textclassifier\\.(.*)\\.model"; 82 private static final File ANNOTATOR_UPDATED_MODEL_FILE = 83 new File("/data/misc/textclassifier/textclassifier.model"); 84 85 // LangID 86 private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model"; 87 private static final File UPDATED_LANG_ID_MODEL_FILE = 88 new File("/data/misc/textclassifier/lang_id.model"); 89 90 // Actions 91 private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX = 92 "actions_suggestions\\.(.*)\\.model"; 93 private static final File UPDATED_ACTIONS_MODEL = 94 new File("/data/misc/textclassifier/actions_suggestions.model"); 95 96 private final Context mContext; 97 private final TextClassifier mFallback; 98 private final GenerateLinksLogger mGenerateLinksLogger; 99 100 private final Object mLock = new Object(); 101 102 @GuardedBy("mLock") 103 private ModelFileManager.ModelFile mAnnotatorModelInUse; 104 @GuardedBy("mLock") 105 private AnnotatorModel mAnnotatorImpl; 106 107 @GuardedBy("mLock") 108 private ModelFileManager.ModelFile mLangIdModelInUse; 109 @GuardedBy("mLock") 110 private LangIdModel mLangIdImpl; 111 112 @GuardedBy("mLock") 113 private ModelFileManager.ModelFile mActionModelInUse; 114 @GuardedBy("mLock") 115 private ActionsSuggestionsModel mActionsImpl; 116 117 private final SelectionSessionLogger mSessionLogger = new SelectionSessionLogger(); 118 private final TextClassifierEventTronLogger mTextClassifierEventTronLogger = 119 new TextClassifierEventTronLogger(); 120 121 private final TextClassificationConstants mSettings; 122 123 private final ModelFileManager mAnnotatorModelFileManager; 124 private final ModelFileManager mLangIdModelFileManager; 125 private final ModelFileManager mActionsModelFileManager; 126 127 private final ClassificationIntentFactory mClassificationIntentFactory; 128 private final TemplateIntentFactory mTemplateIntentFactory; 129 private final Supplier<ActionsModelParams> mActionsModelParamsSupplier; 130 TextClassifierImpl( Context context, TextClassificationConstants settings, TextClassifier fallback)131 public TextClassifierImpl( 132 Context context, TextClassificationConstants settings, TextClassifier fallback) { 133 mContext = Preconditions.checkNotNull(context); 134 mFallback = Preconditions.checkNotNull(fallback); 135 mSettings = Preconditions.checkNotNull(settings); 136 mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate()); 137 mAnnotatorModelFileManager = new ModelFileManager( 138 new ModelFileManager.ModelFileSupplierImpl( 139 FACTORY_MODEL_DIR, 140 ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX, 141 ANNOTATOR_UPDATED_MODEL_FILE, 142 AnnotatorModel::getVersion, 143 AnnotatorModel::getLocales)); 144 mLangIdModelFileManager = new ModelFileManager( 145 new ModelFileManager.ModelFileSupplierImpl( 146 FACTORY_MODEL_DIR, 147 LANG_ID_FACTORY_MODEL_FILENAME_REGEX, 148 UPDATED_LANG_ID_MODEL_FILE, 149 LangIdModel::getVersion, 150 fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT)); 151 mActionsModelFileManager = new ModelFileManager( 152 new ModelFileManager.ModelFileSupplierImpl( 153 FACTORY_MODEL_DIR, 154 ACTIONS_FACTORY_MODEL_FILENAME_REGEX, 155 UPDATED_ACTIONS_MODEL, 156 ActionsSuggestionsModel::getVersion, 157 ActionsSuggestionsModel::getLocales)); 158 159 mTemplateIntentFactory = new TemplateIntentFactory(); 160 mClassificationIntentFactory = mSettings.isTemplateIntentFactoryEnabled() 161 ? new TemplateClassificationIntentFactory( 162 mTemplateIntentFactory, new LegacyClassificationIntentFactory()) 163 : new LegacyClassificationIntentFactory(); 164 mActionsModelParamsSupplier = new ActionsModelParamsSupplier(mContext, 165 () -> { 166 synchronized (mLock) { 167 // Clear mActionsImpl here, so that we will create a new 168 // ActionsSuggestionsModel object with the new flag in the next request. 169 mActionsImpl = null; 170 mActionModelInUse = null; 171 } 172 }); 173 } 174 TextClassifierImpl(Context context, TextClassificationConstants settings)175 public TextClassifierImpl(Context context, TextClassificationConstants settings) { 176 this(context, settings, TextClassifier.NO_OP); 177 } 178 179 /** @inheritDoc */ 180 @Override 181 @WorkerThread suggestSelection(TextSelection.Request request)182 public TextSelection suggestSelection(TextSelection.Request request) { 183 Preconditions.checkNotNull(request); 184 Utils.checkMainThread(); 185 try { 186 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 187 final String string = request.getText().toString(); 188 if (string.length() > 0 189 && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) { 190 final String localesString = concatenateLocales(request.getDefaultLocales()); 191 final String detectLanguageTags = detectLanguageTagsFromText(request.getText()); 192 final ZonedDateTime refTime = ZonedDateTime.now(); 193 final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales()); 194 final int start; 195 final int end; 196 if (mSettings.isModelDarkLaunchEnabled() && !request.isDarkLaunchAllowed()) { 197 start = request.getStartIndex(); 198 end = request.getEndIndex(); 199 } else { 200 final int[] startEnd = annotatorImpl.suggestSelection( 201 string, request.getStartIndex(), request.getEndIndex(), 202 new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags)); 203 start = startEnd[0]; 204 end = startEnd[1]; 205 } 206 if (start < end 207 && start >= 0 && end <= string.length() 208 && start <= request.getStartIndex() && end >= request.getEndIndex()) { 209 final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end); 210 final AnnotatorModel.ClassificationResult[] results = 211 annotatorImpl.classifyText( 212 string, start, end, 213 new AnnotatorModel.ClassificationOptions( 214 refTime.toInstant().toEpochMilli(), 215 refTime.getZone().getId(), 216 localesString, 217 detectLanguageTags), 218 // Passing null here to suppress intent generation 219 // TODO: Use an explicit flag to suppress it. 220 /* appContext */ null, 221 /* deviceLocales */null); 222 final int size = results.length; 223 for (int i = 0; i < size; i++) { 224 tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore()); 225 } 226 return tsBuilder.setId(createId( 227 string, request.getStartIndex(), request.getEndIndex())) 228 .build(); 229 } else { 230 // We can not trust the result. Log the issue and ignore the result. 231 Log.d(LOG_TAG, "Got bad indices for input text. Ignoring result."); 232 } 233 } 234 } catch (Throwable t) { 235 // Avoid throwing from this method. Log the error. 236 Log.e(LOG_TAG, 237 "Error suggesting selection for text. No changes to selection suggested.", 238 t); 239 } 240 // Getting here means something went wrong, return a NO_OP result. 241 return mFallback.suggestSelection(request); 242 } 243 244 /** @inheritDoc */ 245 @Override 246 @WorkerThread classifyText(TextClassification.Request request)247 public TextClassification classifyText(TextClassification.Request request) { 248 Preconditions.checkNotNull(request); 249 Utils.checkMainThread(); 250 try { 251 final int rangeLength = request.getEndIndex() - request.getStartIndex(); 252 final String string = request.getText().toString(); 253 if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) { 254 final String localesString = concatenateLocales(request.getDefaultLocales()); 255 final String detectLanguageTags = detectLanguageTagsFromText(request.getText()); 256 final ZonedDateTime refTime = request.getReferenceTime() != null 257 ? request.getReferenceTime() : ZonedDateTime.now(); 258 final AnnotatorModel.ClassificationResult[] results = 259 getAnnotatorImpl(request.getDefaultLocales()) 260 .classifyText( 261 string, request.getStartIndex(), request.getEndIndex(), 262 new AnnotatorModel.ClassificationOptions( 263 refTime.toInstant().toEpochMilli(), 264 refTime.getZone().getId(), 265 localesString, 266 detectLanguageTags), 267 mContext, 268 getResourceLocalesString() 269 ); 270 if (results.length > 0) { 271 return createClassificationResult( 272 results, string, 273 request.getStartIndex(), request.getEndIndex(), refTime.toInstant()); 274 } 275 } 276 } catch (Throwable t) { 277 // Avoid throwing from this method. Log the error. 278 Log.e(LOG_TAG, "Error getting text classification info.", t); 279 } 280 // Getting here means something went wrong, return a NO_OP result. 281 return mFallback.classifyText(request); 282 } 283 284 /** @inheritDoc */ 285 @Override 286 @WorkerThread generateLinks(@onNull TextLinks.Request request)287 public TextLinks generateLinks(@NonNull TextLinks.Request request) { 288 Preconditions.checkNotNull(request); 289 Utils.checkTextLength(request.getText(), getMaxGenerateLinksTextLength()); 290 Utils.checkMainThread(); 291 292 if (!mSettings.isSmartLinkifyEnabled() && request.isLegacyFallback()) { 293 return Utils.generateLegacyLinks(request); 294 } 295 296 final String textString = request.getText().toString(); 297 final TextLinks.Builder builder = new TextLinks.Builder(textString); 298 299 try { 300 final long startTimeMs = System.currentTimeMillis(); 301 final ZonedDateTime refTime = ZonedDateTime.now(); 302 final Collection<String> entitiesToIdentify = request.getEntityConfig() != null 303 ? request.getEntityConfig().resolveEntityListModifications( 304 getEntitiesForHints(request.getEntityConfig().getHints())) 305 : mSettings.getEntityListDefault(); 306 final String localesString = concatenateLocales(request.getDefaultLocales()); 307 final String detectLanguageTags = detectLanguageTagsFromText(request.getText()); 308 final AnnotatorModel annotatorImpl = 309 getAnnotatorImpl(request.getDefaultLocales()); 310 final boolean isSerializedEntityDataEnabled = 311 ExtrasUtils.isSerializedEntityDataEnabled(request); 312 final AnnotatorModel.AnnotatedSpan[] annotations = 313 annotatorImpl.annotate( 314 textString, 315 new AnnotatorModel.AnnotationOptions( 316 refTime.toInstant().toEpochMilli(), 317 refTime.getZone().getId(), 318 localesString, 319 detectLanguageTags, 320 entitiesToIdentify, 321 AnnotatorModel.AnnotationUsecase.SMART.getValue(), 322 isSerializedEntityDataEnabled)); 323 for (AnnotatorModel.AnnotatedSpan span : annotations) { 324 final AnnotatorModel.ClassificationResult[] results = 325 span.getClassification(); 326 if (results.length == 0 327 || !entitiesToIdentify.contains(results[0].getCollection())) { 328 continue; 329 } 330 final Map<String, Float> entityScores = new ArrayMap<>(); 331 for (int i = 0; i < results.length; i++) { 332 entityScores.put(results[i].getCollection(), results[i].getScore()); 333 } 334 Bundle extras = new Bundle(); 335 if (isSerializedEntityDataEnabled) { 336 ExtrasUtils.putEntities(extras, results); 337 } 338 builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras); 339 } 340 final TextLinks links = builder.build(); 341 final long endTimeMs = System.currentTimeMillis(); 342 final String callingPackageName = request.getCallingPackageName() == null 343 ? mContext.getPackageName() // local (in process) TC. 344 : request.getCallingPackageName(); 345 mGenerateLinksLogger.logGenerateLinks( 346 request.getText(), links, callingPackageName, endTimeMs - startTimeMs); 347 return links; 348 } catch (Throwable t) { 349 // Avoid throwing from this method. Log the error. 350 Log.e(LOG_TAG, "Error getting links info.", t); 351 } 352 return mFallback.generateLinks(request); 353 } 354 355 /** @inheritDoc */ 356 @Override getMaxGenerateLinksTextLength()357 public int getMaxGenerateLinksTextLength() { 358 return mSettings.getGenerateLinksMaxTextLength(); 359 } 360 getEntitiesForHints(Collection<String> hints)361 private Collection<String> getEntitiesForHints(Collection<String> hints) { 362 final boolean editable = hints.contains(HINT_TEXT_IS_EDITABLE); 363 final boolean notEditable = hints.contains(HINT_TEXT_IS_NOT_EDITABLE); 364 365 // Use the default if there is no hint, or conflicting ones. 366 final boolean useDefault = editable == notEditable; 367 if (useDefault) { 368 return mSettings.getEntityListDefault(); 369 } else if (editable) { 370 return mSettings.getEntityListEditable(); 371 } else { // notEditable 372 return mSettings.getEntityListNotEditable(); 373 } 374 } 375 376 /** @inheritDoc */ 377 @Override onSelectionEvent(SelectionEvent event)378 public void onSelectionEvent(SelectionEvent event) { 379 mSessionLogger.writeEvent(event); 380 } 381 382 @Override onTextClassifierEvent(TextClassifierEvent event)383 public void onTextClassifierEvent(TextClassifierEvent event) { 384 if (DEBUG) { 385 Log.d(DEFAULT_LOG_TAG, "onTextClassifierEvent() called with: event = [" + event + "]"); 386 } 387 try { 388 final SelectionEvent selEvent = event.toSelectionEvent(); 389 if (selEvent != null) { 390 mSessionLogger.writeEvent(selEvent); 391 } else { 392 mTextClassifierEventTronLogger.writeEvent(event); 393 } 394 } catch (Exception e) { 395 Log.e(LOG_TAG, "Error writing event", e); 396 } 397 } 398 399 /** @inheritDoc */ 400 @Override detectLanguage(@onNull TextLanguage.Request request)401 public TextLanguage detectLanguage(@NonNull TextLanguage.Request request) { 402 Preconditions.checkNotNull(request); 403 Utils.checkMainThread(); 404 try { 405 final TextLanguage.Builder builder = new TextLanguage.Builder(); 406 final LangIdModel.LanguageResult[] langResults = 407 getLangIdImpl().detectLanguages(request.getText().toString()); 408 for (int i = 0; i < langResults.length; i++) { 409 builder.putLocale( 410 ULocale.forLanguageTag(langResults[i].getLanguage()), 411 langResults[i].getScore()); 412 } 413 return builder.build(); 414 } catch (Throwable t) { 415 // Avoid throwing from this method. Log the error. 416 Log.e(LOG_TAG, "Error detecting text language.", t); 417 } 418 return mFallback.detectLanguage(request); 419 } 420 421 @Override suggestConversationActions(ConversationActions.Request request)422 public ConversationActions suggestConversationActions(ConversationActions.Request request) { 423 Preconditions.checkNotNull(request); 424 Utils.checkMainThread(); 425 try { 426 ActionsSuggestionsModel actionsImpl = getActionsImpl(); 427 if (actionsImpl == null) { 428 // Actions model is optional, fallback if it is not available. 429 return mFallback.suggestConversationActions(request); 430 } 431 ActionsSuggestionsModel.ConversationMessage[] nativeMessages = 432 ActionsSuggestionsHelper.toNativeMessages( 433 request.getConversation(), this::detectLanguageTagsFromText); 434 if (nativeMessages.length == 0) { 435 return mFallback.suggestConversationActions(request); 436 } 437 ActionsSuggestionsModel.Conversation nativeConversation = 438 new ActionsSuggestionsModel.Conversation(nativeMessages); 439 440 ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions = 441 actionsImpl.suggestActionsWithIntents( 442 nativeConversation, 443 null, 444 mContext, 445 getResourceLocalesString(), 446 getAnnotatorImpl(LocaleList.getDefault())); 447 return createConversationActionResult(request, nativeSuggestions); 448 } catch (Throwable t) { 449 // Avoid throwing from this method. Log the error. 450 Log.e(LOG_TAG, "Error suggesting conversation actions.", t); 451 } 452 return mFallback.suggestConversationActions(request); 453 } 454 455 /** 456 * Returns the {@link ConversationAction} result, with a non-null extras. 457 * <p> 458 * Whenever the RemoteAction is non-null, you can expect its corresponding intent 459 * with a non-null component name is in the extras. 460 */ createConversationActionResult( ConversationActions.Request request, ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions)461 private ConversationActions createConversationActionResult( 462 ConversationActions.Request request, 463 ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) { 464 Collection<String> expectedTypes = resolveActionTypesFromRequest(request); 465 List<ConversationAction> conversationActions = new ArrayList<>(); 466 for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) { 467 String actionType = nativeSuggestion.getActionType(); 468 if (!expectedTypes.contains(actionType)) { 469 continue; 470 } 471 LabeledIntent.Result labeledIntentResult = 472 ActionsSuggestionsHelper.createLabeledIntentResult( 473 mContext, 474 mTemplateIntentFactory, 475 nativeSuggestion); 476 RemoteAction remoteAction = null; 477 Bundle extras = new Bundle(); 478 if (labeledIntentResult != null) { 479 remoteAction = labeledIntentResult.remoteAction; 480 ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent); 481 } 482 ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData()); 483 ExtrasUtils.putEntitiesExtras( 484 extras, 485 TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData())); 486 conversationActions.add( 487 new ConversationAction.Builder(actionType) 488 .setConfidenceScore(nativeSuggestion.getScore()) 489 .setTextReply(nativeSuggestion.getResponseText()) 490 .setAction(remoteAction) 491 .setExtras(extras) 492 .build()); 493 } 494 conversationActions = 495 ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions); 496 if (request.getMaxSuggestions() >= 0 497 && conversationActions.size() > request.getMaxSuggestions()) { 498 conversationActions = conversationActions.subList(0, request.getMaxSuggestions()); 499 } 500 String resultId = ActionsSuggestionsHelper.createResultId( 501 mContext, 502 request.getConversation(), 503 mActionModelInUse.getVersion(), 504 mActionModelInUse.getSupportedLocales()); 505 return new ConversationActions(conversationActions, resultId); 506 } 507 508 @Nullable detectLanguageTagsFromText(CharSequence text)509 private String detectLanguageTagsFromText(CharSequence text) { 510 if (!mSettings.isDetectLanguagesFromTextEnabled()) { 511 return null; 512 } 513 final float threshold = getLangIdThreshold(); 514 if (threshold < 0 || threshold > 1) { 515 Log.w(LOG_TAG, 516 "[detectLanguageTagsFromText] unexpected threshold is found: " + threshold); 517 return null; 518 } 519 TextLanguage.Request request = new TextLanguage.Request.Builder(text).build(); 520 TextLanguage textLanguage = detectLanguage(request); 521 int localeHypothesisCount = textLanguage.getLocaleHypothesisCount(); 522 List<String> languageTags = new ArrayList<>(); 523 for (int i = 0; i < localeHypothesisCount; i++) { 524 ULocale locale = textLanguage.getLocale(i); 525 if (textLanguage.getConfidenceScore(locale) < threshold) { 526 break; 527 } 528 languageTags.add(locale.toLanguageTag()); 529 } 530 if (languageTags.isEmpty()) { 531 return null; 532 } 533 return String.join(",", languageTags); 534 } 535 resolveActionTypesFromRequest(ConversationActions.Request request)536 private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) { 537 List<String> defaultActionTypes = 538 request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION) 539 ? mSettings.getNotificationConversationActionTypes() 540 : mSettings.getInAppConversationActionTypes(); 541 return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes); 542 } 543 getAnnotatorImpl(LocaleList localeList)544 private AnnotatorModel getAnnotatorImpl(LocaleList localeList) 545 throws FileNotFoundException { 546 synchronized (mLock) { 547 localeList = localeList == null ? LocaleList.getDefault() : localeList; 548 final ModelFileManager.ModelFile bestModel = 549 mAnnotatorModelFileManager.findBestModelFile(localeList); 550 if (bestModel == null) { 551 throw new FileNotFoundException( 552 "No annotator model for " + localeList.toLanguageTags()); 553 } 554 if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) { 555 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); 556 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( 557 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 558 try { 559 if (pfd != null) { 560 // The current annotator model may be still used by another thread / model. 561 // Do not call close() here, and let the GC to clean it up when no one else 562 // is using it. 563 mAnnotatorImpl = new AnnotatorModel(pfd.getFd()); 564 mAnnotatorModelInUse = bestModel; 565 } 566 } finally { 567 maybeCloseAndLogError(pfd); 568 } 569 } 570 return mAnnotatorImpl; 571 } 572 } 573 getLangIdImpl()574 private LangIdModel getLangIdImpl() throws FileNotFoundException { 575 synchronized (mLock) { 576 final ModelFileManager.ModelFile bestModel = 577 mLangIdModelFileManager.findBestModelFile(null); 578 if (bestModel == null) { 579 throw new FileNotFoundException("No LangID model is found"); 580 } 581 if (mLangIdImpl == null || !Objects.equals(mLangIdModelInUse, bestModel)) { 582 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); 583 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( 584 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 585 try { 586 if (pfd != null) { 587 mLangIdImpl = new LangIdModel(pfd.getFd()); 588 mLangIdModelInUse = bestModel; 589 } 590 } finally { 591 maybeCloseAndLogError(pfd); 592 } 593 } 594 return mLangIdImpl; 595 } 596 } 597 598 @Nullable getActionsImpl()599 private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException { 600 synchronized (mLock) { 601 // TODO: Use LangID to determine the locale we should use here? 602 final ModelFileManager.ModelFile bestModel = 603 mActionsModelFileManager.findBestModelFile(LocaleList.getDefault()); 604 if (bestModel == null) { 605 return null; 606 } 607 if (mActionsImpl == null || !Objects.equals(mActionModelInUse, bestModel)) { 608 Log.d(DEFAULT_LOG_TAG, "Loading " + bestModel); 609 final ParcelFileDescriptor pfd = ParcelFileDescriptor.open( 610 new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY); 611 try { 612 if (pfd == null) { 613 Log.d(LOG_TAG, "Failed to read the model file: " + bestModel.getPath()); 614 return null; 615 } 616 ActionsModelParams params = mActionsModelParamsSupplier.get(); 617 mActionsImpl = new ActionsSuggestionsModel( 618 pfd.getFd(), params.getSerializedPreconditions(bestModel)); 619 mActionModelInUse = bestModel; 620 } finally { 621 maybeCloseAndLogError(pfd); 622 } 623 } 624 return mActionsImpl; 625 } 626 } 627 createId(String text, int start, int end)628 private String createId(String text, int start, int end) { 629 synchronized (mLock) { 630 return SelectionSessionLogger.createId(text, start, end, mContext, 631 mAnnotatorModelInUse.getVersion(), 632 mAnnotatorModelInUse.getSupportedLocales()); 633 } 634 } 635 concatenateLocales(@ullable LocaleList locales)636 private static String concatenateLocales(@Nullable LocaleList locales) { 637 return (locales == null) ? "" : locales.toLanguageTags(); 638 } 639 createClassificationResult( AnnotatorModel.ClassificationResult[] classifications, String text, int start, int end, @Nullable Instant referenceTime)640 private TextClassification createClassificationResult( 641 AnnotatorModel.ClassificationResult[] classifications, 642 String text, int start, int end, @Nullable Instant referenceTime) { 643 final String classifiedText = text.substring(start, end); 644 final TextClassification.Builder builder = new TextClassification.Builder() 645 .setText(classifiedText); 646 647 final int typeCount = classifications.length; 648 AnnotatorModel.ClassificationResult highestScoringResult = 649 typeCount > 0 ? classifications[0] : null; 650 for (int i = 0; i < typeCount; i++) { 651 builder.setEntityType(classifications[i]); 652 if (classifications[i].getScore() > highestScoringResult.getScore()) { 653 highestScoringResult = classifications[i]; 654 } 655 } 656 657 final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end); 658 final Bundle textLanguagesBundle = languagesBundles.first; 659 final Bundle foreignLanguageBundle = languagesBundles.second; 660 builder.setForeignLanguageExtra(foreignLanguageBundle); 661 662 boolean isPrimaryAction = true; 663 final List<LabeledIntent> labeledIntents = mClassificationIntentFactory.create( 664 mContext, 665 classifiedText, 666 foreignLanguageBundle != null, 667 referenceTime, 668 highestScoringResult); 669 final LabeledIntent.TitleChooser titleChooser = 670 (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity; 671 672 for (LabeledIntent labeledIntent : labeledIntents) { 673 final LabeledIntent.Result result = 674 labeledIntent.resolve(mContext, titleChooser, textLanguagesBundle); 675 if (result == null) { 676 continue; 677 } 678 679 final Intent intent = result.resolvedIntent; 680 final RemoteAction action = result.remoteAction; 681 if (isPrimaryAction) { 682 // For O backwards compatibility, the first RemoteAction is also written to the 683 // legacy API fields. 684 builder.setIcon(action.getIcon().loadDrawable(mContext)); 685 builder.setLabel(action.getTitle().toString()); 686 builder.setIntent(intent); 687 builder.setOnClickListener(TextClassification.createIntentOnClickListener( 688 TextClassification.createPendingIntent( 689 mContext, intent, labeledIntent.requestCode))); 690 isPrimaryAction = false; 691 } 692 builder.addAction(action, intent); 693 } 694 return builder.setId(createId(text, start, end)).build(); 695 } 696 697 /** 698 * Returns a bundle pair with language detection information for extras. 699 * <p> 700 * Pair.first = textLanguagesBundle - A bundle containing information about all detected 701 * languages in the text. May be null if language detection fails or is disabled. This is 702 * typically expected to be added to a textClassifier generated remote action intent. 703 * See {@link ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}. 704 * See {@link ExtrasUtils#getTopLanguage(Intent)}. 705 * <p> 706 * Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if the 707 * system finds the text to be in a foreign language. Otherwise is null. 708 * See {@link TextClassification.Builder#setForeignLanguageExtra(Bundle)}. 709 * 710 * @param context the context of the text to detect languages for 711 * @param start the start index of the text 712 * @param end the end index of the text 713 */ 714 // TODO: Revisit this algorithm. 715 // TODO: Consider making this public API. generateLanguageBundles(String context, int start, int end)716 private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) { 717 if (!mSettings.isTranslateInClassificationEnabled()) { 718 return null; 719 } 720 try { 721 final float threshold = getLangIdThreshold(); 722 if (threshold < 0 || threshold > 1) { 723 Log.w(LOG_TAG, 724 "[detectForeignLanguage] unexpected threshold is found: " + threshold); 725 return Pair.create(null, null); 726 } 727 728 final EntityConfidence languageScores = detectLanguages(context, start, end); 729 if (languageScores.getEntities().isEmpty()) { 730 return Pair.create(null, null); 731 } 732 733 final Bundle textLanguagesBundle = new Bundle(); 734 ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores); 735 736 final String language = languageScores.getEntities().get(0); 737 final float score = languageScores.getConfidenceScore(language); 738 if (score < threshold) { 739 return Pair.create(textLanguagesBundle, null); 740 } 741 742 Log.v(LOG_TAG, String.format( 743 Locale.US, "Language detected: <%s:%.2f>", language, score)); 744 745 final Locale detected = new Locale(language); 746 final LocaleList deviceLocales = LocaleList.getDefault(); 747 final int size = deviceLocales.size(); 748 for (int i = 0; i < size; i++) { 749 if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) { 750 return Pair.create(textLanguagesBundle, null); 751 } 752 } 753 final Bundle foreignLanguageBundle = ExtrasUtils.createForeignLanguageExtra( 754 detected.getLanguage(), score, getLangIdImpl().getVersion()); 755 return Pair.create(textLanguagesBundle, foreignLanguageBundle); 756 } catch (Throwable t) { 757 Log.e(LOG_TAG, "Error generating language bundles.", t); 758 } 759 return Pair.create(null, null); 760 } 761 762 /** 763 * Detect the language of a piece of text by taking surrounding text into consideration. 764 * 765 * @param text text providing context for the text for which its language is to be detected 766 * @param start the start index of the text to detect its language 767 * @param end the end index of the text to detect its language 768 */ 769 // TODO: Revisit this algorithm. detectLanguages(String text, int start, int end)770 private EntityConfidence detectLanguages(String text, int start, int end) 771 throws FileNotFoundException { 772 Preconditions.checkArgument(start >= 0); 773 Preconditions.checkArgument(end <= text.length()); 774 Preconditions.checkArgument(start <= end); 775 776 final float[] langIdContextSettings = mSettings.getLangIdContextSettings(); 777 // The minimum size of text to prefer for detection. 778 final int minimumTextSize = (int) langIdContextSettings[0]; 779 // For reducing the score when text is less than the preferred size. 780 final float penalizeRatio = langIdContextSettings[1]; 781 // Original detection score to surrounding text detection score ratios. 782 final float subjectTextScoreRatio = langIdContextSettings[2]; 783 final float moreTextScoreRatio = 1f - subjectTextScoreRatio; 784 Log.v(LOG_TAG, 785 String.format(Locale.US, "LangIdContextSettings: " 786 + "minimumTextSize=%d, penalizeRatio=%.2f, " 787 + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f", 788 minimumTextSize, penalizeRatio, subjectTextScoreRatio, moreTextScoreRatio)); 789 790 if (end - start < minimumTextSize && penalizeRatio <= 0) { 791 return new EntityConfidence(Collections.emptyMap()); 792 } 793 794 final String subject = text.substring(start, end); 795 final EntityConfidence scores = detectLanguages(subject); 796 797 if (subject.length() >= minimumTextSize 798 || subject.length() == text.length() 799 || subjectTextScoreRatio * penalizeRatio >= 1) { 800 return scores; 801 } 802 803 final EntityConfidence moreTextScores; 804 if (moreTextScoreRatio >= 0) { 805 // Attempt to grow the detection text to be at least minimumTextSize long. 806 final String moreText = Utils.getSubString(text, start, end, minimumTextSize); 807 moreTextScores = detectLanguages(moreText); 808 } else { 809 moreTextScores = new EntityConfidence(Collections.emptyMap()); 810 } 811 812 // Combine the original detection scores with the those returned after including more text. 813 final Map<String, Float> newScores = new ArrayMap<>(); 814 final Set<String> languages = new ArraySet<>(); 815 languages.addAll(scores.getEntities()); 816 languages.addAll(moreTextScores.getEntities()); 817 for (String language : languages) { 818 final float score = (subjectTextScoreRatio * scores.getConfidenceScore(language) 819 + moreTextScoreRatio * moreTextScores.getConfidenceScore(language)) 820 * penalizeRatio; 821 newScores.put(language, score); 822 } 823 return new EntityConfidence(newScores); 824 } 825 826 /** 827 * Detect languages for the specified text. 828 */ detectLanguages(String text)829 private EntityConfidence detectLanguages(String text) throws FileNotFoundException { 830 final LangIdModel langId = getLangIdImpl(); 831 final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text); 832 final Map<String, Float> languagesMap = new ArrayMap<>(); 833 for (LanguageResult langResult : langResults) { 834 languagesMap.put(langResult.getLanguage(), langResult.getScore()); 835 } 836 return new EntityConfidence(languagesMap); 837 } 838 getLangIdThreshold()839 private float getLangIdThreshold() { 840 try { 841 return mSettings.getLangIdThresholdOverride() >= 0 842 ? mSettings.getLangIdThresholdOverride() 843 : getLangIdImpl().getLangIdThreshold(); 844 } catch (FileNotFoundException e) { 845 final float defaultThreshold = 0.5f; 846 Log.v(LOG_TAG, "Using default foreign language threshold: " + defaultThreshold); 847 return defaultThreshold; 848 } 849 } 850 851 @Override dump(@onNull IndentingPrintWriter printWriter)852 public void dump(@NonNull IndentingPrintWriter printWriter) { 853 synchronized (mLock) { 854 printWriter.println("TextClassifierImpl:"); 855 printWriter.increaseIndent(); 856 printWriter.println("Annotator model file(s):"); 857 printWriter.increaseIndent(); 858 for (ModelFileManager.ModelFile modelFile : 859 mAnnotatorModelFileManager.listModelFiles()) { 860 printWriter.println(modelFile.toString()); 861 } 862 printWriter.decreaseIndent(); 863 printWriter.println("LangID model file(s):"); 864 printWriter.increaseIndent(); 865 for (ModelFileManager.ModelFile modelFile : 866 mLangIdModelFileManager.listModelFiles()) { 867 printWriter.println(modelFile.toString()); 868 } 869 printWriter.decreaseIndent(); 870 printWriter.println("Actions model file(s):"); 871 printWriter.increaseIndent(); 872 for (ModelFileManager.ModelFile modelFile : 873 mActionsModelFileManager.listModelFiles()) { 874 printWriter.println(modelFile.toString()); 875 } 876 printWriter.decreaseIndent(); 877 printWriter.printPair("mFallback", mFallback); 878 printWriter.decreaseIndent(); 879 printWriter.println(); 880 } 881 } 882 883 /** 884 * Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. 885 */ maybeCloseAndLogError(@ullable ParcelFileDescriptor fd)886 private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) { 887 if (fd == null) { 888 return; 889 } 890 891 try { 892 fd.close(); 893 } catch (IOException e) { 894 Log.e(LOG_TAG, "Error closing file.", e); 895 } 896 } 897 898 /** 899 * Returns the locales string for the current resources configuration. 900 */ getResourceLocalesString()901 private String getResourceLocalesString() { 902 try { 903 return mContext.getResources().getConfiguration().getLocales().toLanguageTags(); 904 } catch (NullPointerException e) { 905 // NPE is unexpected. Erring on the side of caution. 906 return LocaleList.getDefault().toLanguageTags(); 907 } 908 } 909 } 910