Skip to content

Commit

Permalink
EntityPreloader: add support for preloading many to many associations (
Browse files Browse the repository at this point in the history
…#4)

* EntityPreloader: add support for preloading many to many associations

* EntityPreloader: refactor preloadOneToMany

* EntityPreloader: refactor preloading has many associations

* add EntityPreloadBlogManyHasManyInversedTest
  • Loading branch information
JanTvrdik authored Oct 8, 2024
1 parent e78a8c6 commit f93c7c1
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 38 deletions.
3 changes: 1 addition & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
],
"require": {
"php": "^8.1",
"doctrine/orm": "^3",
"doctrine/persistence": "^3.1"
"doctrine/orm": "^3"
},
"require-dev": {
"doctrine/collections": "^2.2",
Expand Down
184 changes: 155 additions & 29 deletions src/EntityPreloader.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
use Doctrine\ORM\Mapping\ClassMetadata;
use Doctrine\ORM\PersistentCollection;
use Doctrine\ORM\QueryBuilder;
use Doctrine\Persistence\Proxy;
use LogicException;
use ReflectionProperty;
use function array_chunk;
use function array_keys;
use function array_values;
use function count;
use function get_parent_class;
Expand All @@ -21,7 +20,7 @@
class EntityPreloader
{

private const BATCH_SIZE = 1_000;
private const PRELOAD_ENTITY_DEFAULT_BATCH_SIZE = 1_000;
private const PRELOAD_COLLECTION_DEFAULT_BATCH_SIZE = 100;

public function __construct(
Expand Down Expand Up @@ -65,14 +64,15 @@ public function preload(
}

$maxFetchJoinSameFieldCount ??= 1;
$sourceEntities = $this->loadProxies($sourceClassMetadata, $sourceEntities, $batchSize ?? self::BATCH_SIZE, $maxFetchJoinSameFieldCount);
$sourceEntities = $this->loadProxies($sourceClassMetadata, $sourceEntities, $batchSize ?? self::PRELOAD_ENTITY_DEFAULT_BATCH_SIZE, $maxFetchJoinSameFieldCount);

return match ($associationMapping->type()) {
ClassMetadata::ONE_TO_MANY => $this->preloadOneToMany($sourceEntities, $sourceClassMetadata, $sourcePropertyName, $targetClassMetadata, $batchSize, $maxFetchJoinSameFieldCount),
ClassMetadata::ONE_TO_ONE,
ClassMetadata::MANY_TO_ONE => $this->preloadToOne($sourceEntities, $sourceClassMetadata, $sourcePropertyName, $targetClassMetadata, $batchSize, $maxFetchJoinSameFieldCount),
$preloader = match (true) {
$associationMapping->isToOne() => $this->preloadToOne(...),
$associationMapping->isToMany() => $this->preloadToMany(...),
default => throw new LogicException("Unsupported association mapping type {$associationMapping->type()}"),
};

return $preloader($sourceEntities, $sourceClassMetadata, $sourcePropertyName, $targetClassMetadata, $batchSize, $maxFetchJoinSameFieldCount);
}

/**
Expand Down Expand Up @@ -135,7 +135,7 @@ private function loadProxies(
$entityKey = (string) $entityId;
$uniqueEntities[$entityKey] = $entity;

if ($entity instanceof Proxy && !$entity->__isInitialized()) {
if ($this->entityManager->isUninitializedObject($entity)) {
$uninitializedIds[$entityKey] = $entityId;
}
}
Expand All @@ -157,7 +157,7 @@ private function loadProxies(
* @template S of E
* @template T of E
*/
private function preloadOneToMany(
private function preloadToMany(
array $sourceEntities,
ClassMetadata $sourceClassMetadata,
string $sourcePropertyName,
Expand All @@ -168,50 +168,170 @@ private function preloadOneToMany(
{
$sourceIdentifierReflection = $sourceClassMetadata->getSingleIdReflectionProperty(); // e.g. Order::$id reflection
$sourcePropertyReflection = $sourceClassMetadata->getReflectionProperty($sourcePropertyName); // e.g. Order::$items reflection
$targetPropertyName = $sourceClassMetadata->getAssociationMappedByTargetField($sourcePropertyName); // e.g. 'order'
$targetPropertyReflection = $targetClassMetadata->getReflectionProperty($targetPropertyName); // e.g. Item::$order reflection
$targetIdentifierReflection = $targetClassMetadata->getSingleIdReflectionProperty();

if ($sourceIdentifierReflection === null || $sourcePropertyReflection === null || $targetPropertyReflection === null) {
if ($sourceIdentifierReflection === null || $sourcePropertyReflection === null || $targetIdentifierReflection === null) {
throw new LogicException('Doctrine should use RuntimeReflectionService which never returns null.');
}

$batchSize ??= self::PRELOAD_COLLECTION_DEFAULT_BATCH_SIZE;

$targetEntities = [];
$uninitializedSourceEntityIds = [];
$uninitializedCollections = [];

foreach ($sourceEntities as $sourceEntity) {
$sourceEntityId = (string) $sourceIdentifierReflection->getValue($sourceEntity);
$sourceEntityId = $sourceIdentifierReflection->getValue($sourceEntity);
$sourceEntityKey = (string) $sourceEntityId;
$sourceEntityCollection = $sourcePropertyReflection->getValue($sourceEntity);

if (
$sourceEntityCollection instanceof PersistentCollection
&& !$sourceEntityCollection->isInitialized()
&& !$sourceEntityCollection->isDirty() // preloading dirty collection is too hard to handle
) {
$uninitializedCollections[$sourceEntityId] = $sourceEntityCollection;
$uninitializedSourceEntityIds[$sourceEntityKey] = $sourceEntityId;
$uninitializedCollections[$sourceEntityKey] = $sourceEntityCollection;
continue;
}

foreach ($sourceEntityCollection as $targetEntity) {
$targetEntities[] = $targetEntity;
$targetEntityKey = (string) $targetIdentifierReflection->getValue($targetEntity);
$targetEntities[$targetEntityKey] = $targetEntity;
}
}

foreach (array_chunk($uninitializedCollections, $batchSize, true) as $chunk) {
$targetEntitiesChunk = $this->loadEntitiesBy($targetClassMetadata, $targetPropertyName, array_keys($chunk), $maxFetchJoinSameFieldCount);
$innerLoader = match ($sourceClassMetadata->getAssociationMapping($sourcePropertyName)->type()) {
ClassMetadata::ONE_TO_MANY => $this->preloadOneToManyInner(...),
ClassMetadata::MANY_TO_MANY => $this->preloadManyToManyInner(...),
default => throw new LogicException('Unsupported association mapping type'),
};

foreach ($targetEntitiesChunk as $targetEntity) {
$sourceEntity = $targetPropertyReflection->getValue($targetEntity);
$sourceEntityId = (string) $sourceIdentifierReflection->getValue($sourceEntity);
$uninitializedCollections[$sourceEntityId]->add($targetEntity);
$targetEntities[] = $targetEntity;
foreach (array_chunk($uninitializedSourceEntityIds, $batchSize, preserve_keys: true) as $uninitializedSourceEntityIdsChunk) {
$targetEntitiesChunk = $innerLoader(
sourceClassMetadata: $sourceClassMetadata,
sourceIdentifierReflection: $sourceIdentifierReflection,
sourcePropertyName: $sourcePropertyName,
targetClassMetadata: $targetClassMetadata,
targetIdentifierReflection: $targetIdentifierReflection,
uninitializedSourceEntityIdsChunk: array_values($uninitializedSourceEntityIdsChunk),
uninitializedCollections: $uninitializedCollections,
maxFetchJoinSameFieldCount: $maxFetchJoinSameFieldCount,
);

foreach ($targetEntitiesChunk as $targetEntityKey => $targetEntity) {
$targetEntities[$targetEntityKey] = $targetEntity;
}
}

foreach ($uninitializedCollections as $sourceEntityCollection) {
$sourceEntityCollection->setInitialized(true);
$sourceEntityCollection->takeSnapshot();
}

return array_values($targetEntities);
}

/**
* @param ClassMetadata<S> $sourceClassMetadata
* @param ClassMetadata<T> $targetClassMetadata
* @param list<mixed> $uninitializedSourceEntityIdsChunk
* @param array<string, PersistentCollection<int, T>> $uninitializedCollections
* @param non-negative-int $maxFetchJoinSameFieldCount
* @return array<string, T>
* @template S of E
* @template T of E
*/
private function preloadOneToManyInner(
ClassMetadata $sourceClassMetadata,
ReflectionProperty $sourceIdentifierReflection,
string $sourcePropertyName,
ClassMetadata $targetClassMetadata,
ReflectionProperty $targetIdentifierReflection,
array $uninitializedSourceEntityIdsChunk,
array $uninitializedCollections,
int $maxFetchJoinSameFieldCount,
): array
{
$targetPropertyName = $sourceClassMetadata->getAssociationMappedByTargetField($sourcePropertyName); // e.g. 'order'
$targetPropertyReflection = $targetClassMetadata->getReflectionProperty($targetPropertyName); // e.g. Item::$order reflection
$targetEntities = [];

if ($targetPropertyReflection === null) {
throw new LogicException('Doctrine should use RuntimeReflectionService which never returns null.');
}

foreach ($this->loadEntitiesBy($targetClassMetadata, $targetPropertyName, $uninitializedSourceEntityIdsChunk, $maxFetchJoinSameFieldCount) as $targetEntity) {
$sourceEntity = $targetPropertyReflection->getValue($targetEntity);
$sourceEntityKey = (string) $sourceIdentifierReflection->getValue($sourceEntity);
$uninitializedCollections[$sourceEntityKey]->add($targetEntity);

$targetEntityKey = (string) $targetIdentifierReflection->getValue($targetEntity);
$targetEntities[$targetEntityKey] = $targetEntity;
}

foreach ($chunk as $sourceEntityCollection) {
$sourceEntityCollection->setInitialized(true);
$sourceEntityCollection->takeSnapshot();
return $targetEntities;
}

/**
* @param ClassMetadata<S> $sourceClassMetadata
* @param ClassMetadata<T> $targetClassMetadata
* @param list<mixed> $uninitializedSourceEntityIdsChunk
* @param array<string, PersistentCollection<int, T>> $uninitializedCollections
* @param non-negative-int $maxFetchJoinSameFieldCount
* @return array<string, T>
* @template S of E
* @template T of E
*/
private function preloadManyToManyInner(
ClassMetadata $sourceClassMetadata,
ReflectionProperty $sourceIdentifierReflection,
string $sourcePropertyName,
ClassMetadata $targetClassMetadata,
ReflectionProperty $targetIdentifierReflection,
array $uninitializedSourceEntityIdsChunk,
array $uninitializedCollections,
int $maxFetchJoinSameFieldCount,
): array
{
$sourceIdentifierName = $sourceClassMetadata->getSingleIdentifierFieldName();
$targetIdentifierName = $targetClassMetadata->getSingleIdentifierFieldName();

$manyToManyRows = $this->entityManager->createQueryBuilder()
->select("source.{$sourceIdentifierName} AS sourceId", "target.{$targetIdentifierName} AS targetId")
->from($sourceClassMetadata->getName(), 'source')
->join("source.{$sourcePropertyName}", 'target')
->andWhere('source IN (:sourceEntityIds)')
->setParameter('sourceEntityIds', $uninitializedSourceEntityIdsChunk)
->getQuery()
->getResult();

$targetEntities = [];
$uninitializedTargetEntityIds = [];

foreach ($manyToManyRows as $manyToManyRow) {
$targetEntityId = $manyToManyRow['targetId'];
$targetEntityKey = (string) $targetEntityId;

/** @var T|false $targetEntity */
$targetEntity = $this->entityManager->getUnitOfWork()->tryGetById($targetEntityId, $targetClassMetadata->getName());

if ($targetEntity !== false && !$this->entityManager->isUninitializedObject($targetEntity)) {
$targetEntities[$targetEntityKey] = $targetEntity;
continue;
}

$uninitializedTargetEntityIds[$targetEntityKey] = $targetEntityId;
}

foreach ($this->loadEntitiesBy($targetClassMetadata, $targetIdentifierName, array_values($uninitializedTargetEntityIds), $maxFetchJoinSameFieldCount) as $targetEntity) {
$targetEntityKey = (string) $targetIdentifierReflection->getValue($targetEntity);
$targetEntities[$targetEntityKey] = $targetEntity;
}

foreach ($manyToManyRows as $manyToManyRow) {
$sourceEntityKey = (string) $manyToManyRow['sourceId'];
$targetEntityKey = (string) $manyToManyRow['targetId'];
$uninitializedCollections[$sourceEntityKey]->add($targetEntities[$targetEntityKey]);
}

return $targetEntities;
Expand All @@ -237,12 +357,14 @@ private function preloadToOne(
): array
{
$sourcePropertyReflection = $sourceClassMetadata->getReflectionProperty($sourcePropertyName); // e.g. Item::$order reflection
$targetEntities = [];

if ($sourcePropertyReflection === null) {
throw new LogicException('Doctrine should use RuntimeReflectionService which never returns null.');
}

$batchSize ??= self::PRELOAD_ENTITY_DEFAULT_BATCH_SIZE;
$targetEntities = [];

foreach ($sourceEntities as $sourceEntity) {
$targetEntity = $sourcePropertyReflection->getValue($sourceEntity);

Expand All @@ -253,7 +375,7 @@ private function preloadToOne(
$targetEntities[] = $targetEntity;
}

return $this->loadProxies($targetClassMetadata, $targetEntities, $batchSize ?? self::BATCH_SIZE, $maxFetchJoinSameFieldCount);
return $this->loadProxies($targetClassMetadata, $targetEntities, $batchSize, $maxFetchJoinSameFieldCount);
}

/**
Expand All @@ -270,6 +392,10 @@ private function loadEntitiesBy(
int $maxFetchJoinSameFieldCount,
): array
{
if (count($fieldValues) === 0) {
return [];
}

$rootLevelAlias = 'e';

$queryBuilder = $this->entityManager->createQueryBuilder()
Expand Down
93 changes: 93 additions & 0 deletions tests/EntityPreloadBlogManyHasManyInversedTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
<?php declare(strict_types = 1);

namespace ShipMonkTests\DoctrineEntityPreloader;

use Doctrine\ORM\Mapping\ClassMetadata;
use ShipMonkTests\DoctrineEntityPreloader\Fixtures\Blog\Tag;
use ShipMonkTests\DoctrineEntityPreloader\Lib\TestCase;

class EntityPreloadBlogManyHasManyInversedTest extends TestCase
{

public function testManyHasManyInversedUnoptimized(): void
{
$this->createDummyBlogData(articleInEachCategoryCount: 5, tagForEachArticleCount: 5);

$tags = $this->getEntityManager()->getRepository(Tag::class)->findAll();

$this->readArticleTitles($tags);

self::assertAggregatedQueries([
['count' => 1, 'query' => 'SELECT * FROM tag t0'],
['count' => 25, 'query' => 'SELECT * FROM article t0 INNER JOIN article_tag ON t0.id = article_tag.article_id WHERE article_tag.tag_id = ?'],
]);
}

public function testManyHasManyInversedWithFetchJoin(): void
{
$this->createDummyBlogData(articleInEachCategoryCount: 5, tagForEachArticleCount: 5);

$tags = $this->getEntityManager()->createQueryBuilder()
->select('tag', 'article')
->from(Tag::class, 'tag')
->leftJoin('tag.articles', 'article')
->getQuery()
->getResult();

$this->readArticleTitles($tags);

self::assertAggregatedQueries([
['count' => 1, 'query' => 'SELECT * FROM tag t0_ LEFT JOIN article_tag a2_ ON t0_.id = a2_.tag_id LEFT JOIN article a1_ ON a1_.id = a2_.article_id'],
]);
}

public function testManyHasManyInversedWithEagerFetchMode(): void
{
$this->createDummyBlogData(articleInEachCategoryCount: 5, tagForEachArticleCount: 5);

// for eagerly loaded Many-To-Many associations one query has to be made for each collection
// https://www.doctrine-project.org/projects/doctrine-orm/en/3.2/reference/working-with-objects.html#by-eager-loading
$tags = $this->getEntityManager()->createQueryBuilder()
->select('tag')
->from(Tag::class, 'tag')
->getQuery()
->setFetchMode(Tag::class, 'articles', ClassMetadata::FETCH_EAGER)
->getResult();

$this->readArticleTitles($tags);

self::assertAggregatedQueries([
['count' => 1, 'query' => 'SELECT * FROM tag t0_'],
['count' => 25, 'query' => 'SELECT * FROM article t0 INNER JOIN article_tag ON t0.id = article_tag.article_id WHERE article_tag.tag_id = ?'],
]);
}

public function testManyHasManyInversedWithPreload(): void
{
$this->createDummyBlogData(articleInEachCategoryCount: 5, tagForEachArticleCount: 5);

$tags = $this->getEntityManager()->getRepository(Tag::class)->findAll();
$this->getEntityPreloader()->preload($tags, 'articles');

$this->readArticleTitles($tags);

self::assertAggregatedQueries([
['count' => 1, 'query' => 'SELECT * FROM tag t0'],
['count' => 1, 'query' => 'SELECT * FROM tag t0_ INNER JOIN article_tag a2_ ON t0_.id = a2_.tag_id INNER JOIN article a1_ ON a1_.id = a2_.article_id WHERE t0_.id IN (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)'],
['count' => 1, 'query' => 'SELECT * FROM article a0_ WHERE a0_.id IN (?, ?, ?, ?, ?)'],
]);
}

/**
* @param array<Tag> $tags
*/
private function readArticleTitles(array $tags): void
{
foreach ($tags as $tag) {
foreach ($tag->getArticles() as $article) {
$article->getTitle();
}
}
}

}
Loading

0 comments on commit f93c7c1

Please sign in to comment.