Yongseok's Blog
Back
9 min read
Using Styled Components on the Server (feat. React.cache)

Despite the title, this isn’t actually about using the real styled-components library.

dream-css-tool

I came across someone working on a project called dream-css-tool. It was an attempt to make styled-like syntax work in server components.

Let’s start by looking at how it’s used.

  1. First, wrap the root with StyleRegistry. (In this case, since it’s a Next.js App Directory project, it’s wrapped in the layout.)
import type { Metadata } from 'next';

import StyleRegistry from '@/components/StyleRegistry';

export const metadata: Metadata = {
  title: 'Dream CSS tool',
};

export default function RootLayout({
  children,
}: {
  children: React.ReactNode;
}) {
  return (
    <StyleRegistry>
      <html lang="en">
        <body>{children}</body>
      </html>
    </StyleRegistry>
  );
}
  1. Use styled. Write components the same way we’ve always used styled.
import React from 'react';
import styled from '../styled.js';

export default function StaticButton() {
  return <Button>Static Button</Button>;
}

const Button = styled('button')`
  display: block;
  padding: 1rem 2rem;
  border: none;
  border-radius: 4px;
  background: hsl(270deg 100% 30%);
  color: white;
  font-size: 1rem;
  cursor: pointer;
`;

The entire codebase is still only a few lines, so let’s take a quick look. Below is the original code of the key parts. (Comments included as-is.)

// styled.js
import React from 'react';

import { cache } from './components/StyleRegistry';

// TODO: Ideally, this API would use dot notation (styled.div) in
// addition to function calls (styled('div')). We should be able to
// use Proxies for this, like Framer Motion does.
export default function styled(Tag) {
  return (css) => {
    return function StyledComponent(props) {
      let collectedStyles =  cache();

      // Instead of using the filename, I'm using the `useId` hook to
      // generate a unique ID for each styled-component.
      const id = React.useId().replace(/:/g, '');
      const generatedClassName = `styled-${id}`;

      const styleContent = `.${generatedClassName} { ${css} }`;

      collectedStyles.push(styleContent);
      return <Tag className={generatedClassName} {...props} />;
    };
  };
}
// StyleRegistry.js
import React from 'react';

import StyleInserter from './StyleInserter';

export const cache = React.cache(() => {
  return [];
});

function StyleRegistry({ children }) {
  const collectedStyles = cache();

  return (
    <>
      <StyleInserter styles={collectedStyles} />
      {children}
    </>
  );
}

export default StyleRegistry;
// StyleInserter.js
'use client';

import React from 'react';
import { useServerInsertedHTML } from 'next/navigation';

function StyleInserter({ styles }) {
  useServerInsertedHTML(() => {
    return <style>{styles.join('\n')}</style>;
  });

  return null;
}

export default StyleInserter;

Let’s break it down piece by piece.

// StyledRegistry.js
export const cache = React.cache(() => {
  return [];
});
// styled.js
// ...
let collectedStyles =  cache();
// ...

Looking at the first part of styled, it imports the cache created with React.cache from StyleRegistry.

// styled.js
// ...
const id = React.useId().replace(/:/g, '');
const generatedClassName = `styled-${id}`;

const styleContent = `.${generatedClassName} { ${css} }`;

collectedStyles.push(styleContent);
return <Tag className={generatedClassName} {...props} />;

Here, a unique id is generated to create a className, then the CSS is constructed and pushed into the cache. The component is then returned with that className.

// StyleRegistry.js (server)
function StyleRegistry({ children }) {
  const collectedStyles = cache();

  return (
    <>
      <StyleInserter styles={collectedStyles} />
      {children}
    </>
  );
}

In StyleRegistry, the cache is retrieved and the collected CSS is passed to StyleInserter.

// StyleInserter.js
'use client';

import React from 'react';
import { useServerInsertedHTML } from 'next/navigation';

function StyleInserter({ styles }) {

  useServerInsertedHTML(() => {
    return <style>{styles.join('\n')}</style>;
  });

  return null;
}

export default StyleInserter;

At runtime, useServerInsertedHTML is used to inject the received CSS as a <style> tag. This way, we can create styled components that also work in server components.

However, there is a major issue.

// src/components/CountButton.js
'use client';

import React from 'react';
import styled from '../styled.js';

export default function CountButton() {
  const [count, setCount] = React.useState(0);
  return (
    <Button onClick={() => setCount(count + 1)}>
      Clicks: {count}
    </Button>
  );
}

// Currently, this doesn't work, because `cache()` can't be used in
// Client Components. It throws an error, and none of the styles get
// created.
const Button = styled('button', 'client')`
  padding: 1rem 2rem;
  color: red;
  font-size: 1rem;
`;

cache cannot be used in client components. But it’s only cache that can’t be used — if we skip or replace the role of cache, it should work.

Making It Work in Client Components Too

There may be performance implications, so be cautious when using this in practice.

Let’s modify it to inject CSS directly without using cache.

// clientStyled.js
import React from 'react';
import { useServerInsertedHTML } from 'next/navigation';

export default function styled(Tag) {
  return (css) => {
    return function StyledComponent(props) {
      const id = React.useId().replace(/:/g, '');
      const generatedClassName = `styled-${id}`;

      const styleContent = `.${generatedClassName} { ${css} }`;
      useServerInsertedHTML(() => {
        return <style>{styleContent}</style>;
      });
      return <Tag className={generatedClassName} {...props} />;
    };
  };
}

Then I simply added another argument to branch on. There’s probably a way to detect the origin automatically… but let’s go with this for now.

// styled.js
import React from 'react';

import serverStyled from './serverStyled';
import clientStyled from './clientStyled';


export default function styled(Tag, from = 'server') {
  if (from === 'client'){
    return clientStyled(Tag);
  }
  return serverStyled(Tag);
}

With this, passing 'client' when running on the client side does work. Looking at the end result, though, it’s a bit disappointing that it essentially turned into something not so different from 'use client'.

React.cache

https://react.dev/reference/react/cache

https://github.com/facebook/react/blob/main/packages/react/src/ReactCacheServer.js

Let’s take a brief look at how cache works internally. You pass it a function, and it returns a wrapped version with a caching layer. When the returned function is called, it checks the cache layer first and returns the cached result if available.

import ReactCurrentCache from './ReactCurrentCache';

const UNTERMINATED = 0; // Constant representing unterminated state
const TERMINATED = 1; // Constant representing terminated state
const ERRORED = 2; // Constant representing errored state

type UnterminatedCacheNode<T> = {
  s: 0, // Status (unterminated)
  v: void, // Value (no value in unterminated state)
  o: null | WeakMap<Function | Object, CacheNode<T>>, // Object cache (uses WeakMap)
  p: null | Map<string | number | null | void | symbol | boolean, CacheNode<T>>, // Primitive cache (uses Map)
};

type TerminatedCacheNode<T> = {
  s: 1, // Status (terminated)
  v: T, // Value (cached result)
  o: null | WeakMap<Function | Object, CacheNode<T>>, // Object cache
  p: null | Map<string | number | null | void | symbol | boolean, CacheNode<T>>, // Primitive cache
};

type ErroredCacheNode<T> = {
  s: 2, // Status (errored)
  v: mixed, // Value (error object)
  o: null | WeakMap<Function | Object, CacheNode<T>>, // Object cache
  p: null | Map<string | number | null | void | symbol | boolean, CacheNode<T>>, // Primitive cache
};

type CacheNode<T> =
  | TerminatedCacheNode<T>
  | UnterminatedCacheNode<T>
  | ErroredCacheNode<T>;

function createCacheRoot<T>(): WeakMap<Function | Object, CacheNode<T>> {
  return new WeakMap(); // Create a new WeakMap as the cache root
}

function createCacheNode<T>(): CacheNode<T> {
  return {
    s: UNTERMINATED, // Default state is unterminated
    v: undefined, // Initial value is undefined
    o: null, // Object cache initialized
    p: null, // Primitive cache initialized
  };
}

/*
  * Function that creates cache nodes and stores them in the cache root
  *
  * @param fn Function to create cache node for
  * @returns Cache node
  * @example
  * const styleCache = cache(() => []);
  */
export function cache<A: Iterable<mixed>, T>(fn: (...A) => T): (...A) => T {
  return function () {
    const dispatcher = ReactCurrentCache.current; // Get the current cache dispatcher
    if (!dispatcher) {
      // If no dispatcher, execute function without caching (when running in client components)
      return fn.apply(null, arguments);
    }
    // Get the cache root
    const fnMap: WeakMap<any, CacheNode<T>> = dispatcher.getCacheForType(
      createCacheRoot,
    ); // Create a new cache root if it doesn't exist


    const fnNode = fnMap.get(fn); // Get the cache node for the function
    let cacheNode: CacheNode<T>; // Cache node
    if (fnNode === undefined) {
      cacheNode = createCacheNode(); // Create a new cache node if it doesn't exist
      fnMap.set(fn, cacheNode); // Set the cache node for the function
    } else {
      cacheNode = fnNode; // Use the existing cache node
    }
    for (let i = 0, l = arguments.length; i < l; i++) { // Iterate over all arguments passed to the function
      const arg = arguments[i]; // Current argument being processed

      if (
        typeof arg === 'function' ||
        (typeof arg === 'object' && arg !== null)
        // If the current argument is an object or function
        let objectCache = cacheNode.o; // Cache map for objects (WeakMap)
        if (objectCache === null) {
          cacheNode.o = objectCache = new WeakMap(); // Create a new object cache map if it doesn't exist
        }
        const objectNode = objectCache.get(arg); // Look up cache node for the current object argument
        if (objectNode === undefined) {
          cacheNode = createCacheNode(); // Create a new cache node if it doesn't exist
          objectCache.set(arg, cacheNode); // Add new cache node to the object cache map
        } else {
          cacheNode = objectNode; // Use the existing cache node
        }
      } else {
        // If the current argument is a primitive type
        let primitiveCache = cacheNode.p; // Cache map for primitives (Map)
        if (primitiveCache === null) {
          cacheNode.p = primitiveCache = new Map(); // Create a new primitive cache map if it doesn't exist
        }
        const primitiveNode = primitiveCache.get(arg); // Look up cache node for the current primitive argument
        if (primitiveNode === undefined) {
          cacheNode = createCacheNode(); // Create a new cache node if it doesn't exist
          primitiveCache.set(arg, cacheNode); // Add new cache node to the primitive cache map
        } else {
          cacheNode = primitiveNode; // Use the existing cache node
        }
      }
    }

    if (cacheNode.s === TERMINATED) { // If there is a cached result
      return cacheNode.v; // Return the cached result
    }
    if (cacheNode.s === ERRORED) { // If there is a cached error
      throw cacheNode.v; // Re-throw the cached error
    }
    try {
      // If no cached result, execute the function and cache the result
      const result = fn.apply(null, arguments); // Execute the function
      const terminatedNode: TerminatedCacheNode<T> = (cacheNode: any); // Change cache node to terminated state
      terminatedNode.s = TERMINATED; // Set to terminated state
      terminatedNode.v = result; // Set the cached result
      return result; // Return the result
    } catch (error) {
      // Cache the error if one occurs
      const erroredNode: ErroredCacheNode<T> = (cacheNode: any);
      erroredNode.s = ERRORED;
      erroredNode.v = error;
      throw error;
    }
  };
}

Here’s a simple example to illustrate.

import { cache } from 'react';

// An example computation function
function expensiveCalculation(x, y) {
  console.log('Calculating result...');
  return x + y;
}

// Wrap the computation function using cache
const cachedCalculation = cache(expensiveCalculation);

// First call - the computation function is executed
const result1 = cachedCalculation(2, 3); // Logs: "Calculating result..."

// Second call - same arguments, so the cached result is returned
const result2 = cachedCalculation(2, 3); // Result is 5, no log output

// Call with different arguments - the computation function runs again
const result3 = cachedCalculation(4, 5); // Logs: "Calculating result..."

console.log(result1); // 5
console.log(result2); // 5
console.log(result3); // 9

With this example, the role of cache is immediately clear. However, the usage in dream-css-tool is slightly different. Rather than saving computation on each call, it uses cache as a way to maintain a shared reference.

// StyledRegistry.js
export const cache = React.cache(() => {
  return [];
});
// ...
function StyleRegistry({ children }) {
  const collectedStyles = cache();
...
// styled.js
import { cache } from './components/StyleRegistry';

...
let collectedStyles =  cache();

collectedStyles.push(styleContent);
...

Simply put, it works like this:

const arr1 = [1, 2, 3];

const arr2 = arr1;

// Check if arr1 and arr2 reference the same array
console.log(arr1 === arr2); // true

// Adding an element to arr1 is also reflected in arr2
arr1.push(4);
console.log(arr2); // [1, 2, 3, 4]

Conclusion

Work in progress.