Skip to content

Commit d539f4b

Browse files
committed
Added types for Doctrine
1 parent 71ce2e7 commit d539f4b

File tree

7 files changed

+258
-1
lines changed

7 files changed

+258
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.2 (unreleased)
2+
3+
- Added experimental support for Doctrine
4+
15
## 0.2.1 (2025-02-01)
26

37
- Added support for `SplFixedArray`

composer.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,20 @@
1515
"autoload": {
1616
"psr-4": {
1717
"Pgvector\\": "src/",
18+
"Pgvector\\Doctrine\\": "src/doctrine/",
1819
"Pgvector\\Laravel\\": "src/laravel/"
1920
}
2021
},
2122
"require": {
2223
"php": ">= 8.1"
2324
},
2425
"require-dev": {
26+
"doctrine/dbal": "^4",
27+
"doctrine/orm": "^3",
2528
"phpunit/phpunit": "^10",
2629
"illuminate/database": ">= 10",
27-
"laravel/serializable-closure": "^1.3"
30+
"laravel/serializable-closure": "^1.3",
31+
"symfony/cache": "^6"
2832
},
2933
"extra": {
3034
"laravel": {

src/doctrine/HalfVectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\HalfVector;
8+
9+
class HalfVectorType extends Type
10+
{
11+
public function getName(): string
12+
{
13+
return 'halfvec';
14+
}
15+
16+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
17+
{
18+
$length = $fieldDeclaration['length'];
19+
return is_null($length) ? 'halfvec' : sprintf('halfvec(%d)', $length);
20+
}
21+
22+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?HalfVector
23+
{
24+
if (is_null($value)) {
25+
return null;
26+
}
27+
28+
return new HalfVector($value);
29+
}
30+
31+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
32+
{
33+
if (is_null($value)) {
34+
return null;
35+
}
36+
37+
if (!($value instanceof HalfVector)) {
38+
$value = new HalfVector($value);
39+
}
40+
41+
return (string) $value;
42+
}
43+
}

src/doctrine/SparseVectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\SparseVector;
8+
9+
class SparseVectorType extends Type
10+
{
11+
public function getName(): string
12+
{
13+
return 'sparsevec';
14+
}
15+
16+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
17+
{
18+
$length = $fieldDeclaration['length'];
19+
return is_null($length) ? 'sparsevec' : sprintf('sparsevec(%d)', $length);
20+
}
21+
22+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?SparseVector
23+
{
24+
if (is_null($value)) {
25+
return null;
26+
}
27+
28+
return new SparseVector($value);
29+
}
30+
31+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
32+
{
33+
if (is_null($value)) {
34+
return null;
35+
}
36+
37+
if (!($value instanceof SparseVector)) {
38+
$value = new SparseVector($value);
39+
}
40+
41+
return (string) $value;
42+
}
43+
}

src/doctrine/VectorType.php

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
<?php
2+
3+
namespace Pgvector\Doctrine;
4+
5+
use Doctrine\DBAL\Types\Type;
6+
use Doctrine\DBAL\Platforms\AbstractPlatform;
7+
use Pgvector\Vector;
8+
9+
class VectorType extends Type
10+
{
11+
public function getName(): string
12+
{
13+
return 'vector';
14+
}
15+
16+
public function getSQLDeclaration(array $fieldDeclaration, AbstractPlatform $platform): string
17+
{
18+
$length = $fieldDeclaration['length'];
19+
return is_null($length) ? 'vector' : sprintf('vector(%d)', $length);
20+
}
21+
22+
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): ?Vector
23+
{
24+
if (is_null($value)) {
25+
return null;
26+
}
27+
28+
return new Vector($value);
29+
}
30+
31+
public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): ?string
32+
{
33+
if (is_null($value)) {
34+
return null;
35+
}
36+
37+
if (!($value instanceof Vector)) {
38+
$value = new Vector($value);
39+
}
40+
41+
return (string) $value;
42+
}
43+
}

tests/DoctrineTest.php

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
<?php
2+
3+
use PHPUnit\Framework\TestCase;
4+
5+
use Doctrine\DBAL\DriverManager;
6+
use Doctrine\DBAL\Exception\TableNotFoundException;
7+
use Doctrine\DBAL\Types\Type;
8+
use Doctrine\ORM\EntityManager;
9+
use Doctrine\ORM\ORMSetup;
10+
use Doctrine\ORM\Tools\SchemaTool;
11+
use Pgvector\HalfVector;
12+
use Pgvector\SparseVector;
13+
use Pgvector\Vector;
14+
15+
require_once __DIR__ . '/models/DoctrineItem.php';
16+
17+
final class DoctrineTest extends TestCase
18+
{
19+
public function testTypes()
20+
{
21+
$config = ORMSetup::createAttributeMetadataConfiguration(
22+
paths: [__DIR__ . '/models'],
23+
isDevMode: true
24+
);
25+
26+
$connection = DriverManager::getConnection([
27+
'driver' => 'pgsql',
28+
'dbname' => 'pgvector_php_test'
29+
], $config);
30+
31+
$entityManager = new EntityManager($connection, $config);
32+
33+
Type::addType('vector', 'Pgvector\Doctrine\VectorType');
34+
Type::addType('halfvec', 'Pgvector\Doctrine\HalfVectorType');
35+
Type::addType('sparsevec', 'Pgvector\Doctrine\SparseVectorType');
36+
37+
$platform = $entityManager->getConnection()->getDatabasePlatform();
38+
$platform->registerDoctrineTypeMapping('vector', 'vector');
39+
$platform->registerDoctrineTypeMapping('halfvec', 'halfvec');
40+
$platform->registerDoctrineTypeMapping('sparsevec', 'sparsevec');
41+
42+
$schemaManager = $entityManager->getConnection()->createSchemaManager();
43+
try {
44+
$schemaManager->dropTable('doctrine_items');
45+
} catch (TableNotFoundException $e) {
46+
// do nothing
47+
}
48+
49+
$schemaTool = new SchemaTool($entityManager);
50+
$schemaTool->createSchema([$entityManager->getClassMetadata('DoctrineItem')]);
51+
52+
$item = new DoctrineItem();
53+
$item->setEmbedding(new Vector([1, 2, 3]));
54+
$item->setHalfEmbedding(new HalfVector([4, 5, 6]));
55+
$item->setSparseEmbedding(new SparseVector([7, 8, 9]));
56+
$entityManager->persist($item);
57+
$entityManager->flush();
58+
59+
$itemRepository = $entityManager->getRepository('DoctrineItem');
60+
$item = $itemRepository->find(1);
61+
$this->assertEquals([1, 2, 3], $item->getEmbedding()->toArray());
62+
$this->assertEquals([4, 5, 6], $item->getHalfEmbedding()->toArray());
63+
$this->assertEquals([7, 8, 9], $item->getSparseEmbedding()->toArray());
64+
}
65+
}

tests/models/DoctrineItem.php

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
<?php
2+
3+
use Doctrine\ORM\Mapping as ORM;
4+
use Pgvector\HalfVector;
5+
use Pgvector\SparseVector;
6+
use Pgvector\Vector;
7+
8+
#[ORM\Entity]
9+
#[ORM\Table(name: 'doctrine_items')]
10+
class DoctrineItem
11+
{
12+
#[ORM\Id]
13+
#[ORM\Column(type: 'integer')]
14+
#[ORM\GeneratedValue]
15+
private int|null $id = null;
16+
17+
#[ORM\Column(type: 'vector', length: 3)]
18+
private Vector $embedding;
19+
20+
#[ORM\Column(type: 'halfvec', length: 3)]
21+
private HalfVector $halfEmbedding;
22+
23+
#[ORM\Column(type: 'sparsevec', length: 3)]
24+
private SparseVector $sparseEmbedding;
25+
26+
public function getEmbedding(): Vector
27+
{
28+
return $this->embedding;
29+
}
30+
31+
public function setEmbedding(Vector $embedding): void
32+
{
33+
$this->embedding = $embedding;
34+
}
35+
36+
public function getHalfEmbedding(): HalfVector
37+
{
38+
return $this->halfEmbedding;
39+
}
40+
41+
public function setHalfEmbedding(HalfVector $embedding): void
42+
{
43+
$this->halfEmbedding = $embedding;
44+
}
45+
46+
public function getSparseEmbedding(): SparseVector
47+
{
48+
return $this->sparseEmbedding;
49+
}
50+
51+
public function setSparseEmbedding(SparseVector $embedding): void
52+
{
53+
$this->sparseEmbedding = $embedding;
54+
}
55+
}

0 commit comments

Comments
 (0)