incremental bindings index

pull/3685/head
David Sheldrick 2024-05-02 16:56:55 +01:00
rodzic da35f2bd75
commit 004945611c
3 zmienionych plików z 418 dodań i 65 usunięć

Wyświetl plik

@ -120,6 +120,7 @@ import { getReorderingShapesChanges } from '../utils/reorderShapes'
import { applyRotationToSnapshotShapes, getRotationSnapshot } from '../utils/rotation'
import { uniqueId } from '../utils/uniqueId'
import { BindingUtil, TLBindingUtilConstructor } from './bindings/BindingUtil'
import { bindingsIndex } from './derivations/bindingsIndex'
import { notVisibleShapes } from './derivations/notVisibleShapes'
import { parentsToChildren } from './derivations/parentsToChildren'
import { deriveShapeIdsInCurrentPage } from './derivations/shapeIdsInCurrentPage'
@ -379,19 +380,21 @@ export class Editor extends EventEmitter<TLEventMap> {
this.sideEffects.register({
shape: {
afterChange: (shapeBefore, shapeAfter) => {
for (const binding of this.getAllBindingsFromShape(shapeAfter)) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
shapeBefore,
shapeAfter,
})
}
for (const binding of this.getAllBindingsToShape(shapeAfter)) {
this.getBindingUtil(binding).onAfterChangeToShape?.({
binding,
shapeBefore,
shapeAfter,
})
for (const binding of this.getBindingsInvolvingShape(shapeAfter)) {
if (binding.fromId === shapeAfter.id) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
shapeBefore,
shapeAfter,
})
}
if (binding.toId === shapeAfter.id) {
this.getBindingUtil(binding).onAfterChangeToShape?.({
binding,
shapeBefore,
shapeAfter,
})
}
}
// if the shape's parent changed and it has a binding, update the binding
@ -400,19 +403,21 @@ export class Editor extends EventEmitter<TLEventMap> {
const descendantShape = this.getShape(id)
if (!descendantShape) return
for (const binding of this.getAllBindingsFromShape(descendantShape)) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
shapeBefore: descendantShape,
shapeAfter: descendantShape,
})
}
for (const binding of this.getAllBindingsToShape(descendantShape)) {
this.getBindingUtil(binding).onAfterChangeToShape?.({
binding,
shapeBefore: descendantShape,
shapeAfter: descendantShape,
})
for (const binding of this.getBindingsInvolvingShape(descendantShape)) {
if (binding.fromId === descendantShape.id) {
this.getBindingUtil(binding).onAfterChangeFromShape?.({
binding,
shapeBefore: descendantShape,
shapeAfter: descendantShape,
})
}
if (binding.toId === descendantShape.id) {
this.getBindingUtil(binding).onAfterChangeToShape?.({
binding,
shapeBefore: descendantShape,
shapeAfter: descendantShape,
})
}
}
}
notifyBindingAncestryChange(shapeAfter.id)
@ -451,13 +456,15 @@ export class Editor extends EventEmitter<TLEventMap> {
}
const deleteBindingIds: TLBindingId[] = []
for (const binding of this.getAllBindingsFromShape(shape)) {
this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape })
deleteBindingIds.push(binding.id)
}
for (const binding of this.getAllBindingsToShape(shape)) {
this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape })
deleteBindingIds.push(binding.id)
for (const binding of this.getBindingsInvolvingShape(shape)) {
if (binding.fromId === shape.id) {
this.getBindingUtil(binding).onBeforeDeleteFromShape?.({ binding, shape })
deleteBindingIds.push(binding.id)
}
if (binding.toId === shape.id) {
this.getBindingUtil(binding).onBeforeDeleteToShape?.({ binding, shape })
deleteBindingIds.push(binding.id)
}
}
this.deleteBindings(deleteBindingIds)
@ -5032,6 +5039,15 @@ export class Editor extends EventEmitter<TLEventMap> {
/* -------------------- Bindings -------------------- */
@computed
private _getBindingsIndex() {
return bindingsIndex(this)
}
private getBindingsIndex() {
return this._getBindingsIndex().get()
}
getBinding(id: TLBindingId): TLBinding | undefined {
return this.store.get(id) as TLBinding | undefined
}
@ -5039,35 +5055,26 @@ export class Editor extends EventEmitter<TLEventMap> {
// TODO(alex) #bindings - cache `allBindings` getters and derive type-specific ones from them
getBindingsFromShape<Binding extends TLUnknownBinding = TLBinding>(
shape: TLShape | TLShapeId,
type: Binding['type']
type?: Binding['type']
): Binding[] {
const id = typeof shape === 'string' ? shape : shape.id
return this.store.query.exec('binding', {
fromId: { eq: id },
type: { eq: type },
}) as Binding[]
return this.getBindingsInvolvingShape(id, type).filter((b) => b.fromId === id) as Binding[]
}
getBindingsToShape<Binding extends TLUnknownBinding = TLBinding>(
shape: TLShape | TLShapeId,
type: Binding['type']
type?: Binding['type']
): Binding[] {
const id = typeof shape === 'string' ? shape : shape.id
return this.store.query.exec('binding', {
toId: { eq: id },
type: { eq: type },
}) as Binding[]
return this.getBindingsInvolvingShape(id, type).filter((b) => b.toId === id) as Binding[]
}
getAllBindingsFromShape(shape: TLShape | TLShapeId): TLBinding[] {
getBindingsInvolvingShape<Binding extends TLUnknownBinding = TLBinding>(
shape: TLShape | TLShapeId,
type?: Binding['type']
): Binding[] {
const id = typeof shape === 'string' ? shape : shape.id
return this.store.query.exec('binding', {
fromId: { eq: id },
})
}
getAllBindingsToShape(shape: TLShape | TLShapeId): TLBinding[] {
const id = typeof shape === 'string' ? shape : shape.id
return this.store.query.exec('binding', {
toId: { eq: id },
})
const result = this.getBindingsIndex()[id] ?? EMPTY_ARRAY
if (!type) return result as Binding[]
return result.filter((b) => b.type === type) as Binding[]
}
createBindings(partials: RequiredKeys<TLBindingPartial, 'type' | 'toId' | 'fromId'>[]) {
@ -8744,22 +8751,19 @@ function withoutBindingsToUnrelatedShapes<T>(
const shape = editor.getShape(shapeId)
if (!shape) continue
for (const binding of editor.getAllBindingsFromShape(shapeId)) {
if (shapeIds.has(binding.toId)) {
// if we have both sides of the binding, we want to recreate it
for (const binding of editor.getBindingsInvolvingShape(shapeId)) {
const hasFrom = shapeIds.has(binding.fromId)
const hasTo = shapeIds.has(binding.toId)
if (hasFrom && hasTo) {
bindingsWithBoth.add(binding.id)
} else {
// otherwise, if we only have one side, we need to record that and duplicate
// the shape as if the one it's bound to has been deleted
bindingsWithoutTo.add(binding.id)
continue
}
}
for (const binding of editor.getAllBindingsToShape(shapeId)) {
if (shapeIds.has(binding.fromId)) {
bindingsWithBoth.add(binding.id)
} else {
if (!hasFrom) {
bindingsWithoutFrom.add(binding.id)
}
if (!hasTo) {
bindingsWithoutTo.add(binding.id)
}
}
}

Wyświetl plik

@ -0,0 +1,89 @@
import { Computed, RESET_VALUE, computed, isUninitialized } from '@tldraw/state'
import { TLBinding, TLShapeId } from '@tldraw/tlschema'
import { objectMapValues } from '@tldraw/utils'
import { Editor } from '../Editor'
type TLBindingsIndex = Record<TLShapeId, undefined | TLBinding[]>
export const bindingsIndex = (editor: Editor): Computed<TLBindingsIndex> => {
const { store } = editor
const bindingsHistory = store.query.filterHistory('binding')
const bindingsQuery = store.query.records('binding')
function fromScratch() {
const allBindings = bindingsQuery.get() as TLBinding[]
const shape2Binding: TLBindingsIndex = {}
for (const binding of allBindings) {
const { fromId, toId } = binding
const bindingsForFromShape = (shape2Binding[fromId] ??= [])
bindingsForFromShape.push(binding)
const bindingsForToShape = (shape2Binding[toId] ??= [])
bindingsForToShape.push(binding)
}
return shape2Binding
}
return computed<TLBindingsIndex>('arrowBindingsIndex', (_lastValue, lastComputedEpoch) => {
if (isUninitialized(_lastValue)) {
return fromScratch()
}
const lastValue = _lastValue
const diff = bindingsHistory.getDiffSince(lastComputedEpoch)
if (diff === RESET_VALUE) {
return fromScratch()
}
let nextValue: TLBindingsIndex | undefined = undefined
function removingBinding(binding: TLBinding) {
nextValue ??= { ...lastValue }
nextValue[binding.fromId] = nextValue[binding.fromId]?.filter((b) => b.id !== binding.id)
if (!nextValue[binding.fromId]?.length) {
delete nextValue[binding.fromId]
}
nextValue[binding.toId] = nextValue[binding.toId]?.filter((b) => b.id !== binding.id)
if (!nextValue[binding.toId]?.length) {
delete nextValue[binding.toId]
}
}
function ensureNewArray(shapeId: TLShapeId) {
nextValue ??= { ...lastValue }
if (!nextValue[shapeId]) {
nextValue[shapeId] = []
} else if (nextValue[shapeId] === lastValue[shapeId]) {
nextValue[shapeId] = nextValue[shapeId]!.slice(0)
}
}
function addBinding(binding: TLBinding) {
ensureNewArray(binding.fromId)
ensureNewArray(binding.toId)
nextValue![binding.fromId]!.push(binding)
nextValue![binding.toId]!.push(binding)
}
for (const changes of diff) {
for (const newBinding of objectMapValues(changes.added)) {
addBinding(newBinding)
}
for (const [prev, next] of objectMapValues(changes.updated)) {
removingBinding(prev)
addBinding(next)
}
for (const prev of objectMapValues(changes.removed)) {
removingBinding(prev)
}
}
// TODO: add diff entries if we need them
return nextValue ?? lastValue
})
}

Wyświetl plik

@ -0,0 +1,260 @@
import { TLArrowBinding, TLGeoShape, TLShapeId, createShapeId } from '@tldraw/editor'
import { TestEditor } from './TestEditor'
import { TL } from './test-jsx'
let editor: TestEditor
beforeEach(() => {
editor = new TestEditor()
})
describe('bindingsIndex', () => {
it('keeps a mapping from bound shapes to their bindings', () => {
const ids = editor.createShapesFromJsx([
<TL.geo ref="box1" x={0} y={0} w={100} h={100} fill="solid" />,
<TL.geo ref="box2" x={200} y={0} w={100} h={100} fill="solid" />,
])
editor.selectNone()
editor.setCurrentTool('arrow')
editor.pointerDown(50, 50)
expect(editor.getOnlySelectedShape()).toBe(null)
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([])
editor.pointerMove(50, 55)
expect(editor.getOnlySelectedShape()).not.toBe(null)
const arrow = editor.getOnlySelectedShape()!
expect(arrow.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow])
editor.pointerMove(250, 50)
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([editor.getShape(arrow.id)])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([editor.getShape(arrow.id)])
})
it('works if there are many arrows', () => {
const ids = {
box1: createShapeId('box1'),
box2: createShapeId('box2'),
}
editor.createShapes([
{ type: 'geo', id: ids.box1, x: 0, y: 0, props: { w: 100, h: 100 } },
{ type: 'geo', id: ids.box2, x: 200, y: 0, props: { w: 100, h: 100 } },
])
editor.setCurrentTool('arrow')
// start at box 1 and end on box 2
editor.pointerDown(50, 50)
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([])
editor.pointerMove(250, 50)
const arrow1 = editor.getOnlySelectedShape()!
expect(arrow1.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1])
editor.pointerUp()
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1])
// start at box 1 and end on the page
editor.setCurrentTool('arrow')
editor.pointerMove(50, 50).pointerDown().pointerMove(50, -50).pointerUp()
const arrow2 = editor.getOnlySelectedShape()!
expect(arrow2.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2])
// start outside box 1 and end in box 1
editor.setCurrentTool('arrow')
editor.pointerDown(0, -50).pointerMove(50, 50).pointerUp(50, 50)
const arrow3 = editor.getOnlySelectedShape()!
expect(arrow3.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1])
// start at box 2 and end on the page
editor.selectNone()
editor.setCurrentTool('arrow')
editor.pointerDown(250, 50)
editor.expectToBeIn('arrow.pointing')
editor.pointerMove(250, -50)
editor.expectToBeIn('select.dragging_handle')
const arrow4 = editor.getOnlySelectedShape()!
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4])
editor.pointerUp(250, -50)
editor.expectToBeIn('select.idle')
expect(arrow4.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4])
// start outside box 2 and enter in box 2
editor.setCurrentTool('arrow')
editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50)
const arrow5 = editor.getOnlySelectedShape()!
expect(arrow5.type).toBe('arrow')
expect(editor.getArrowsBoundTo(ids.box1)).toEqual([arrow1, arrow2, arrow3])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([arrow1, arrow4, arrow5])
})
describe('updating shapes', () => {
// ▲ │ │ ▲
// │ │ │ │
// b c e d
// ┌───┼─┴─┐ ┌──┴──┼─┐
// │ │ ▼ │ │ ▼ │ │
// │ └───┼─────a───┼───► │ │
// │ 1 │ │ 2 │
// └───────┘ └───────┘
let arrowAId: TLShapeId
let arrowBId: TLShapeId
let arrowCId: TLShapeId
let arrowDId: TLShapeId
let arrowEId: TLShapeId
let ids: Record<string, TLShapeId>
beforeEach(() => {
ids = editor.createShapesFromJsx([
<TL.geo ref="box1" x={0} y={0} w={100} h={100} />,
<TL.geo ref="box2" x={200} y={0} w={100} h={100} />,
])
// span both boxes
editor.setCurrentTool('arrow')
editor.pointerDown(50, 50).pointerMove(250, 50).pointerUp(250, 50)
arrowAId = editor.getOnlySelectedShape()!.id
// start at box 1 and leave
editor.setCurrentTool('arrow')
editor.pointerDown(50, 50).pointerMove(50, -50).pointerUp(50, -50)
arrowBId = editor.getOnlySelectedShape()!.id
// start outside box 1 and enter
editor.setCurrentTool('arrow')
editor.pointerDown(50, -50).pointerMove(50, 50).pointerUp(50, 50)
arrowCId = editor.getOnlySelectedShape()!.id
// start at box 2 and leave
editor.setCurrentTool('arrow')
editor.pointerDown(250, 50).pointerMove(250, -50).pointerUp(250, -50)
arrowDId = editor.getOnlySelectedShape()!.id
// start outside box 2 and enter
editor.setCurrentTool('arrow')
editor.pointerDown(250, -50).pointerMove(250, 50).pointerUp(250, 50)
arrowEId = editor.getOnlySelectedShape()!.id
})
it('deletes the entry if you delete the bound shapes', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
editor.deleteShapes([ids.box2])
expect(editor.getArrowsBoundTo(ids.box2)).toEqual([])
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
})
it('deletes the entry if you delete an arrow', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
editor.deleteShapes([arrowEId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
editor.deleteShapes([arrowDId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
editor.deleteShapes([arrowCId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(2)
editor.deleteShapes([arrowBId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(1)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(1)
editor.deleteShapes([arrowAId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0)
})
it('deletes the entries in a batch too', () => {
editor.deleteShapes([arrowAId, arrowBId, arrowCId, arrowDId, arrowEId])
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(0)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(0)
})
it('adds new entries after initial creation', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
// draw from box 2 to box 1
editor.setCurrentTool('arrow')
editor.pointerDown(250, 50).pointerMove(50, 50).pointerUp(50, 50)
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(4)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4)
// create a new box
const { box3 } = editor.createShapesFromJsx(
<TL.geo ref="box3" x={400} y={0} w={100} h={100} />
)
// draw from box 2 to box 3
editor.setCurrentTool('arrow')
editor.pointerDown(250, 50).pointerMove(450, 50).pointerUp(450, 50)
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(5)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(4)
expect(editor.getArrowsBoundTo(box3)).toHaveLength(1)
})
it('works when copy pasting', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
editor.selectAll()
editor.duplicateShapes(editor.getSelectedShapeIds())
const [box1Clone, box2Clone] = editor
.getSelectedShapes()
.filter((shape) => editor.isShapeOfType<TLGeoShape>(shape, 'geo'))
.sort((a, b) => a.x - b.x)
expect(editor.getArrowsBoundTo(box2Clone.id)).toHaveLength(3)
expect(editor.getArrowsBoundTo(box1Clone.id)).toHaveLength(3)
})
it('allows bound shapes to be moved', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
editor.nudgeShapes([ids.box2], { x: 0, y: -1 })
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
})
it('allows the arrows bound shape to change', () => {
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(3)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
// create another box
const { box3 } = editor.createShapesFromJsx(
<TL.geo ref="box3" x={400} y={0} w={100} h={100} />
)
// move arrowA end from box2 to box3
const binding = editor
.getBindingsInvolvingShape<TLArrowBinding>(ids.box2, 'arrow')
.find((b) => b.props.terminal === 'end')!
editor.updateBinding({ ...binding, toId: box3 } satisfies TLArrowBinding)
expect(editor.getArrowsBoundTo(ids.box2)).toHaveLength(2)
expect(editor.getArrowsBoundTo(ids.box1)).toHaveLength(3)
expect(editor.getArrowsBoundTo(box3)).toHaveLength(1)
})
})
})