From aeff02382cad2f5563c3495cd3cda1c980c511f9 Mon Sep 17 00:00:00 2001 From: Hamza Date: Thu, 7 May 2026 18:02:28 +0200 Subject: [PATCH] feat(classifier): only train on user set important messages Signed-off-by: Hamza --- lib/BackgroundJob/FollowUpClassifierJob.php | 2 + lib/Contracts/IMailManager.php | 6 +- lib/Db/MessageTags.php | 15 +++++ lib/Db/TagMapper.php | 43 +++++++++++++- .../Version5900Date20260507120000.php | 59 +++++++++++++++++++ .../Classification/ImportanceClassifier.php | 28 ++++++++- .../Classification/NewMessagesClassifier.php | 3 +- lib/Service/MailManager.php | 9 +-- tests/Unit/Job/FollowUpClassifierJobTest.php | 7 ++- tests/Unit/Service/MailManagerTest.php | 4 ++ 10 files changed, 163 insertions(+), 13 deletions(-) create mode 100644 lib/Migration/Version5900Date20260507120000.php diff --git a/lib/BackgroundJob/FollowUpClassifierJob.php b/lib/BackgroundJob/FollowUpClassifierJob.php index 4856aa7d5d..1a030a0b18 100644 --- a/lib/BackgroundJob/FollowUpClassifierJob.php +++ b/lib/BackgroundJob/FollowUpClassifierJob.php @@ -11,6 +11,7 @@ use OCA\Mail\Contracts\IMailManager; use OCA\Mail\Db\Message; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\ThreadMapper; use OCA\Mail\Exception\ClientException; use OCA\Mail\Service\AccountService; @@ -102,6 +103,7 @@ public function run($argument): void { $message, $tag, true, + MessageTags::TYPE_CLASSIFIER, ); } } diff --git a/lib/Contracts/IMailManager.php b/lib/Contracts/IMailManager.php index c18bf8f4a1..586825a321 100644 --- a/lib/Contracts/IMailManager.php +++ b/lib/Contracts/IMailManager.php @@ -14,6 +14,7 @@ use OCA\Mail\Attachment; use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\Message; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\Tag; use OCA\Mail\Exception\ClientException; use OCA\Mail\Exception\ServiceException; @@ -183,11 +184,14 @@ public function flagMessage(Account $account, string $mailbox, int $uid, string * @param Message $message * @param Tag $tag * @param bool $value + * @param string $type Source of the tag, one of MessageTags::TYPE_USER (default) + * or MessageTags::TYPE_CLASSIFIER. Used so the importance + * classifier does not train on its own predictions. * * @throws ClientException * @throws ServiceException */ - public function tagMessage(Account $account, string $mailbox, Message $message, Tag $tag, bool $value): void; + public function tagMessage(Account $account, string $mailbox, Message $message, Tag $tag, bool $value, string $type = MessageTags::TYPE_USER): void; /** * @param Account $account diff --git a/lib/Db/MessageTags.php b/lib/Db/MessageTags.php index ac678827c1..2488087baa 100644 --- a/lib/Db/MessageTags.php +++ b/lib/Db/MessageTags.php @@ -18,13 +18,27 @@ * @method void setImapMessageId(string $imapMessageId) * @method int getTagId() * @method void setTagId(int $tagId) + * @method string getType() + * @method void setType(string $type) */ class MessageTags extends Entity implements JsonSerializable { + /** + * Tag was applied by the user (manual interaction). + */ + public const TYPE_USER = 'user'; + + /** + * Tag was applied by the automatic importance classifier. + */ + public const TYPE_CLASSIFIER = 'classifier'; + protected $imapMessageId; protected $tagId; + protected $type = self::TYPE_USER; public function __construct() { $this->addType('tagId', 'integer'); + $this->addType('type', 'string'); } #[\Override] @@ -34,6 +48,7 @@ public function jsonSerialize() { 'id' => $this->getId(), 'imapMessageId' => $this->getImapMessageId(), 'tagId' => $this->getTagId(), + 'type' => $this->getType(), ]; } } diff --git a/lib/Db/TagMapper.php b/lib/Db/TagMapper.php index 0c187d1441..c9b9ea6936 100644 --- a/lib/Db/TagMapper.php +++ b/lib/Db/TagMapper.php @@ -76,7 +76,7 @@ public function getAllTagsForUser(string $userId): array { * * To tag (flag) a message on IMAP, @see \OCA\Mail\Service\MailManager::tagMessage */ - public function tagMessage(Tag $tag, string $messageId, string $userId): void { + public function tagMessage(Tag $tag, string $messageId, string $userId, string $type = MessageTags::TYPE_USER): void { try { $tag = $this->getTagByImapLabel($tag->getImapLabel(), $userId); } catch (DoesNotExistException $e) { @@ -86,7 +86,8 @@ public function tagMessage(Tag $tag, string $messageId, string $userId): void { $qb = $this->db->getQueryBuilder(); $qb->insert('mail_message_tags') ->setValue('imap_message_id', $qb->createNamedParameter($messageId)) - ->setValue('tag_id', $qb->createNamedParameter($tag->getId(), IQueryBuilder::PARAM_INT)); + ->setValue('tag_id', $qb->createNamedParameter($tag->getId(), IQueryBuilder::PARAM_INT)) + ->setValue('type', $qb->createNamedParameter($type)); $qb->executeStatement(); } @@ -143,6 +144,44 @@ public function getAllTagsForMessages(array $messages, string $userId): array { return $tags; } + /** + * Return the IMAP message IDs of messages whose given tag was applied by + * the automatic classifier (rather than the user). Useful for excluding + * classifier-applied labels from the importance classifier's training + * set so it does not reinforce its own predictions. + * + * @param Message[] $messages + * @return string[] + */ + public function getClassifierTaggedMessageIds(array $messages, string $userId, string $imapLabel): array { + if ($messages === []) { + return []; + } + $ids = array_map(static fn (Message $message) => $message->getMessageId(), $messages); + + $qb = $this->db->getQueryBuilder(); + $query = $qb->selectDistinct('mt.imap_message_id') + ->from($this->getTableName(), 't') + ->join('t', 'mail_message_tags', 'mt', $qb->expr()->eq('t.id', 'mt.tag_id', IQueryBuilder::PARAM_INT)) + ->where( + $qb->expr()->in('mt.imap_message_id', $qb->createParameter('ids'), IQueryBuilder::PARAM_STR_ARRAY), + $qb->expr()->eq('t.user_id', $qb->createNamedParameter($userId, IQueryBuilder::PARAM_STR)), + $qb->expr()->eq('t.imap_label', $qb->createNamedParameter($imapLabel, IQueryBuilder::PARAM_STR)), + $qb->expr()->eq('mt.type', $qb->createNamedParameter(MessageTags::TYPE_CLASSIFIER, IQueryBuilder::PARAM_STR)), + ); + + $messageIds = []; + foreach (array_chunk($ids, 1000) as $chunk) { + $query->setParameter('ids', $chunk, IQueryBuilder::PARAM_STR_ARRAY); + $result = $query->executeQuery(); + while (($row = $result->fetch()) !== false) { + $messageIds[] = $row['imap_message_id']; + } + $result->closeCursor(); + } + return $messageIds; + } + /** * @param Message[] $messages * @param string $userId diff --git a/lib/Migration/Version5900Date20260507120000.php b/lib/Migration/Version5900Date20260507120000.php new file mode 100644 index 0000000000..65199b4321 --- /dev/null +++ b/lib/Migration/Version5900Date20260507120000.php @@ -0,0 +1,59 @@ +hasTable('mail_message_tags')) { + return null; + } + + $table = $schema->getTable('mail_message_tags'); + if (!$table->hasColumn('type')) { + $table->addColumn('type', Types::STRING, [ + 'notnull' => true, + 'length' => 16, + 'default' => MessageTags::TYPE_USER, + ]); + } + + return $schema; + } +} diff --git a/lib/Service/Classification/ImportanceClassifier.php b/lib/Service/Classification/ImportanceClassifier.php index bf8ee00b9e..4e82ad0eae 100644 --- a/lib/Service/Classification/ImportanceClassifier.php +++ b/lib/Service/Classification/ImportanceClassifier.php @@ -16,6 +16,8 @@ use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\Message; use OCA\Mail\Db\MessageMapper; +use OCA\Mail\Db\Tag; +use OCA\Mail\Db\TagMapper; use OCA\Mail\Exception\ClassifierTrainingException; use OCA\Mail\Exception\ServiceException; use OCA\Mail\Model\Classifier; @@ -109,18 +111,22 @@ class ImportanceClassifier { private ContainerInterface $container; + private TagMapper $tagMapper; + public function __construct(MailboxMapper $mailboxMapper, MessageMapper $messageMapper, PersistenceService $persistenceService, PerformanceLogger $performanceLogger, ImportanceRulesClassifier $rulesClassifier, - ContainerInterface $container) { + ContainerInterface $container, + TagMapper $tagMapper) { $this->mailboxMapper = $mailboxMapper; $this->messageMapper = $messageMapper; $this->persistenceService = $persistenceService; $this->performanceLogger = $performanceLogger; $this->rulesClassifier = $rulesClassifier; $this->container = $container; + $this->tagMapper = $tagMapper; } private static function createDefaultEstimator(): KNearestNeighbors { @@ -180,10 +186,28 @@ public function buildDataSet( $this->messageMapper->findLatestMessages($account->getUserId(), $mailboxIds, self::MAX_TRAINING_SET_SIZE), [$this, 'filterMessageHasSenderEmail'] ); + + // Drop messages whose importance flag was set by the classifier itself. + // We have no ground truth for these, so including them would let the + // classifier reinforce its own predictions over time. + $classifierTaggedIds = array_flip($this->tagMapper->getClassifierTaggedMessageIds( + $messages, + $account->getUserId(), + Tag::LABEL_IMPORTANT, + )); + $autoTaggedDropped = 0; + $messages = array_filter($messages, static function (Message $message) use ($classifierTaggedIds, &$autoTaggedDropped) { + if (isset($classifierTaggedIds[$message->getMessageId()])) { + $autoTaggedDropped++; + return false; + } + return true; + }); + $importantMessages = array_filter($messages, static fn (Message $message) => $message->getFlagImportant() === true); $nMessages = count($messages); $nImportant = count($importantMessages); - $logger->debug("found $nMessages messages of which $nImportant are important"); + $logger->debug("found $nMessages messages of which $nImportant are important (dropped $autoTaggedDropped classifier-tagged messages)"); if (count($importantMessages) < self::COLD_START_THRESHOLD) { $logger->info('not enough messages to train a classifier'); return null; diff --git a/lib/Service/Classification/NewMessagesClassifier.php b/lib/Service/Classification/NewMessagesClassifier.php index 449157531f..581904942e 100644 --- a/lib/Service/Classification/NewMessagesClassifier.php +++ b/lib/Service/Classification/NewMessagesClassifier.php @@ -14,6 +14,7 @@ use OCA\Mail\Contracts\IMailManager; use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\Message; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\Tag; use OCA\Mail\Db\TagMapper; use OCA\Mail\Exception\ClientException; @@ -89,7 +90,7 @@ public function classifyNewMessages( if ($prediction) { $message->setFlagImportant(true); $this->mailManager->flagMessage($account, $mailbox->getName(), $message->getUid(), Tag::LABEL_IMPORTANT, true); - $this->mailManager->tagMessage($account, $mailbox->getName(), $message, $importantTag, true); + $this->mailManager->tagMessage($account, $mailbox->getName(), $message, $importantTag, true, MessageTags::TYPE_CLASSIFIER); } } } catch (ServiceException $e) { diff --git a/lib/Service/MailManager.php b/lib/Service/MailManager.php index 1b94bf8cf0..238b3fa690 100644 --- a/lib/Service/MailManager.php +++ b/lib/Service/MailManager.php @@ -21,6 +21,7 @@ use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\Message; use OCA\Mail\Db\MessageMapper as DbMessageMapper; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\MessageTagsMapper; use OCA\Mail\Db\Tag; use OCA\Mail\Db\TagMapper; @@ -493,7 +494,7 @@ public function flagMessage(Account $account, string $mailbox, int $uid, string * @throws ClientException * @throws ServiceException */ - public function tagMessagesWithClient(Horde_Imap_Client_Socket $client, Account $account, Mailbox $mailbox, array $messages, Tag $tag, bool $value):void { + public function tagMessagesWithClient(Horde_Imap_Client_Socket $client, Account $account, Mailbox $mailbox, array $messages, Tag $tag, bool $value, string $type = MessageTags::TYPE_USER):void { if ($this->isPermflagsEnabled($client, $account, $mailbox->getName()) === true) { $messageIds = array_map(static fn (Message $message) => $message->getUid(), $messages); try { @@ -514,7 +515,7 @@ public function tagMessagesWithClient(Horde_Imap_Client_Socket $client, Account if ($value) { foreach ($messages as $message) { - $this->tagMapper->tagMessage($tag, $message->getMessageId(), $account->getUserId()); + $this->tagMapper->tagMessage($tag, $message->getMessageId(), $account->getUserId(), $type); } } else { foreach ($messages as $message) { @@ -540,7 +541,7 @@ public function tagMessagesWithClient(Horde_Imap_Client_Socket $client, Account * @link https://github.com/nextcloud/mail/issues/25 */ #[\Override] - public function tagMessage(Account $account, string $mailbox, Message $message, Tag $tag, bool $value): void { + public function tagMessage(Account $account, string $mailbox, Message $message, Tag $tag, bool $value, string $type = MessageTags::TYPE_USER): void { try { $mb = $this->mailboxMapper->find($account, $mailbox); } catch (DoesNotExistException $e) { @@ -548,7 +549,7 @@ public function tagMessage(Account $account, string $mailbox, Message $message, } $client = $this->imapClientFactory->getClient($account); try { - $this->tagMessagesWithClient($client, $account, $mb, [$message], $tag, $value); + $this->tagMessagesWithClient($client, $account, $mb, [$message], $tag, $value, $type); } finally { $client->logout(); } diff --git a/tests/Unit/Job/FollowUpClassifierJobTest.php b/tests/Unit/Job/FollowUpClassifierJobTest.php index d16f6a7331..f3f85bfc9e 100644 --- a/tests/Unit/Job/FollowUpClassifierJobTest.php +++ b/tests/Unit/Job/FollowUpClassifierJobTest.php @@ -16,6 +16,7 @@ use OCA\Mail\Db\MailAccount; use OCA\Mail\Db\Mailbox; use OCA\Mail\Db\Message; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\Tag; use OCA\Mail\Db\ThreadMapper; use OCA\Mail\Service\AccountService; @@ -115,7 +116,7 @@ public function testRun(): void { ->willReturn($tag); $this->mailManager->expects(self::once()) ->method('tagMessage') - ->with($account, 'sent', $message, $tag, true); + ->with($account, 'sent', $message, $tag, true, MessageTags::TYPE_CLASSIFIER); $this->job->run($argument); } @@ -248,7 +249,7 @@ public function testRunMultipleMessages(): void { ->willReturn($tag); $this->mailManager->expects(self::once()) ->method('tagMessage') - ->with($account, 'sent', $message, $tag, true); + ->with($account, 'sent', $message, $tag, true, MessageTags::TYPE_CLASSIFIER); $this->job->run($argument); } @@ -302,7 +303,7 @@ public function testRunCreateTag(): void { ->willReturn($tag); $this->mailManager->expects(self::once()) ->method('tagMessage') - ->with($account, 'sent', $message, $tag, true); + ->with($account, 'sent', $message, $tag, true, MessageTags::TYPE_CLASSIFIER); $this->job->run($argument); } diff --git a/tests/Unit/Service/MailManagerTest.php b/tests/Unit/Service/MailManagerTest.php index 5bd24acc40..1048d40250 100644 --- a/tests/Unit/Service/MailManagerTest.php +++ b/tests/Unit/Service/MailManagerTest.php @@ -18,6 +18,7 @@ use OCA\Mail\Db\MailboxMapper; use OCA\Mail\Db\Message; use OCA\Mail\Db\MessageMapper as DbMessageMapper; +use OCA\Mail\Db\MessageTags; use OCA\Mail\Db\MessageTagsMapper; use OCA\Mail\Db\Tag; use OCA\Mail\Db\TagMapper; @@ -450,6 +451,9 @@ public function testTagMessage(): void { $account->expects($this->once()) ->method('getUserId') ->willReturn('test'); + $this->tagMapper->expects($this->once()) + ->method('tagMessage') + ->with($tag, $message->getMessageId(), 'test', MessageTags::TYPE_USER); $this->manager->tagMessage($account, 'INBOX', $message, $tag, true); }