diff --git a/src/Provider/Doctrine/DoctrineProvider.php b/src/Provider/Doctrine/DoctrineProvider.php index 0fdda41..4c94abc 100644 --- a/src/Provider/Doctrine/DoctrineProvider.php +++ b/src/Provider/Doctrine/DoctrineProvider.php @@ -16,10 +16,12 @@ use DH\Auditor\Provider\Doctrine\Persistence\Event\CreateSchemaListener; use DH\Auditor\Provider\Doctrine\Persistence\Event\TableSchemaListener; use DH\Auditor\Provider\Doctrine\Persistence\Helper\DoctrineHelper; +use DH\Auditor\Provider\Doctrine\Persistence\Reader\Query; use DH\Auditor\Provider\Doctrine\Service\AuditingService; -use DH\Auditor\Provider\Doctrine\Service\StorageService; +use DH\Auditor\Provider\Doctrine\Service\DoctrineService; use DH\Auditor\Provider\ProviderInterface; use DH\Auditor\Provider\Service\AuditingServiceInterface; +use DH\Auditor\Provider\Service\StorageServiceInterface; use DH\Auditor\Tests\Provider\Doctrine\DoctrineProviderTest; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Events; @@ -31,23 +33,6 @@ */ final class DoctrineProvider extends AbstractProvider { - /** - * @var array - */ - private const FIELDS = [ - 'type' => '?', - 'object_id' => '?', - 'discriminator' => '?', - 'transaction_hash' => '?', - 'diffs' => '?', - 'blame_id' => '?', - 'blame_user' => '?', - 'blame_user_fqdn' => '?', - 'blame_user_firewall' => '?', - 'ip' => '?', - 'created_at' => '?', - ]; - private readonly TransactionManager $transactionManager; public function __construct(ConfigurationInterface $configuration) @@ -102,7 +87,7 @@ public function getAuditingServiceForEntity(string $entity): AuditingService throw new InvalidArgumentException(\sprintf('Auditing service not found for "%s".', $entity)); } - public function getStorageServiceForEntity(string $entity): StorageService + public function getStorageServiceForEntity(string $entity): DoctrineService { $this->checkStorageMapper(); @@ -111,10 +96,13 @@ public function getStorageServiceForEntity(string $entity): StorageService if (null === $storageMapper || 1 === \count($this->getStorageServices())) { // No mapper and only 1 storage entity manager - /** @var array $services */ + /** @var array $services */ $services = $this->getStorageServices(); + $service = array_values($services)[0]; - return array_values($services)[0]; + \assert($service instanceof DoctrineService); // helps PHPStan + + return $service; } if (\is_string($storageMapper) && class_exists($storageMapper)) { @@ -126,34 +114,22 @@ public function getStorageServiceForEntity(string $entity): StorageService return $storageMapper($entity, $this->getStorageServices()); } + public function getEntityAuditTableName(string $entity): string + { + \assert($this->configuration instanceof Configuration); // helps PHPStan + + return $this->getStorageServiceForEntity($entity)->getEntityAuditTableName($this->configuration, $entity); + } + public function persist(LifecycleEvent $event): void { $payload = $event->getPayload(); - $auditTable = $payload['table']; $entity = $payload['entity']; - unset($payload['table'], $payload['entity']); - - $keys = array_keys(self::FIELDS); - $query = \sprintf( - 'INSERT INTO %s (%s) VALUES (%s)', - $auditTable, - implode(', ', $keys), - implode(', ', array_values(self::FIELDS)) - ); - - /** @var StorageService $storageService */ - $storageService = $this->getStorageServiceForEntity($entity); - $statement = $storageService->getEntityManager()->getConnection()->prepare($query); - - foreach ($payload as $key => $value) { - $statement->bindValue(array_search($key, $keys, true) + 1, $value); - } - - $statement->executeStatement(); + $id = $this->getStorageServiceForEntity($entity)->persist($event); // let's get the last inserted ID from the database so other providers can use that info $payload = $event->getPayload(); - $payload['id'] = (int) $storageService->getEntityManager()->getConnection()->lastInsertId(); + $payload['id'] = $id; $event->setPayload($payload); } @@ -267,6 +243,17 @@ public function loadAnnotations(EntityManagerInterface $entityManager, array $en return $this; } + public function createBaseQuery(string $entity): Query + { + \assert($this->configuration instanceof Configuration); // helps PHPStan + + return $this->getStorageServiceForEntity($entity)->createBaseQuery( + $this->configuration, + $entity, + $this->getAuditor()->getConfiguration()->getTimezone() + ); + } + private function checkStorageMapper(): self { \assert($this->configuration instanceof Configuration); // helps PHPStan diff --git a/src/Provider/Doctrine/Persistence/Reader/Query.php b/src/Provider/Doctrine/Persistence/Reader/Query.php index 8fa99c3..458aeea 100644 --- a/src/Provider/Doctrine/Persistence/Reader/Query.php +++ b/src/Provider/Doctrine/Persistence/Reader/Query.php @@ -67,10 +67,11 @@ final class Query private readonly \DateTimeZone $timezone; - public function __construct(private readonly string $table, private readonly Connection $connection, string $timezone) + public function __construct(private readonly string $table, private readonly Connection $connection, string $timezone, private ?array $supportedFilters = null) { $this->timezone = new \DateTimeZone($timezone); + $this->supportedFilters = ($supportedFilters ?? array_keys(SchemaHelper::getAuditTableIndices('fake'))); foreach ($this->getSupportedFilters() as $filterType) { $this->filters[$filterType] = []; } @@ -162,7 +163,7 @@ public function limit(int $limit, int $offset = 0): self public function getSupportedFilters(): array { - return array_keys(SchemaHelper::getAuditTableIndices('fake')); + return $this->supportedFilters ?? []; } public function getFilters(): array diff --git a/src/Provider/Doctrine/Persistence/Reader/Reader.php b/src/Provider/Doctrine/Persistence/Reader/Reader.php index cf9a7fa..54669a8 100644 --- a/src/Provider/Doctrine/Persistence/Reader/Reader.php +++ b/src/Provider/Doctrine/Persistence/Reader/Reader.php @@ -45,10 +45,7 @@ public function createQuery(string $entity, array $options = []): Query $this->configureOptions($resolver); $config = $resolver->resolve($options); - $connection = $this->provider->getStorageServiceForEntity($entity)->getEntityManager()->getConnection(); - $timezone = $this->provider->getAuditor()->getConfiguration()->getTimezone(); - - $query = new Query($this->getEntityAuditTableName($entity), $connection, $timezone); + $query = $this->provider->createBaseQuery($entity); $query ->addOrderBy(Query::CREATED_AT, 'DESC') ->addOrderBy(Query::ID, 'DESC') @@ -159,24 +156,7 @@ public function getEntityTableName(string $entity): string */ public function getEntityAuditTableName(string $entity): string { - /** @var Configuration $configuration */ - $configuration = $this->provider->getConfiguration(); - - /** @var AuditingService $auditingService */ - $auditingService = $this->provider->getAuditingServiceForEntity($entity); - $entityManager = $auditingService->getEntityManager(); - $schema = ''; - if ($entityManager->getClassMetadata($entity)->getSchemaName()) { - $schema = $entityManager->getClassMetadata($entity)->getSchemaName().'.'; - } - - return \sprintf( - '%s%s%s%s', - $schema, - $configuration->getTablePrefix(), - $this->getEntityTableName($entity), - $configuration->getTableSuffix() - ); + return $this->provider->getEntityAuditTableName($entity); } private function configureOptions(OptionsResolver $resolver): void diff --git a/src/Provider/Doctrine/Service/DoctrineService.php b/src/Provider/Doctrine/Service/DoctrineService.php index 66cc58b..b9add4f 100644 --- a/src/Provider/Doctrine/Service/DoctrineService.php +++ b/src/Provider/Doctrine/Service/DoctrineService.php @@ -4,11 +4,31 @@ namespace DH\Auditor\Provider\Doctrine\Service; +use DH\Auditor\Event\LifecycleEvent; +use DH\Auditor\Provider\Doctrine\Configuration; +use DH\Auditor\Provider\Doctrine\Persistence\Reader\Query; use DH\Auditor\Provider\Service\AbstractService; use Doctrine\ORM\EntityManagerInterface; abstract class DoctrineService extends AbstractService { + /** + * @var array + */ + private const FIELDS = [ + 'type' => '?', + 'object_id' => '?', + 'discriminator' => '?', + 'transaction_hash' => '?', + 'diffs' => '?', + 'blame_id' => '?', + 'blame_user' => '?', + 'blame_user_fqdn' => '?', + 'blame_user_firewall' => '?', + 'ip' => '?', + 'created_at' => '?', + ]; + public function __construct(string $name, private readonly EntityManagerInterface $entityManager) { parent::__construct($name); @@ -18,4 +38,55 @@ public function getEntityManager(): EntityManagerInterface { return $this->entityManager; } + + public function createBaseQuery(Configuration $configuration, string $entity, string $timezone): Query + { + $connection = $this->getEntityManager()->getConnection(); + + return new Query($this->getEntityAuditTableName($configuration, $entity), $connection, $timezone, ['id', ...array_keys(self::FIELDS)]); + } + + public function persist(LifecycleEvent $event): int + { + $payload = $event->getPayload(); + $auditTable = $payload['table']; + unset($payload['table'], $payload['entity']); + + $keys = array_keys(self::FIELDS); + $query = \sprintf( + 'INSERT INTO %s (%s) VALUES (%s)', + $auditTable, + implode(', ', $keys), + implode(', ', array_values(self::FIELDS)) + ); + + $statement = $this->getEntityManager()->getConnection()->prepare($query); + + foreach ($payload as $key => $value) { + $statement->bindValue(array_search($key, $keys, true) + 1, $value); + } + + $statement->executeStatement(); + + return (int) $this->getEntityManager()->getConnection()->lastInsertId(); + } + + /** + * Returns the audit table name for $entity. + */ + public function getEntityAuditTableName(Configuration $configuration, string $entity): string + { + $schema = ''; + if ($this->entityManager->getClassMetadata($entity)->getSchemaName()) { + $schema = $this->entityManager->getClassMetadata($entity)->getSchemaName().'.'; + } + + return \sprintf( + '%s%s%s%s', + $schema, + $configuration->getTablePrefix(), + $this->entityManager->getClassMetadata($entity)->getTableName(), + $configuration->getTableSuffix() + ); + } } diff --git a/src/Provider/Doctrine/Service/SingleTableDoctrineService.php b/src/Provider/Doctrine/Service/SingleTableDoctrineService.php new file mode 100644 index 0000000..6cbcede --- /dev/null +++ b/src/Provider/Doctrine/Service/SingleTableDoctrineService.php @@ -0,0 +1,70 @@ + + */ + private const FIELDS = [ + 'type' => '?', + 'object_fqdn' => '?', + 'object_id' => '?', + 'discriminator' => '?', + 'transaction_hash' => '?', + 'diffs' => '?', + 'blame_id' => '?', + 'blame_user' => '?', + 'blame_user_fqdn' => '?', + 'blame_user_firewall' => '?', + 'ip' => '?', + 'created_at' => '?', + ]; + + public function __construct(DoctrineService $doctrineService, private string $auditTableName = 'audit') + { + parent::__construct($doctrineService->getName(), $doctrineService->getEntityManager()); + } + + public function createBaseQuery(Configuration $configuration, string $entity, string $timezone): Query + { + $connection = $this->getEntityManager()->getConnection(); + $query = new Query($this->auditTableName, $connection, $timezone, ['id', ...array_keys(self::FIELDS)]); + $query->addFilter(new SimpleFilter('object_fqdn', $entity)); + + return $query; + } + + public function persist(LifecycleEvent $event): int + { + $payload = $event->getPayload(); + $entity = $payload['entity']; + $payload['object_fqdn'] = $entity; + unset($payload['table'], $payload['entity']); + + $keys = array_keys(self::FIELDS); + $query = \sprintf( + 'INSERT INTO %s (%s) VALUES (%s)', + $this->auditTableName, + implode(', ', $keys), + implode(', ', array_values(self::FIELDS)) + ); + + $statement = $this->getEntityManager()->getConnection()->prepare($query); + + foreach ($payload as $key => $value) { + $statement->bindValue(array_search($key, $keys, true) + 1, $value); + } + + $statement->executeStatement(); + + return (int) $this->getEntityManager()->getConnection()->lastInsertId(); + } +}