258 lines
9.2 KiB
JavaScript
258 lines
9.2 KiB
JavaScript
import { CompiledQuery, Kysely } from 'kysely';
|
|
import { Connection, EventType, RawQueryFragment, Utils } from '@mikro-orm/core';
|
|
import { NativeQueryBuilder } from './query/NativeQueryBuilder.js';
|
|
/** Base class for SQL database connections, built on top of Kysely. */
|
|
export class AbstractSqlConnection extends Connection {
|
|
#client;
|
|
/** Establishes the database connection and runs the onConnect hook. */
|
|
async connect(options) {
|
|
await this.initClient();
|
|
this.connected = true;
|
|
if (options?.skipOnConnect !== true) {
|
|
await this.onConnect();
|
|
}
|
|
}
|
|
/** Initializes the Kysely client from driver options or a user-provided Kysely instance. */
|
|
createKysely() {
|
|
let driverOptions = this.options.driverOptions ?? this.config.get('driverOptions');
|
|
if (typeof driverOptions === 'function') {
|
|
driverOptions = driverOptions();
|
|
}
|
|
if (driverOptions instanceof Kysely) {
|
|
this.logger.log('info', 'Reusing Kysely client provided via `driverOptions`');
|
|
this.#client = driverOptions;
|
|
} else if ('createDriver' in driverOptions) {
|
|
this.logger.log('info', 'Reusing Kysely dialect provided via `driverOptions`');
|
|
this.#client = new Kysely({ dialect: driverOptions });
|
|
} else {
|
|
const dialect = this.createKyselyDialect(driverOptions);
|
|
if (dialect instanceof Promise) {
|
|
return dialect.then(d => {
|
|
this.#client = new Kysely({ dialect: d });
|
|
});
|
|
}
|
|
this.#client = new Kysely({ dialect });
|
|
}
|
|
}
|
|
/**
|
|
* @inheritDoc
|
|
*/
|
|
async close(force) {
|
|
await super.close(force);
|
|
await this.#client?.destroy();
|
|
this.connected = false;
|
|
this.#client = undefined;
|
|
}
|
|
/**
|
|
* @inheritDoc
|
|
*/
|
|
async isConnected() {
|
|
const check = await this.checkConnection();
|
|
return check.ok;
|
|
}
|
|
/**
|
|
* @inheritDoc
|
|
*/
|
|
async checkConnection() {
|
|
if (!this.connected) {
|
|
return { ok: false, reason: 'Connection not established' };
|
|
}
|
|
try {
|
|
await this.getClient().executeQuery(CompiledQuery.raw('select 1'));
|
|
return { ok: true };
|
|
} catch (error) {
|
|
return { ok: false, reason: error.message, error };
|
|
}
|
|
}
|
|
/** Returns the underlying Kysely client, creating it synchronously if needed. */
|
|
getClient() {
|
|
if (!this.#client) {
|
|
const maybePromise = this.createKysely();
|
|
/* v8 ignore next */
|
|
if (maybePromise instanceof Promise) {
|
|
throw new Error(
|
|
'Current driver requires async initialization, use `MikroORM.init()` instead of the constructor',
|
|
);
|
|
}
|
|
}
|
|
return this.#client;
|
|
}
|
|
/** Ensures the Kysely client is initialized, creating it asynchronously if needed. */
|
|
async initClient() {
|
|
if (!this.#client) {
|
|
await this.createKysely();
|
|
}
|
|
}
|
|
/** Executes a callback within a transaction, committing on success and rolling back on error. */
|
|
async transactional(cb, options = {}) {
|
|
const trx = await this.begin(options);
|
|
try {
|
|
const ret = await cb(trx);
|
|
await this.commit(trx, options.eventBroadcaster, options.loggerContext);
|
|
return ret;
|
|
} catch (error) {
|
|
await this.rollback(trx, options.eventBroadcaster, options.loggerContext);
|
|
throw error;
|
|
}
|
|
}
|
|
/** Begins a new transaction or creates a savepoint if a transaction context already exists. */
|
|
async begin(options = {}) {
|
|
if (options.ctx) {
|
|
const ctx = options.ctx;
|
|
await options.eventBroadcaster?.dispatchEvent(EventType.beforeTransactionStart, ctx);
|
|
ctx.index ??= 0;
|
|
const savepointName = `trx${ctx.index + 1}`;
|
|
const trx = await options.ctx.savepoint(savepointName).execute();
|
|
Reflect.defineProperty(trx, 'index', { value: ctx.index + 1 });
|
|
Reflect.defineProperty(trx, 'savepointName', { value: savepointName });
|
|
this.logQuery(this.platform.getSavepointSQL(savepointName), options.loggerContext);
|
|
await options.eventBroadcaster?.dispatchEvent(EventType.afterTransactionStart, trx);
|
|
return trx;
|
|
}
|
|
await this.ensureConnection();
|
|
await options.eventBroadcaster?.dispatchEvent(EventType.beforeTransactionStart);
|
|
let trxBuilder = this.getClient().startTransaction();
|
|
if (options.isolationLevel) {
|
|
trxBuilder = trxBuilder.setIsolationLevel(options.isolationLevel);
|
|
}
|
|
if (options.readOnly) {
|
|
trxBuilder = trxBuilder.setAccessMode('read only');
|
|
}
|
|
const trx = await trxBuilder.execute();
|
|
if (options.ctx) {
|
|
const ctx = options.ctx;
|
|
ctx.index ??= 0;
|
|
const savepointName = `trx${ctx.index + 1}`;
|
|
Reflect.defineProperty(trx, 'index', { value: ctx.index + 1 });
|
|
Reflect.defineProperty(trx, 'savepointName', { value: savepointName });
|
|
this.logQuery(this.platform.getSavepointSQL(savepointName), options.loggerContext);
|
|
} else {
|
|
for (const query of this.platform.getBeginTransactionSQL(options)) {
|
|
this.logQuery(query, options.loggerContext);
|
|
}
|
|
}
|
|
await options.eventBroadcaster?.dispatchEvent(EventType.afterTransactionStart, trx);
|
|
return trx;
|
|
}
|
|
/** Commits the transaction or releases the savepoint. */
|
|
async commit(ctx, eventBroadcaster, loggerContext) {
|
|
if (ctx.isRolledBack) {
|
|
return;
|
|
}
|
|
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionCommit, ctx);
|
|
if ('savepointName' in ctx) {
|
|
await ctx.releaseSavepoint(ctx.savepointName).execute();
|
|
this.logQuery(this.platform.getReleaseSavepointSQL(ctx.savepointName), loggerContext);
|
|
} else {
|
|
await ctx.commit().execute();
|
|
this.logQuery(this.platform.getCommitTransactionSQL(), loggerContext);
|
|
}
|
|
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionCommit, ctx);
|
|
}
|
|
/** Rolls back the transaction or rolls back to the savepoint. */
|
|
async rollback(ctx, eventBroadcaster, loggerContext) {
|
|
await eventBroadcaster?.dispatchEvent(EventType.beforeTransactionRollback, ctx);
|
|
if ('savepointName' in ctx) {
|
|
await ctx.rollbackToSavepoint(ctx.savepointName).execute();
|
|
this.logQuery(this.platform.getRollbackToSavepointSQL(ctx.savepointName), loggerContext);
|
|
} else {
|
|
await ctx.rollback().execute();
|
|
this.logQuery(this.platform.getRollbackTransactionSQL(), loggerContext);
|
|
}
|
|
await eventBroadcaster?.dispatchEvent(EventType.afterTransactionRollback, ctx);
|
|
}
|
|
prepareQuery(query, params = []) {
|
|
if (query instanceof NativeQueryBuilder) {
|
|
query = query.toRaw();
|
|
}
|
|
if (query instanceof RawQueryFragment) {
|
|
params = query.params;
|
|
query = query.sql;
|
|
}
|
|
query = this.config.get('onQuery')(query, params);
|
|
const formatted = this.platform.formatQuery(query, params);
|
|
return { query, params, formatted };
|
|
}
|
|
/** Executes a SQL query and returns the result based on the method: `'all'` for rows, `'get'` for single row, `'run'` for affected count. */
|
|
async execute(query, params = [], method = 'all', ctx, loggerContext) {
|
|
await this.ensureConnection();
|
|
const q = this.prepareQuery(query, params);
|
|
const sql = this.getSql(q.query, q.formatted, loggerContext);
|
|
return this.executeQuery(
|
|
sql,
|
|
async () => {
|
|
const compiled = CompiledQuery.raw(q.formatted);
|
|
const res = await (ctx ?? this.#client).executeQuery(compiled);
|
|
return this.transformRawResult(res, method);
|
|
},
|
|
{ ...q, ...loggerContext },
|
|
);
|
|
}
|
|
/** Executes a SQL query and returns an async iterable that yields results row by row. */
|
|
async *stream(query, params = [], ctx, loggerContext) {
|
|
await this.ensureConnection();
|
|
const q = this.prepareQuery(query, params);
|
|
const sql = this.getSql(q.query, q.formatted, loggerContext);
|
|
// construct the compiled query manually with `kind: 'SelectQueryNode'` to avoid sqlite validation for select queries when streaming
|
|
const compiled = {
|
|
query: {
|
|
kind: 'SelectQueryNode',
|
|
},
|
|
sql: q.formatted,
|
|
parameters: [],
|
|
};
|
|
try {
|
|
const res = (ctx ?? this.getClient()).getExecutor().stream(compiled, 1);
|
|
this.logQuery(sql, {
|
|
sql,
|
|
params,
|
|
...loggerContext,
|
|
affected: Utils.isPlainObject(res) ? res.affectedRows : undefined,
|
|
});
|
|
for await (const items of res) {
|
|
for (const row of this.transformRawResult(items, 'all')) {
|
|
yield row;
|
|
}
|
|
}
|
|
} catch (e) {
|
|
this.logQuery(sql, { sql, params, ...loggerContext, level: 'error' });
|
|
throw e;
|
|
}
|
|
}
|
|
/** @inheritDoc */
|
|
async executeDump(dump) {
|
|
await this.ensureConnection();
|
|
try {
|
|
const raw = CompiledQuery.raw(dump);
|
|
await this.getClient().executeQuery(raw);
|
|
} catch (e) {
|
|
/* v8 ignore next */
|
|
throw this.platform.getExceptionConverter().convertException(e);
|
|
}
|
|
}
|
|
getSql(query, formatted, context) {
|
|
const logger = this.config.getLogger();
|
|
if (!logger.isEnabled('query', context)) {
|
|
return query;
|
|
}
|
|
if (logger.isEnabled('query-params', context)) {
|
|
return formatted;
|
|
}
|
|
return query;
|
|
}
|
|
transformRawResult(res, method) {
|
|
if (method === 'get') {
|
|
return res.rows[0];
|
|
}
|
|
if (method === 'all') {
|
|
return res.rows;
|
|
}
|
|
return {
|
|
affectedRows: Number(res.numAffectedRows ?? res.rows.length),
|
|
insertId: res.insertId != null ? Number(res.insertId) : res.insertId,
|
|
row: res.rows[0],
|
|
rows: res.rows,
|
|
};
|
|
}
|
|
}
|