|
2 | 2 |
|
3 | 3 | namespace Pgvector\Doctrine; |
4 | 4 |
|
| 5 | +use Doctrine\DBAL\Platforms\AbstractPlatform; |
5 | 6 | use Doctrine\DBAL\Types\Type; |
6 | 7 | use Doctrine\ORM\Configuration; |
7 | 8 | use Doctrine\ORM\EntityManager; |
8 | 9 |
|
9 | 10 | abstract class PgvectorSetup |
10 | 11 | { |
11 | | - public static function registerTypes(?EntityManager $entityManager = null): void |
| 12 | + public static function register(EntityManager $entityManager): void |
| 13 | + { |
| 14 | + self::registerTypes(); |
| 15 | + self::registerPlatformTypes($entityManager->getConnection()->getDatabasePlatform()); |
| 16 | + self::registerFunctions($entityManager->getConfiguration()); |
| 17 | + } |
| 18 | + |
| 19 | + private static function registerTypes(): void |
12 | 20 | { |
13 | 21 | Type::addType('vector', 'Pgvector\Doctrine\VectorType'); |
14 | 22 | Type::addType('halfvec', 'Pgvector\Doctrine\HalfVectorType'); |
15 | 23 | Type::addType('bit', 'Pgvector\Doctrine\BitType'); |
16 | 24 | Type::addType('sparsevec', 'Pgvector\Doctrine\SparseVectorType'); |
| 25 | + } |
17 | 26 |
|
18 | | - if (!is_null($entityManager)) { |
19 | | - $platform = $entityManager->getConnection()->getDatabasePlatform(); |
20 | | - $platform->registerDoctrineTypeMapping('vector', 'vector'); |
21 | | - $platform->registerDoctrineTypeMapping('halfvec', 'halfvec'); |
22 | | - $platform->registerDoctrineTypeMapping('bit', 'bit'); |
23 | | - $platform->registerDoctrineTypeMapping('sparsevec', 'sparsevec'); |
24 | | - } |
| 27 | + private static function registerPlatformTypes(AbstractPlatform $platform): void |
| 28 | + { |
| 29 | + $platform->registerDoctrineTypeMapping('vector', 'vector'); |
| 30 | + $platform->registerDoctrineTypeMapping('halfvec', 'halfvec'); |
| 31 | + $platform->registerDoctrineTypeMapping('bit', 'bit'); |
| 32 | + $platform->registerDoctrineTypeMapping('sparsevec', 'sparsevec'); |
25 | 33 | } |
26 | 34 |
|
27 | | - public static function registerFunctions(Configuration $config): void |
| 35 | + private static function registerFunctions(Configuration $config): void |
28 | 36 | { |
29 | 37 | $config->addCustomNumericFunction('l2_distance', 'Pgvector\Doctrine\L2Distance'); |
30 | 38 | $config->addCustomNumericFunction('max_inner_product', 'Pgvector\Doctrine\MaxInnerProduct'); |
|
0 commit comments