Skip to content

Commit c4c8fe3

Browse files
committed
Improved Doctrine setup
1 parent 8be5962 commit c4c8fe3

File tree

3 files changed

+20
-19
lines changed

3 files changed

+20
-19
lines changed

README.md

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,12 @@ Install the package
117117
composer require pgvector/pgvector
118118
```
119119

120-
Register the distance functions
120+
Register the types and distance functions
121121

122122
```php
123123
use Pgvector\Doctrine\PgvectorSetup;
124124

125-
PgvectorSetup::registerFunctions($config);
126-
```
127-
128-
And the types
129-
130-
```php
131-
PgvectorSetup::registerTypes($entityManager);
125+
PgvectorSetup::register($entityManager);
132126
```
133127

134128
Enable the extension

src/doctrine/PgvectorSetup.php

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,37 @@
22

33
namespace Pgvector\Doctrine;
44

5+
use Doctrine\DBAL\Platforms\AbstractPlatform;
56
use Doctrine\DBAL\Types\Type;
67
use Doctrine\ORM\Configuration;
78
use Doctrine\ORM\EntityManager;
89

910
abstract class PgvectorSetup
1011
{
11-
public static function registerTypes(?EntityManager $entityManager = null): void
12+
public static function register(EntityManager $entityManager): void
13+
{
14+
self::registerTypes();
15+
self::registerPlatformTypes($entityManager->getConnection()->getDatabasePlatform());
16+
self::registerFunctions($entityManager->getConfiguration());
17+
}
18+
19+
private static function registerTypes(): void
1220
{
1321
Type::addType('vector', 'Pgvector\Doctrine\VectorType');
1422
Type::addType('halfvec', 'Pgvector\Doctrine\HalfVectorType');
1523
Type::addType('bit', 'Pgvector\Doctrine\BitType');
1624
Type::addType('sparsevec', 'Pgvector\Doctrine\SparseVectorType');
25+
}
1726

18-
if (!is_null($entityManager)) {
19-
$platform = $entityManager->getConnection()->getDatabasePlatform();
20-
$platform->registerDoctrineTypeMapping('vector', 'vector');
21-
$platform->registerDoctrineTypeMapping('halfvec', 'halfvec');
22-
$platform->registerDoctrineTypeMapping('bit', 'bit');
23-
$platform->registerDoctrineTypeMapping('sparsevec', 'sparsevec');
24-
}
27+
private static function registerPlatformTypes(AbstractPlatform $platform): void
28+
{
29+
$platform->registerDoctrineTypeMapping('vector', 'vector');
30+
$platform->registerDoctrineTypeMapping('halfvec', 'halfvec');
31+
$platform->registerDoctrineTypeMapping('bit', 'bit');
32+
$platform->registerDoctrineTypeMapping('sparsevec', 'sparsevec');
2533
}
2634

27-
public static function registerFunctions(Configuration $config): void
35+
private static function registerFunctions(Configuration $config): void
2836
{
2937
$config->addCustomNumericFunction('l2_distance', 'Pgvector\Doctrine\L2Distance');
3038
$config->addCustomNumericFunction('max_inner_product', 'Pgvector\Doctrine\MaxInnerProduct');

tests/DoctrineTest.php

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ public static function setUpBeforeClass(): void
2626
paths: [__DIR__ . '/models'],
2727
isDevMode: true
2828
);
29-
PgvectorSetup::registerFunctions($config);
3029

3130
$connection = DriverManager::getConnection([
3231
'driver' => 'pgsql',
@@ -35,7 +34,7 @@ public static function setUpBeforeClass(): void
3534

3635
$entityManager = new EntityManager($connection, $config);
3736
$entityManager->getConnection()->executeStatement('CREATE EXTENSION IF NOT EXISTS vector');
38-
PgvectorSetup::registerTypes($entityManager);
37+
PgvectorSetup::register($entityManager);
3938

4039
$schemaManager = $entityManager->getConnection()->createSchemaManager();
4140
try {

0 commit comments

Comments
 (0)