Browse Source

Updated CommonTableExpression. Cleaned up unit tests and added comments.

Corey Taylor 6 years ago
parent
commit
9284667ffa

+ 62 - 165
src/Database/Expression/CommonTableExpression.php

@@ -16,11 +16,10 @@ declare(strict_types=1);
  */
 namespace Cake\Database\Expression;
 
-use Cake\Database\Exception as DatabaseException;
 use Cake\Database\ExpressionInterface;
 use Cake\Database\ValueBinder;
 use Closure;
-use InvalidArgumentException;
+use RuntimeException;
 
 /**
  * An expression that represents a common table expression definition.
@@ -37,18 +36,11 @@ class CommonTableExpression implements ExpressionInterface
     /**
      * The field names to use for the CTE.
      *
-     * @var \Cake\Database\ExpressionInterface[]|string[]
+     * @var \Cake\Database\Expression\IdentifierExpression[]
      */
     protected $fields = [];
 
     /**
-     * The modifiers to use for the CTE.
-     *
-     * @var \Cake\Database\ExpressionInterface[]|string[]
-     */
-    protected $modifiers = [];
-
-    /**
      * The CTE query definition.
      *
      * @var \Cake\Database\ExpressionInterface|null
@@ -56,41 +48,36 @@ class CommonTableExpression implements ExpressionInterface
     protected $query;
 
     /**
-     * Whether the CTE operates recursively.
+     * Whether the CTE is materialized or not materialized.
      *
-     * @var bool
+     * @var string|null
      */
-    protected $recursive = false;
+    protected $materialized = null;
 
     /**
      * Constructor.
      *
      * @param string $name The CTE name.
-     * @param \Cake\Database\ExpressionInterface $query The CTE query definition.
+     * @param \Closure|\Cake\Database\ExpressionInterface $query CTE query
      */
-    public function __construct(?string $name = null, ?ExpressionInterface $query = null)
+    public function __construct(?string $name = null, $query = null)
     {
         $this->name = $name;
-        $this->query = $query;
+        if ($query) {
+            $this->query($query);
+        }
     }
 
     /**
-     * Returns the CTE name.
+     * Sets the name of this CTE.
      *
-     * @return string|null
-     */
-    public function getName(): ?string
-    {
-        return $this->name;
-    }
-
-    /**
-     * Sets the CTE name.
+     * This is the named you used to reference the expression
+     * in select, insert, etc queries.
      *
      * @param string $name The CTE name.
      * @return $this
      */
-    public function setName(string $name)
+    public function name(string $name)
     {
         $this->name = $name;
 
@@ -98,176 +85,91 @@ class CommonTableExpression implements ExpressionInterface
     }
 
     /**
-     * Returns the field names to use for the CTE.
-     *
-     * @return \Cake\Database\ExpressionInterface[]|string[]
-     */
-    public function getFields(): array
-    {
-        return $this->fields;
-    }
-
-    /**
-     * Sets the field names to use for the CTE.
+     * Sets the query for this CTE.
      *
-     * @param \Cake\Database\ExpressionInterface[]|string[] $fields The field names to use for the CTE.
+     * @param \Closure|\Cake\Database\ExpressionInterface $query CTE query
      * @return $this
-     * @throws \InvalidArgumentException When one or more fields are of an invalid type.
      */
-    public function setFields(array $fields)
+    public function query($query)
     {
-        foreach ($fields as $index => $field) {
-            if (is_string($field)) {
-                $fields[$index] = $field = new IdentifierExpression($field);
-            }
-
-            if (!($field instanceof ExpressionInterface)) {
-                throw new InvalidArgumentException(sprintf(
-                    'The `$fields` argument must contain only instances of `%s`, or strings, `%s` given at index `%d`.',
-                    ExpressionInterface::class,
-                    getTypeName($field),
-                    $index
-                ));
+        if ($query instanceof Closure) {
+            $query = $query();
+            if (!($query instanceof ExpressionInterface)) {
+                throw new RuntimeException(
+                    'You must return an `ExpressionInterface` from closure passed to `query()`.'
+                );
             }
         }
-
-        $this->fields = $fields;
+        $this->query = $query;
 
         return $this;
     }
 
     /**
-     * Returns the modifiers to use for the CTE.
+     * Adds one or more fields (arguments) to the CTE.
      *
-     * @return \Cake\Database\ExpressionInterface[]|string[]
-     */
-    public function getModifiers(): array
-    {
-        return $this->modifiers;
-    }
-
-    /**
-     * Sets the modifiers to use for the CTE.
-     *
-     * @param \Cake\Database\ExpressionInterface[]|string[] $modifiers The modifiers to use for the CTE.
+     * @param string|string[]|\Cake\Database\Expression\IdentifierExpression|\Cake\Database\Expression\IdentifierExpression[] $fields Field names
      * @return $this
-     * @throws \InvalidArgumentException When one or more modifiers are of an invalid type.
      */
-    public function setModifiers(array $modifiers)
+    public function field($fields)
     {
-        foreach ($modifiers as $index => $modifier) {
-            if (
-                !($modifier instanceof ExpressionInterface) &&
-                !is_string($modifier)
-            ) {
-                throw new InvalidArgumentException(sprintf(
-                    'The `$modifiers` argument must contain only instances of `%s`, or strings, ' .
-                        '`%s` given at index `%d`.',
-                    ExpressionInterface::class,
-                    getTypeName($modifier),
-                    $index
-                ));
+        $fields = (array)$fields;
+        foreach ($fields as &$field) {
+            if (!($field instanceof IdentifierExpression)) {
+                $field = new IdentifierExpression($field);
             }
         }
-
-        $this->modifiers = $modifiers;
+        $this->fields = array_merge($this->fields, $fields);
 
         return $this;
     }
 
     /**
-     * Returns the CTE query definition.
-     *
-     * @return \Cake\Database\ExpressionInterface|null
-     */
-    public function getQuery(): ?ExpressionInterface
-    {
-        return $this->query;
-    }
-
-    /**
-     * Sets the CTE query definition.
+     * Sets this CTE as materialized.
      *
-     * @param \Cake\Database\ExpressionInterface $query The CTE query definition.
      * @return $this
      */
-    public function setQuery(ExpressionInterface $query)
+    public function materialized()
     {
-        $this->query = $query;
+        $this->materialized = 'MATERIALIZED';
 
         return $this;
     }
 
     /**
-     * Returns whether the CTE operates recursively.
+     * Sets this CTE as not materialized.
      *
-     * @return bool
-     */
-    public function isRecursive(): bool
-    {
-        return $this->recursive;
-    }
-
-    /**
-     * Sets whether the CTE operates recursively.
-     *
-     * @param bool $recursive Indicates whether the CTE query operates recursively.
      * @return $this
      */
-    public function setRecursive(bool $recursive)
+    public function notMaterialized()
     {
-        $this->recursive = $recursive;
+        $this->materialized = 'NOT MATERIALIZED';
 
         return $this;
     }
 
     /**
-     * {@inheritDoc}
-     *
-     * @throws \Cake\Database\Exception When not name has been set.
-     * @throws \Cake\Database\Exception When not query has been set.
+     * @inheritDoc
      */
     public function sql(ValueBinder $generator): string
     {
-        if (empty($this->name)) {
-            throw new DatabaseException(
-                'Cannot generate SQL for common table expressions that have no name.'
-            );
-        }
-
-        if (empty($this->query)) {
-            throw new DatabaseException(
-                'Cannot generate SQL for common table expressions that have no query.'
-            );
-        }
-
         $fields = '';
-        if (!empty($this->fields)) {
-            $fields = [];
-            foreach ($this->fields as $field) {
-                if ($field instanceof ExpressionInterface) {
-                    $field = $field->sql($generator);
-                }
-                $fields[] = $field;
-            }
-
-            $fields = sprintf('(%s)', implode(', ', $fields));
+        if ($this->fields) {
+            $expressions = array_map(function (IdentifierExpression $e) use ($generator) {
+                return $e->sql($generator);
+            }, $this->fields);
+            $fields = sprintf('(%s)', implode(', ', $expressions));
         }
 
-        $modifiers = '';
-        if (!empty($this->modifiers)) {
-            $modifiers = [];
-            foreach ($this->modifiers as $modifier) {
-                if ($modifier instanceof ExpressionInterface) {
-                    $modifier = $modifier->sql($generator);
-                }
-                $modifiers[] = $modifier;
-            }
-
-            $modifiers = ' ' . implode(' ', $modifiers);
-        }
+        $suffix = $this->materialized ? $this->materialized . ' ' : '';
 
-        return sprintf('%s%s AS%s (%s)', $this->name, $fields, $modifiers, $this->query->sql($generator));
+        return sprintf(
+            '%s%s AS %s(%s)',
+            (string)$this->name,
+            $fields,
+            $suffix,
+            $this->query ? $this->query->sql($generator) : ''
+        );
     }
 
     /**
@@ -275,11 +177,14 @@ class CommonTableExpression implements ExpressionInterface
      */
     public function traverse(Closure $visitor)
     {
-        foreach (array_merge($this->fields, $this->modifiers, [$this->query]) as $part) {
-            if ($part instanceof ExpressionInterface) {
-                $visitor($part);
-                $part->traverse($visitor);
-            }
+        foreach ($this->fields as $field) {
+            $visitor($field);
+            $field->traverse($visitor);
+        }
+
+        if ($this->query) {
+            $visitor($this->query);
+            $this->query->traverse($visitor);
         }
 
         return $this;
@@ -292,20 +197,12 @@ class CommonTableExpression implements ExpressionInterface
      */
     public function __clone()
     {
-        if ($this->query instanceof ExpressionInterface) {
+        if ($this->query) {
             $this->query = clone $this->query;
         }
 
         foreach ($this->fields as $key => $field) {
-            if ($this->fields[$key] instanceof ExpressionInterface) {
-                $this->fields[$key] = clone $this->fields[$key];
-            }
-        }
-
-        foreach ($this->modifiers as $key => $modifier) {
-            if ($this->modifiers[$key] instanceof ExpressionInterface) {
-                $this->modifiers[$key] = clone $this->modifiers[$key];
-            }
+            $this->fields[$key] = clone $field;
         }
     }
 }

+ 24 - 55
src/Database/Query.php

@@ -408,81 +408,50 @@ class Query implements ExpressionInterface, IteratorAggregate
      *         ->from('articles');
      *
      *     return $cte
-     *         ->setName('cte')
-     *         ->setQuery($cteQuery);
+     *         ->name('cte')
+     *         ->query($cteQuery);
      * });
      * ```
      *
-     * The list of expressions can be reset by overwriting and passing `null` for the
-     * expression:
-     *
-     * ```
-     * $query->with(null, true);
-     * ```
-     *
-     * @param \Cake\Database\Expression\CommonTableExpression|\Closure|null $expression The CTE to add.
+     * @param \Closure|\Cake\Database\Expression\CommonTableExpression $cte The CTE to add.
+     * @param bool $recursive Whether the CTE is recursive.
      * @param bool $overwrite Whether to reset the list of CTEs.
      * @return $this
-     * @throws \InvalidArgumentException When passing `null` for the `$expression` argument but not enabling
-     *  `$overwrite`.
-     * @throws \InvalidArgumentException When an invalid type is passed or returned for the `$expression` argument.
-     * @throws \InvalidArgumentException When the given CTE object has no name set.
-     * @throws \InvalidArgumentException When the given CTE object has no query set.
-     * @throws \InvalidArgumentException When a CTE object with the same name is already attached to this query.
      */
-    public function with($expression, $overwrite = false)
+    public function with($cte, bool $recursive = false, bool $overwrite = false)
     {
         if ($overwrite) {
             $this->_parts['with'] = [];
         }
 
-        if ($expression === null) {
-            if (!$overwrite) {
-                throw new \InvalidArgumentException(
-                    'Resetting the WITH clause only works when overwriting is enabled.'
+        if ($cte instanceof Closure) {
+            $cte = $cte(new CommonTableExpression(), $this);
+            if (!($cte instanceof CommonTableExpression)) {
+                throw new RuntimeException(
+                    'You must return a `CommonTableExpression` from closure passed to `with()`.'
                 );
             }
-
-            return $this;
-        }
-
-        if ($expression instanceof Closure) {
-            $expression = $expression(new CommonTableExpression(), $this);
-        }
-
-        if (!($expression instanceof CommonTableExpression)) {
-            throw new InvalidArgumentException(sprintf(
-                'The common table expression must be an instance of `%s`, `%s` given.',
-                CommonTableExpression::class,
-                getTypeName($expression)
-            ));
-        }
-
-        $name = $expression->getName();
-        if (empty($name)) {
-            throw new InvalidArgumentException('The common table expression must have a name.');
-        }
-
-        if (empty($expression->getQuery())) {
-            throw new InvalidArgumentException('The common table expression must have a query.');
-        }
-
-        foreach ($this->_parts['with'] as $existing) {
-            /** @var \Cake\Database\Expression\CommonTableExpression $existing */
-            if ($existing->getName() === $name) {
-                throw new InvalidArgumentException(sprintf(
-                    'A common table expression with the name `%s` is already attached to this query.',
-                    $name
-                ));
-            }
         }
 
-        $this->_parts['with'][] = $expression;
+        $this->_parts['with'][] = ['cte' => $cte, 'recursive' => $recursive];
+        $this->_dirty();
 
         return $this;
     }
 
     /**
+     * Adds a new recursive common table expression (CTE) to the query.
+     *
+     * @param \Closure|\Cake\Database\Expression\CommonTableExpression $cte The CTE to add.
+     * @param bool $overwrite Whether to reset the list of CTEs.
+     * @return $this
+     */
+    public function withRecursive($cte, bool $overwrite = false)
+    {
+        return $this->with($cte, true, $overwrite);
+    }
+
+    /**
      * Adds new fields to be returned by a `SELECT` statement when this query is
      * executed. Fields can be passed as an array of strings, array of expression
      * objects, a single expression or a single string.

+ 7 - 13
src/Database/QueryCompiler.php

@@ -162,29 +162,23 @@ class QueryCompiler
      * it constructs the CTE definitions list and generates the `RECURSIVE`
      * keyword when required.
      *
-     * @param \Cake\Database\Expression\CommonTableExpression[] $parts List of CTEs to be transformed to string
+     * @param array $parts List of CTEs to be transformed to string
      * @param \Cake\Database\Query $query The query that is being compiled
      * @param \Cake\Database\ValueBinder $generator The placeholder generator to be used in expressions
      * @return string
      */
     protected function _buildWithPart(array $parts, Query $query, ValueBinder $generator): string
     {
-        $hasRecursiveExpressions = false;
-
+        $recursive = false;
         $expressions = [];
-        foreach ($parts as $expression) {
-            if ($expression->isRecursive()) {
-                $hasRecursiveExpressions = true;
-            }
-            $expressions[] = $expression->sql($generator);
+        foreach ($parts as $cte) {
+            $recursive = $recursive || $cte['recursive'];
+            $expressions[] = $cte['cte']->sql($generator);
         }
 
-        $keywords = '';
-        if ($hasRecursiveExpressions) {
-            $keywords = 'RECURSIVE ';
-        }
+        $recursive = $recursive ? 'RECURSIVE ' : '';
 
-        return sprintf('WITH %s%s ', $keywords, implode(', ', $expressions));
+        return sprintf('WITH %s%s ', $recursive, implode(', ', $expressions));
     }
 
     /**

+ 3 - 3
src/Database/SqlserverCompiler.php

@@ -58,7 +58,7 @@ class SqlserverCompiler extends QueryCompiler
      * it constructs the CTE definitions list without generating the `RECURSIVE`
      * keyword that is neither required nor valid.
      *
-     * @param \Cake\Database\Expression\CommonTableExpression[] $parts List of CTEs to be transformed to string
+     * @param array $parts List of CTEs to be transformed to string
      * @param \Cake\Database\Query $query The query that is being compiled
      * @param \Cake\Database\ValueBinder $generator The placeholder generator to be used in expressions
      * @return string
@@ -66,8 +66,8 @@ class SqlserverCompiler extends QueryCompiler
     protected function _buildWithPart(array $parts, Query $query, ValueBinder $generator): string
     {
         $expressions = [];
-        foreach ($parts as $expression) {
-            $expressions[] = $expression->sql($generator);
+        foreach ($parts as $cte) {
+            $expressions[] = $cte['cte']->sql($generator);
         }
 
         return sprintf('WITH %s ', implode(', ', $expressions));

+ 16 - 0
src/TestSuite/TestCase.php

@@ -491,6 +491,22 @@ abstract class TestCase extends BaseTestCase
     }
 
     /**
+     * Assert that a string starts with SQL with db-specific characters like quotes removed.
+     *
+     * @param string $needle The string to compare
+     * @param string $haystack The SQL to filter
+     * @param string $message The message to display on failure
+     * @return void
+     */
+    public function assertStartsWithSql(
+        string $needle,
+        string $haystack,
+        string $message = ''
+    ): void {
+        $this->assertStringStartsWith($needle, preg_replace('/[`"\[\]]/', '', $haystack), $message);
+    }
+
+    /**
      * Asserts HTML tags.
      *
      * Takes an array $expected and generates a regex from it to match the provided $string.

File diff suppressed because it is too large
+ 0 - 1027
tests/TestCase/Database/CommonTableExpressionIntegrationTest.php


+ 450 - 0
tests/TestCase/Database/CommonTableExpressionQueryTests.php

@@ -0,0 +1,450 @@
+<?php
+declare(strict_types=1);
+
+/**
+ * CakePHP(tm) : Rapid Development Framework (https://cakephp.org)
+ * Copyright (c) Cake Software Foundation, Inc. (https://cakefoundation.org)
+ *
+ * Licensed under The MIT License
+ * For full copyright and license information, please see the LICENSE.txt
+ * Redistributions of files must retain the above copyright notice.
+ *
+ * @copyright     Copyright (c) Cake Software Foundation, Inc. (https://cakefoundation.org)
+ * @link          https://cakephp.org CakePHP(tm) Project
+ * @since         4.1.0
+ * @license       https://opensource.org/licenses/mit-license.php MIT License
+ */
+namespace Cake\Test\TestCase\Database;
+
+use Cake\Database\Driver\Mysql;
+use Cake\Database\Driver\Sqlite;
+use Cake\Database\Driver\Sqlserver;
+use Cake\Database\Expression\CommonTableExpression;
+use Cake\Database\Expression\QueryExpression;
+use Cake\Database\Query;
+use Cake\Database\ValueBinder;
+use Cake\Datasource\ConnectionManager;
+use Cake\TestSuite\TestCase;
+
+class CommonTableExpressionQueryTests extends TestCase
+{
+    /**
+     * @inheritDoc
+     */
+    protected $fixtures = [
+        'core.Articles',
+    ];
+
+    /**
+     * @inheritDoc
+     */
+    public $autoFixtures = false;
+
+    /**
+     * @var \Cake\Database\Connection
+     */
+    protected $connection;
+
+    /**
+     * @var bool
+     */
+    protected $autoQuote;
+
+    public function setUp(): void
+    {
+        parent::setUp();
+        $this->connection = ConnectionManager::get('test');
+
+        $this->skipIf(
+            !$this->connection->getDriver()->supportsCTEs(),
+            'The current driver does not support common table expressions.'
+        );
+    }
+
+    public function tearDown(): void
+    {
+        parent::tearDown();
+        unset($this->connection);
+    }
+
+    /**
+     * Tests with() sql generation.
+     *
+     * @return void
+     */
+    public function testWithCte()
+    {
+        $query = $this->connection->newQuery()
+            ->with(new CommonTableExpression('cte', function () {
+                return $this->connection->newQuery()->select(['col' => 1]);
+            }))
+            ->select('col')
+            ->from('cte');
+
+        $this->assertEqualsSql(
+            'WITH cte AS (SELECT 1 AS col) SELECT col FROM cte',
+            $query->sql(new ValueBinder())
+        );
+
+        $expected = [
+            [
+                'col' => '1',
+            ],
+        ];
+
+        $result = $query->execute();
+        $this->assertEquals($expected, $result->fetchAll('assoc'));
+        $result->closeCursor();
+    }
+
+    /**
+     * Tests calling with() with overwrite clears other CTEs.
+     *
+     * @return void
+     */
+    public function testWithCteOverwrite()
+    {
+        $query = $this->connection->newQuery()
+            ->with(new CommonTableExpression('cte', function () {
+                return $this->connection->newQuery()->select(['col' => '1']);
+            }))
+            ->select('col')
+            ->from('cte');
+
+        $this->assertEqualsSql(
+            'WITH cte AS (SELECT 1 AS col) SELECT col FROM cte',
+            $query->sql(new ValueBinder())
+        );
+
+        $query
+            ->with(new CommonTableExpression('cte2', $this->connection->newQuery()), false, true)
+            ->from('cte2', true);
+        $this->assertEqualsSql(
+            'WITH cte2 AS () SELECT col FROM cte2',
+            $query->sql(new ValueBinder())
+        );
+    }
+
+    /**
+     * Tests recursive CTE.
+     *
+     * @return void
+     */
+    public function testWithRecursiveCte()
+    {
+        $query = $this->connection->newQuery()
+            ->withRecursive(function (CommonTableExpression $cte, Query $query) {
+                $anchorQuery = $query->getConnection()
+                    ->newQuery()
+                    ->select(1);
+
+                $recursiveQuery = $query->getConnection()
+                    ->newQuery()
+                    ->select(function (Query $query) {
+                        return $query->newExpr('col + 1');
+                    })
+                    ->from('cte')
+                    ->where(['col !=' => 3], ['col' => 'integer']);
+
+                $cteQuery = $anchorQuery->unionAll($recursiveQuery);
+
+                return $cte
+                    ->name('cte')
+                    ->field(['col'])
+                    ->query($cteQuery);
+            })
+            ->select('col')
+            ->from('cte');
+
+        if ($this->connection->getDriver() instanceof Sqlserver) {
+            $expectedSql =
+                "WITH cte(col) AS " .
+                    "(SELECT 1\nUNION ALL SELECT (col + 1) FROM cte WHERE col != :c0) " .
+                        "SELECT col FROM cte";
+        } elseif ($this->connection->getDriver() instanceof Sqlite) {
+            $expectedSql =
+                "WITH RECURSIVE cte(col) AS " .
+                    "(SELECT 1\nUNION ALL SELECT (col + 1) FROM cte WHERE col != :c0) " .
+                        "SELECT col FROM cte";
+        } else {
+            $expectedSql =
+                "WITH RECURSIVE cte(col) AS " .
+                    "((SELECT 1)\nUNION ALL (SELECT (col + 1) FROM cte WHERE col != :c0)) " .
+                        "SELECT col FROM cte";
+        }
+        $this->assertEqualsSql(
+            $expectedSql,
+            $query->sql(new ValueBinder())
+        );
+
+        $expected = [
+            [
+                'col' => '1',
+            ],
+            [
+                'col' => '2',
+            ],
+            [
+                'col' => '3',
+            ],
+        ];
+
+        $result = $query->execute();
+        $this->assertEquals($expected, $result->fetchAll('assoc'));
+        $result->closeCursor();
+    }
+
+    /**
+     * Test inserting from CTE.
+     *
+     * @return void
+     */
+    public function testWithInsertQuery()
+    {
+        $this->skipIf(
+            ($this->connection->getDriver() instanceof Mysql),
+            '`WITH ... INSERT INTO` syntax is not supported in MySQL.'
+        );
+
+        $this->loadFixtures('Articles');
+
+        // test initial state
+        $result = $this->connection->newQuery()
+            ->select('*')
+            ->from('articles')
+            ->where(['id' => 4])
+            ->execute();
+        $this->assertFalse($result->fetch('assoc'));
+        $result->closeCursor();
+
+        $query = $this->connection
+            ->newQuery()
+            ->with(function (CommonTableExpression $cte, Query $query) {
+                return $cte
+                    ->name('cte')
+                    ->field(['title', 'body'])
+                    ->query($query->newExpr("SELECT 'Fourth Article', 'Fourth Article Body'"));
+            })
+            ->insert(['title', 'body'])
+            ->into('articles')
+            ->values(
+                $this->connection
+                    ->newQuery()
+                    ->select('*')
+                    ->from('cte')
+            );
+
+        $this->assertStartsWithSql(
+            "WITH cte(title, body) AS (SELECT 'Fourth Article', 'Fourth Article Body') " .
+                "INSERT INTO articles (title, body)",
+            $query->sql(new ValueBinder())
+        );
+
+        // run insert
+        $query->execute()->closeCursor();
+
+        $expected = [
+            'id' => '4',
+            'author_id' => null,
+            'title' => 'Fourth Article',
+            'body' => 'Fourth Article Body',
+            'published' => 'N',
+        ];
+
+        // test updated state
+        $result = $this->connection->newQuery()
+            ->select('*')
+            ->from('articles')
+            ->where(['id' => 4])
+            ->execute();
+        $this->assertEquals($expected, $result->fetch('assoc'));
+        $result->closeCursor();
+    }
+
+    /**
+     * Tests inserting from CTE as values list.
+     *
+     * @return void
+     */
+    public function testWithInInsertWithValuesQuery()
+    {
+        $this->skipIf(
+            ($this->connection->getDriver() instanceof Sqlserver),
+            '`INSERT INTO ... WITH` syntax is not supported in SQL Server.'
+        );
+
+        $this->loadFixtures('Articles');
+
+        $query = $this->connection->newQuery()
+            ->insert(['title', 'body'])
+            ->into('articles')
+            ->values(
+                $this->connection->newQuery()
+                    ->with(function (CommonTableExpression $cte, Query $query) {
+                        return $cte
+                            ->name('cte')
+                            ->field(['title', 'body'])
+                            ->query($query->newExpr("SELECT 'Fourth Article', 'Fourth Article Body'"));
+                    })
+                    ->select('*')
+                    ->from('cte')
+            );
+
+        $this->assertStartsWithSql(
+            "INSERT INTO articles (title, body) " .
+                "WITH cte(title, body) AS (SELECT 'Fourth Article', 'Fourth Article Body') SELECT * FROM cte",
+            $query->sql(new ValueBinder())
+        );
+
+        // run insert
+        $query->execute()->closeCursor();
+
+        $expected = [
+            'id' => '4',
+            'author_id' => null,
+            'title' => 'Fourth Article',
+            'body' => 'Fourth Article Body',
+            'published' => 'N',
+        ];
+
+        // test updated state
+        $result = $this->connection->newQuery()
+            ->select('*')
+            ->from('articles')
+            ->where(['id' => 4])
+            ->execute();
+        $this->assertEquals($expected, $result->fetch('assoc'));
+        $result->closeCursor();
+    }
+
+    /**
+     * Tests updating from CTE.
+     *
+     * @return void
+     */
+    public function testWithInUpdateQuery()
+    {
+        $this->loadFixtures('Articles');
+
+        // test initial state
+        $result = $this->connection->newQuery()
+            ->select(['count' => 'COUNT(*)'])
+            ->from('articles')
+            ->where(['published' => 'Y'])
+            ->execute();
+        $this->assertEquals(['count' => '3'], $result->fetch('assoc'));
+        $result->closeCursor();
+
+        $query = $this->connection->newQuery()
+            ->with(function (CommonTableExpression $cte, Query $query) {
+                $cteQuery = $query->getConnection()
+                    ->newQuery()
+                    ->select('articles.id')
+                    ->from('articles')
+                    ->where(['articles.id !=' => 1]);
+
+                return $cte
+                    ->name('cte')
+                    ->query($cteQuery);
+            })
+            ->update('articles')
+            ->set('published', 'N')
+            ->where(function (QueryExpression $exp, Query $query) {
+                return $exp->in(
+                    'articles.id',
+                    $query
+                        ->getConnection()
+                        ->newQuery()
+                        ->select('cte.id')
+                        ->from('cte')
+                );
+            });
+
+        $this->assertEqualsSql(
+            "WITH cte AS (SELECT articles.id FROM articles WHERE articles.id != :c0) " .
+                "UPDATE articles SET published = :c1 WHERE id IN (SELECT cte.id FROM cte)",
+            $query->sql(new ValueBinder())
+        );
+
+        // run update
+        $query->execute()->closeCursor();
+
+        // test updated state
+        $result = $this->connection->newQuery()
+            ->select(['count' => 'COUNT(*)'])
+            ->from('articles')
+            ->where(['published' => 'Y'])
+            ->execute();
+        $this->assertEquals(['count' => '1'], $result->fetch('assoc'));
+        $result->closeCursor();
+    }
+
+    /**
+     * Tests deleting from CTE.
+     *
+     * @return void
+     */
+    public function testWithInDeleteQuery()
+    {
+        $this->loadFixtures('Articles');
+
+        // test initial state
+        $result = $this->connection
+            ->newQuery()
+            ->select(['count' => 'COUNT(*)'])
+            ->from('articles')
+            ->execute();
+        $this->assertEquals(['count' => '3'], $result->fetch('assoc'));
+        $result->closeCursor();
+
+        $query = $this->connection->newQuery()
+            ->with(function (CommonTableExpression $cte, Query $query) {
+                $cteQuery = $query->getConnection()
+                    ->newQuery()
+                    ->select('articles.id')
+                    ->from('articles')
+                    ->where(['articles.id !=' => 1]);
+
+                return $cte
+                    ->name('cte')
+                    ->query($cteQuery);
+            })
+            ->delete()
+            ->from(['a' => 'articles'])
+            ->where(function (QueryExpression $exp, Query $query) {
+                return $exp->in(
+                    'a.id',
+                    $query
+                        ->getConnection()
+                        ->newQuery()
+                        ->select('cte.id')
+                        ->from('cte')
+                );
+            });
+
+        $this->assertEqualsSql(
+            "WITH cte AS (SELECT articles.id FROM articles WHERE articles.id != :c0) " .
+                "DELETE FROM articles WHERE id IN (SELECT cte.id FROM cte)",
+            $query->sql(new ValueBinder())
+        );
+
+        // run delete
+        $query->execute()->closeCursor();
+
+        $expected = [
+            'id' => '1',
+            'author_id' => '1',
+            'title' => 'First Article',
+            'body' => 'First Article Body',
+            'published' => 'Y',
+        ];
+
+        // test updated state
+        $result = $this->connection->newQuery()
+            ->select('*')
+            ->from('articles')
+            ->execute();
+        $this->assertEquals($expected, $result->fetch('assoc'));
+        $result->closeCursor();
+    }
+}

+ 67 - 226
tests/TestCase/Database/Expression/CommonTableExpressionTest.php

@@ -16,15 +16,11 @@ declare(strict_types=1);
  */
 namespace Cake\Test\TestCase\Database\Expression;
 
-use Cake\Database\Exception as DatabaseException;
 use Cake\Database\Expression\CommonTableExpression;
 use Cake\Database\Expression\IdentifierExpression;
-use Cake\Database\Expression\QueryExpression;
-use Cake\Database\Query;
 use Cake\Database\ValueBinder;
 use Cake\Datasource\ConnectionManager;
 use Cake\TestSuite\TestCase;
-use InvalidArgumentException;
 
 class CommonTableExpressionTest extends TestCase
 {
@@ -45,252 +41,97 @@ class CommonTableExpressionTest extends TestCase
         unset($this->connection);
     }
 
-    public function testConstructWithNoArguments()
-    {
-        $expression = new CommonTableExpression();
-
-        $this->assertNull($expression->getName());
-        $this->assertEmpty($expression->getFields());
-        $this->assertEmpty($expression->getModifiers());
-        $this->assertNull($expression->getQuery());
-    }
-
-    public function testGetSetName(): void
-    {
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $this->assertEquals('cte', $expression->getName());
-
-        $expression->setName('other');
-        $this->assertEquals('other', $expression->getName());
-    }
-
-    public function testGetSetFields(): void
-    {
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $this->assertEmpty($expression->getFields());
-
-        $expression->setFields(['col1', 'col2']);
-        $this->assertEquals(
-            [new IdentifierExpression('col1'), new IdentifierExpression('col2')],
-            $expression->getFields()
-        );
-    }
-
-    public function testSetFieldsWithInvalidType(): void
-    {
-        $this->expectException(InvalidArgumentException::class);
-        $this->expectExceptionMessage(
-            'The `$fields` argument must contain only instances of `Cake\Database\ExpressionInterface`, ' .
-            'or strings, `integer` given at index `1`.'
-        );
-
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $expression->setFields(['col1', 123]);
-    }
-
-    public function testGetSetModifiers(): void
-    {
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $this->assertEmpty($expression->getModifiers());
-
-        $expression->setModifiers(['FOO', 'BAR']);
-        $this->assertEquals(['FOO', 'BAR'], $expression->getModifiers());
-    }
-
-    public function testModifiersFieldsWithInvalidType(): void
-    {
-        $this->expectException(InvalidArgumentException::class);
-        $this->expectExceptionMessage(
-            'The `$modifiers` argument must contain only instances of `Cake\Database\ExpressionInterface`, ' .
-            'or strings, `integer` given at index `1`.'
-        );
-
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $expression->setModifiers(['FOO', 123]);
-    }
-
-    public function testGetSetQuery(): void
-    {
-        $connection = ConnectionManager::get('test');
-
-        $query = $this->connection->newQuery()->select(1);
-        $expression = new CommonTableExpression('cte', $query);
-        $this->assertSame($query, $expression->getQuery());
-
-        $query = $connection->newQuery()->select([1, 2]);
-        $expression->setQuery($query);
-        $this->assertSame($query, $expression->getQuery());
-    }
-
-    public function testGetSetRecursive(): void
-    {
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-        $this->assertFalse($expression->isRecursive());
-
-        $expression->setRecursive(true);
-        $this->assertTrue($expression->isRecursive());
-    }
-
-    public function testSqlWithNoName()
-    {
-        $this->expectException(DatabaseException::class);
-        $this->expectExceptionMessage('Cannot generate SQL for common table expressions that have no name.');
-
-        $expression = new CommonTableExpression();
-        $expression->sql(new ValueBinder());
-    }
-
-    public function testSqlWithNoQuery()
-    {
-        $this->expectException(DatabaseException::class);
-        $this->expectExceptionMessage('Cannot generate SQL for common table expressions that have no query.');
-
-        $expression = new CommonTableExpression('cte');
-        $expression->sql(new ValueBinder());
-    }
-
-    public function testSqlWithQueryAsExpression(): void
-    {
-        $expression = new CommonTableExpression('cte', $this->connection->newQuery()->select(1));
-
-        $this->assertEqualsSql(
-            'cte AS (SELECT 1)',
-            $expression->sql(new ValueBinder())
-        );
-    }
-
-    public function testSqlWithQueryAsCustomExpression(): void
-    {
-        $expression = new CommonTableExpression('cte', new QueryExpression('SELECT 1'));
-
-        $this->assertEqualsSql(
-            'cte AS (SELECT 1)',
-            $expression->sql(new ValueBinder())
-        );
-    }
-
-    public function testSqlWithFieldsAsStrings(): void
+    /**
+     * Tests constructing CommonTableExpressions.
+     *
+     * @return void
+     */
+    public function testCteConstructor()
     {
-        $expression = (new CommonTableExpression('cte', $this->connection->newQuery()->select([1, 2])))
-            ->setFields(['col1', 'col2']);
+        $cte = new CommonTableExpression('test', $this->connection->newQuery());
+        $this->assertEqualsSql('test AS ()', $cte->sql(new ValueBinder()));
 
-        $this->assertEquals(
-            'cte(col1, col2) AS (SELECT 1, 2)',
-            $expression->sql(new ValueBinder())
-        );
+        $cte = (new CommonTableExpression())
+            ->name('test')
+            ->query($this->connection->newQuery());
+        $this->assertEqualsSql('test AS ()', $cte->sql(new ValueBinder()));
     }
 
-    public function testSqlWithFieldsAsExpressions(): void
+    /**
+     * Tests setting fields.
+     *
+     * @return void
+     */
+    public function testFields(): void
     {
-        $expression = (new CommonTableExpression('cte', $this->connection->newQuery()->select([1, 2])))
-            ->setFields([
-                new IdentifierExpression('col1'),
-                new IdentifierExpression('col2'),
-            ]);
-
-        $this->assertEquals(
-            'cte(col1, col2) AS (SELECT 1, 2)',
-            $expression->sql(new ValueBinder())
-        );
+        $cte = (new CommonTableExpression('test', $this->connection->newQuery()))
+            ->field('col1')
+            ->field([new IdentifierExpression('col2')]);
+        $this->assertEqualsSql('test(col1, col2) AS ()', $cte->sql(new ValueBinder()));
     }
 
-    public function testSqlWithModifiersAsStrings(): void
+    /**
+     * Tests setting CTE materialized
+     *
+     * @return void
+     */
+    public function testMaterialized()
     {
-        $expression = (new CommonTableExpression('cte', $this->connection->newQuery()->select(1)))
-            ->setModifiers(['NOT MATERIALIZED']);
+        $cte = (new CommonTableExpression('test', $this->connection->newQuery()))
+            ->materialized();
+        $this->assertEqualsSql('test AS MATERIALIZED ()', $cte->sql(new ValueBinder()));
 
-        $this->assertEquals(
-            'cte AS NOT MATERIALIZED (SELECT 1)',
-            $expression->sql(new ValueBinder())
-        );
+        $cte->notMaterialized();
+        $this->assertEqualsSql('test AS NOT MATERIALIZED ()', $cte->sql(new ValueBinder()));
     }
 
-    public function testSqlWithModifiersAsExpressions(): void
+    /**
+     * Tests setting query using closures.
+     *
+     * @return void
+     */
+    public function testQueryClosures()
     {
-        $expression = (new CommonTableExpression('cte', $this->connection->newQuery()->select(1)))
-            ->setModifiers([new QueryExpression('NOT MATERIALIZED')]);
+        $cte = new CommonTableExpression('test', function () {
+            return $this->connection->newQuery();
+        });
+        $this->assertEqualsSql('test AS ()', $cte->sql(new ValueBinder()));
 
-        $this->assertEquals(
-            'cte AS NOT MATERIALIZED (SELECT 1)',
-            $expression->sql(new ValueBinder())
-        );
+        $cte->query(function () {
+            return $this->connection->newQuery()->select('1');
+        });
+        $this->assertEqualsSql('test AS (SELECT 1)', $cte->sql(new ValueBinder()));
     }
 
-    public function testTraverse(): void
+    /**
+     * Tests traversing CommonTableExpression.
+     *
+     * @return void
+     */
+    public function testTraverse()
     {
-        $query = new QueryExpression('SELECT 1');
-        $identifier = new IdentifierExpression('col');
-        $modifier = new QueryExpression('NOT MATERIALIZED');
-        $modifierWrapper = new QueryExpression($modifier);
-
-        $expression = (new CommonTableExpression('cte', $query))
-            ->setFields([$identifier])
-            ->setModifiers([$modifierWrapper]);
+        $query = $this->connection->newQuery()->select('1');
+        $field = new IdentifierExpression('field');
+        $cte = (new CommonTableExpression('test', $query))->field($field);
 
         $expressions = [];
-        $expression->traverse(function ($expression) use (&$expressions) {
+        $cte->traverse(function ($expression) use (&$expressions) {
             $expressions[] = $expression;
         });
 
-        $this->assertSame(
-            [$identifier, $modifierWrapper, $modifier, $query],
-            $expressions
-        );
+        $this->assertEquals($field, $expressions[0]);
+        $this->assertEquals($query, $expressions[1]);
     }
 
+    /**
+     * Tests cloning CommonTableExpression
+     */
     public function testClone(): void
     {
-        $connection = ConnectionManager::get('test');
-
-        $query = $connection->newQuery()->select(1);
-        $fieldExpression = new IdentifierExpression('col2');
-        $modifierExpression = new QueryExpression('BAR');
-
-        $expression = (new CommonTableExpression('cte', $query))
-            ->setFields([
-                'col1',
-                $fieldExpression,
-            ])
-            ->setModifiers([
-                'FOO',
-                $modifierExpression,
-            ])
-            ->setRecursive(true);
-
-        $clone = clone $expression;
-
-        $this->assertInstanceOf(CommonTableExpression::class, $clone);
-        $this->assertNotSame($clone, $expression);
-
-        $this->assertEquals('cte', $clone->getName());
-
-        $this->assertCount(2, $clone->getFields());
-        $this->assertInstanceOf(IdentifierExpression::class, $clone->getFields()[0]);
-        $this->assertEquals('col1', $clone->getFields()[0]->getIdentifier());
-        $this->assertInstanceOf(IdentifierExpression::class, $clone->getFields()[1]);
-        $this->assertNotSame($fieldExpression, $clone->getFields()[1]);
-        $this->assertEquals('col2', $clone->getFields()[1]->getIdentifier());
-
-        $this->assertCount(2, $clone->getModifiers());
-        $this->assertEquals('FOO', $clone->getModifiers()[0]);
-        $this->assertInstanceOf(QueryExpression::class, $clone->getModifiers()[1]);
-        $this->assertNotSame($fieldExpression, $clone->getModifiers()[1]);
-        $this->assertEquals('BAR', $clone->getModifiers()[1]->sql(new ValueBinder()));
-
-        $this->assertInstanceOf(Query::class, $clone->getQuery());
-        $this->assertNotSame($query, $clone->getQuery());
-        $this->assertEquals('SELECT 1', $clone->getQuery()->sql(new ValueBinder()));
-    }
-
-    public function testCloneEmpty(): void
-    {
-        $expression = new CommonTableExpression();
-        $clone = clone $expression;
-
-        $this->assertNotSame($expression, $clone);
-        $this->assertEquals($expression->getName(), $clone->getName());
-        $this->assertEquals($expression->getFields(), $clone->getFields());
-        $this->assertEquals($expression->getModifiers(), $clone->getModifiers());
-        $this->assertEquals($expression->getQuery(), $clone->getQuery());
+        $cte = new CommonTableExpression('test', function () {
+            return $this->connection->newQuery()->select('1');
+        });
+        $cte2 = (clone $cte)->field('col1');
+        $this->assertNotSame($cte->sql(new ValueBinder()), $cte2->sql(new ValueBinder()));
     }
 }

+ 7 - 2
tests/TestCase/ORM/QueryTest.php

@@ -3889,6 +3889,11 @@ class QueryTest extends TestCase
         $this->assertEquals($expected, $results);
     }
 
+    /**
+     * Tests ORM query using with CTE.
+     *
+     * @return void
+     */
     public function testWith(): void
     {
         $this->skipIf(
@@ -3927,8 +3932,8 @@ class QueryTest extends TestCase
             ->find()
             ->with(function (CommonTableExpression $cte) use ($cteQuery) {
                 return $cte
-                    ->setName('cte')
-                    ->setQuery($cteQuery);
+                    ->name('cte')
+                    ->query($cteQuery);
             })
             ->select(['row_num'])
             ->enableAutoFields()