Skip to content
4 changes: 3 additions & 1 deletion packages/core/src/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import NotificationFilter from './notification-filter'
import HomeDatabaseCache from './internal/homedb-cache'
import { cacheKey } from './internal/auth-util'
import { ProtocolVersion } from './protocol-version'
import { Rules } from './mapping.highlevel'

const DEFAULT_MAX_CONNECTION_LIFETIME: number = 60 * 60 * 1000 // 1 hour

Expand Down Expand Up @@ -368,6 +369,7 @@ class QueryConfig<T = EagerResult> {
transactionConfig?: TransactionConfig
auth?: AuthToken
signal?: AbortSignal
parameterRules?: Rules

/**
* @constructor
Expand Down Expand Up @@ -630,7 +632,7 @@ class Driver {
transactionConfig: config.transactionConfig,
auth: config.auth,
signal: config.signal
}, query, parameters)
}, query, parameters, config.parameterRules)
}

/**
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/graph-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import Integer from './integer'
import { stringify } from './json'
import { Rules, GenericConstructor, as } from './mapping.highlevel'

export const JSDate = Date
type StandardDate = Date
/**
* @typedef {number | Integer | bigint} NumberOrInteger
Expand Down
5 changes: 3 additions & 2 deletions packages/core/src/internal/query-executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Result from '../result'
import ManagedTransaction from '../transaction-managed'
import { AuthToken, Query } from '../types'
import { TELEMETRY_APIS } from './constants'
import { Rules } from '../mapping.highlevel'

type SessionFactory = (config: { database?: string, bookmarkManager?: BookmarkManager, impersonatedUser?: string, auth?: AuthToken }) => Session

Expand All @@ -42,7 +43,7 @@ export default class QueryExecutor {

}

public async execute<T>(config: ExecutionConfig<T>, query: Query, parameters?: any): Promise<T> {
public async execute<T>(config: ExecutionConfig<T>, query: Query, parameters?: any, parameterRules?: Rules): Promise<T> {
const session = this._createSession({
database: config.database,
bookmarkManager: config.bookmarkManager,
Expand All @@ -65,7 +66,7 @@ export default class QueryExecutor {
: session.executeWrite.bind(session)

return await executeInTransaction(async (tx: ManagedTransaction) => {
const result = tx.run(query, parameters)
const result = tx.run(query, parameters, parameterRules)
return await config.resultTransformer(result)
}, config.transactionConfig)
} finally {
Expand Down
9 changes: 6 additions & 3 deletions packages/core/src/internal/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Integer, { isInt, int } from '../integer'
import { NumberOrInteger } from '../graph-types'
import { EncryptionLevel } from '../types'
import { stringify } from '../json'
import { Rules, validateAndcleanParameters } from '../mapping.highlevel'

const ENCRYPTION_ON: EncryptionLevel = 'ENCRYPTION_ON'
const ENCRYPTION_OFF: EncryptionLevel = 'ENCRYPTION_OFF'
Expand Down Expand Up @@ -62,27 +63,29 @@ function isObject (obj: any): boolean {
* @throws TypeError when either given query or parameters are invalid.
*/
function validateQueryAndParameters (
query: string | String | { text: string, parameters?: any },
query: string | String | { text: string, parameters?: any, parameterRules?: Rules },
parameters?: any,
opt?: { skipAsserts: boolean }
opt?: { skipAsserts?: boolean, parameterRules?: Rules }
): {
validatedQuery: string
params: any
} {
let validatedQuery: string = ''
let params = parameters ?? {}
let parameterRules = opt?.parameterRules
const skipAsserts: boolean = opt?.skipAsserts ?? false

if (typeof query === 'string') {
validatedQuery = query
} else if (query instanceof String) {
validatedQuery = query.toString()
} else if (typeof query === 'object' && query.text != null) {
validatedQuery = query.text
params = query.parameters ?? {}
parameterRules = query.parameterRules
}

if (!skipAsserts) {
params = validateAndcleanParameters(params, parameterRules)
assertCypherQuery(validatedQuery)
assertQueryParameters(params)
}
Expand Down
46 changes: 45 additions & 1 deletion packages/core/src/mapping.highlevel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ export interface Rule {
optional?: boolean
from?: string
convert?: (recordValue: any, field: string) => any
parameterConversion?: (objectValue: any) => any
validate?: (recordValue: any, field: string) => void
}

export type Rules = Record<string, Rule>

export let rulesRegistry: Record<string, Rules> = {}

let nameMapping: (name: string) => string = (name) => name
export let nameMapping: (name: string) => string = (name) => name

function register <T extends {} = Object> (constructor: GenericConstructor<T>, rules: Rules): void {
rulesRegistry[constructor.name] = rules
Expand Down Expand Up @@ -179,6 +180,49 @@ export function valueAs (value: unknown, field: string, rule?: Rule): unknown {

return ((rule?.convert) != null) ? rule.convert(value, field) : value
}

export function optionalParameterConversion (value: unknown, rule?: Rule): unknown {
if (rule?.optional === true && value == null) {
return value
}
return ((rule?.parameterConversion) != null) ? rule.parameterConversion(value) : value
}

export function validateAndcleanParameters (params: Record<string, any>, suppliedRules?: Rules): Record<string, any> {
const cleanedParams: Record<string, any> = {}
// @ts-expect-error
const parameterRules = getRules(params.constructor, suppliedRules)
if (parameterRules !== undefined) {
for (const key in parameterRules) {
if (!(parameterRules?.[key]?.optional === true)) {
let param = params[key]
if (parameterRules[key]?.parameterConversion !== undefined) {
param = parameterRules[key].parameterConversion(params[key])
}
if (param === undefined) {
throw newError('Parameter object did not include required parameter.')
}
if (parameterRules[key].validate != null) {
parameterRules[key].validate(param, key)
// @ts-expect-error
if (parameterRules[key].apply !== undefined) {
for (const entryKey in param) {
// @ts-expect-error
parameterRules[key].apply.validate(param[entryKey], entryKey)
}
}
}
const mappedKey = parameterRules[key].from ?? nameMapping(key)

cleanedParams[mappedKey] = param
}
}
return cleanedParams
} else {
return params
}
}

function getRules<T extends {} = Object> (constructorOrRules: Rules | GenericConstructor<T>, rules: Rules | undefined): Rules | undefined {
const rulesDefined = typeof constructorOrRules === 'object' ? constructorOrRules : rules
if (rulesDefined != null) {
Expand Down
33 changes: 30 additions & 3 deletions packages/core/src/mapping.rulesfactories.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
* limitations under the License.
*/

import { Rule, valueAs } from './mapping.highlevel'
import { StandardDate, isNode, isPath, isRelationship, isUnboundRelationship } from './graph-types'
import { Rule, valueAs, optionalParameterConversion } from './mapping.highlevel'
import { JSDate, StandardDate, isNode, isPath, isRelationship, isUnboundRelationship } from './graph-types'
import { isPoint } from './spatial-types'
import { Date, DateTime, Duration, LocalDateTime, LocalTime, Time, isDate, isDateTime, isDuration, isLocalDateTime, isLocalTime, isTime } from './temporal-types'
import Vector from './vector'
import Vector, { vector } from './vector'

/**
* @property {function(rule: ?Rule)} asString Create a {@link Rule} that validates the value is a String.
Expand Down Expand Up @@ -250,6 +250,7 @@ export const rule = Object.freeze({
}
},
convert: (value: Duration) => rule?.stringify === true ? value.toString() : value,
parameterConversion: rule?.stringify === true ? (str: string) => Duration.fromString(str) : undefined,
...rule
}
},
Expand All @@ -268,6 +269,7 @@ export const rule = Object.freeze({
}
},
convert: (value: LocalTime) => rule?.stringify === true ? value.toString() : value,
parameterConversion: rule?.stringify === true ? (str: string) => LocalTime.fromString(str) : undefined,
...rule
}
},
Expand All @@ -286,6 +288,7 @@ export const rule = Object.freeze({
}
},
convert: (value: Time) => rule?.stringify === true ? value.toString() : value,
parameterConversion: rule?.stringify === true ? (str: string) => Time.fromString(str) : undefined,
...rule
}
},
Expand All @@ -304,6 +307,7 @@ export const rule = Object.freeze({
}
},
convert: (value: Date) => convertStdDate(value, rule),
parameterConversion: rule?.stringify === true ? (str: string) => Date.fromStandardDateLocal(new JSDate(str)) : undefined,
...rule
}
},
Expand All @@ -315,13 +319,21 @@ export const rule = Object.freeze({
* @returns {Rule} A new rule for the value
*/
asLocalDateTime (rule?: Rule & { stringify?: boolean, toStandardDate?: boolean }): Rule {
let parameterConversion
if (rule?.stringify === true) {
parameterConversion = (str: string) => LocalDateTime.fromString(str)
}
if (rule?.toStandardDate === true) {
parameterConversion = (standardDate: StandardDate) => LocalDateTime.fromStandardDate(standardDate)
}
return {
validate: (value: any, field: string) => {
if (!isLocalDateTime(value)) {
throw new TypeError(`${field} should be a LocalDateTime but received ${typeof value}`)
}
},
convert: (value: LocalDateTime) => convertStdDate(value, rule),
parameterConversion,
...rule
}
},
Expand All @@ -333,13 +345,21 @@ export const rule = Object.freeze({
* @returns {Rule} A new rule for the value
*/
asDateTime (rule?: Rule & { stringify?: boolean, toStandardDate?: boolean }): Rule {
let parameterConversion
if (rule?.stringify === true) {
parameterConversion = (str: string) => DateTime.fromString(str)
}
if (rule?.toStandardDate === true) {
parameterConversion = (standardDate: StandardDate) => DateTime.fromStandardDate(standardDate)
}
return {
validate: (value: any, field: string) => {
if (!isDateTime(value)) {
throw new TypeError(`${field} should be a DateTime but received ${typeof value}`)
}
},
convert: (value: DateTime) => convertStdDate(value, rule),
parameterConversion,
...rule
}
},
Expand All @@ -363,6 +383,12 @@ export const rule = Object.freeze({
}
return list
},
parameterConversion: (list: any[]) => {
if (rule?.apply != null) {
return list.map((value) => optionalParameterConversion(value, rule.apply))
}
return list
},
...rule
}
},
Expand All @@ -386,6 +412,7 @@ export const rule = Object.freeze({
}
return value
},
parameterConversion: rule?.asTypedList === true ? (typedArray: Int16Array | Int32Array | BigInt64Array | Float32Array | Float64Array) => vector(typedArray) : undefined,
...rule
}
}
Expand Down
8 changes: 6 additions & 2 deletions packages/core/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ import { RecordShape } from './record'
import NotificationFilter from './notification-filter'
import { Logger } from './internal/logger'
import { cacheKey } from './internal/auth-util'
import { Rules } from './mapping.highlevel'

type ConnectionConsumer<T> = (connection: Connection) => Promise<T> | T
type ManagedTransactionWork<T> = (tx: ManagedTransaction) => Promise<T> | T

interface TransactionConfig {
timeout?: NumberOrInteger
metadata?: object
parameterRules?: Rules
}

/**
Expand Down Expand Up @@ -186,11 +188,13 @@ class Session {
run<R extends RecordShape = RecordShape> (
query: Query,
parameters?: any,
transactionConfig?: TransactionConfig
transactionConfig?: TransactionConfig,
parameterRules?: Rules
): Result<R> {
const { validatedQuery, params } = validateQueryAndParameters(
query,
parameters
parameters,
{ parameterRules }
)
const autoCommitTxConfig = (transactionConfig != null)
? new TxConfig(transactionConfig, this._log)
Expand Down
Loading