diff --git a/Memola/Canvas/Geometries/Stroke/Stroke.swift b/Memola/Canvas/Geometries/Stroke/Stroke.swift index 32081b2..8018da1 100644 --- a/Memola/Canvas/Geometries/Stroke/Stroke.swift +++ b/Memola/Canvas/Geometries/Stroke/Stroke.swift @@ -56,8 +56,9 @@ final class Stroke: @unchecked Sendable { let movingAverage = MovingAverage(windowSize: 3) - var vertexBuffer: MTLBuffer? var texture: MTLTexture? + var indexBuffer: MTLBuffer? + var vertexBuffer: MTLBuffer? var isEmpty: Bool { quads.isEmpty @@ -160,12 +161,19 @@ extension Stroke: Drawable { } func draw(device: MTLDevice, renderEncoder: MTLRenderCommandEncoder) { - guard !isEmpty else { return } + guard !isEmpty, let indexBuffer else { return } prepare(device: device) renderEncoder.setFragmentTexture(texture, index: 0) renderEncoder.setVertexBuffer(vertexBuffer, offset: 0, index: 0) - renderEncoder.drawPrimitives(type: .triangle, vertexStart: 0, vertexCount: quads.endIndex * 6) - vertexBuffer = nil + renderEncoder.drawIndexedPrimitives( + type: .triangle, + indexCount: quads.endIndex * 6, + indexType: .uint32, + indexBuffer: indexBuffer, + indexBufferOffset: 0 + ) + self.vertexBuffer = nil + self.indexBuffer = nil } } diff --git a/Memola/Canvas/RenderPasses/EraserRenderPass.swift b/Memola/Canvas/RenderPasses/EraserRenderPass.swift index 46bbe54..32f115d 100644 --- a/Memola/Canvas/RenderPasses/EraserRenderPass.swift +++ b/Memola/Canvas/RenderPasses/EraserRenderPass.swift @@ -57,12 +57,15 @@ class EraserRenderPass: RenderPass { let quadCount = stroke.quads.endIndex var quads = stroke.quads let quadBuffer = renderer.device.makeBuffer(bytes: &quads, length: MemoryLayout.stride * quadCount, options: []) - let vertexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 6, options: []) + let indexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 6, options: []) + let vertexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 4, options: []) computeEncoder.setComputePipelineState(quadPipelineState) computeEncoder.setBuffer(quadBuffer, offset: 0, index: 0) - computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 1) + computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1) + computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 2) + stroke.indexBuffer = indexBuffer stroke.vertexBuffer = vertexBuffer let threadsPerGroup = MTLSize(width: 1, height: 1, depth: 1) diff --git a/Memola/Canvas/RenderPasses/StrokeRenderPass.swift b/Memola/Canvas/RenderPasses/StrokeRenderPass.swift index c268735..89945af 100644 --- a/Memola/Canvas/RenderPasses/StrokeRenderPass.swift +++ b/Memola/Canvas/RenderPasses/StrokeRenderPass.swift @@ -70,12 +70,15 @@ class StrokeRenderPass: RenderPass { let quadCount = stroke.quads.endIndex var quads = stroke.quads let quadBuffer = renderer.device.makeBuffer(bytes: &quads, length: MemoryLayout.stride * quadCount, options: []) - let vertexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 6, options: []) + let indexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 6, options: []) + let vertexBuffer = renderer.device.makeBuffer(length: MemoryLayout.stride * quadCount * 4, options: []) computeEncoder.setComputePipelineState(quadPipelineState) computeEncoder.setBuffer(quadBuffer, offset: 0, index: 0) - computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 1) + computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1) + computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 2) + stroke.indexBuffer = indexBuffer stroke.vertexBuffer = vertexBuffer let threadsPerGroup = MTLSize(width: 1, height: 1, depth: 1) diff --git a/Memola/Canvas/Shaders/Quad.metal b/Memola/Canvas/Shaders/Quad.metal index 8bd9cc8..898c3d0 100644 --- a/Memola/Canvas/Shaders/Quad.metal +++ b/Memola/Canvas/Shaders/Quad.metal @@ -39,16 +39,23 @@ Vertex createVertex(Quad quad, float2 factor, float2 textCoord) { kernel void generate_stroke_vertices( device Quad *quads [[buffer(0)]], - device Vertex *vertices [[buffer(1)]], + device uint *indices [[buffer(1)]], + device Vertex *vertices [[buffer(2)]], uint gid [[thread_position_in_grid]] ) { - uint index = gid * 6; Quad quad = quads[gid]; + uint index = gid * 6; + uint vertexIndex = gid * 4; float halfSize = quad.size * 0.5; - vertices[index] = createVertex(quad, float2(-halfSize, -halfSize), float2(0, 0)); - vertices[index + 1] = createVertex(quad, float2(halfSize, -halfSize), float2(1, 0)); - vertices[index + 2] = createVertex(quad, float2(-halfSize, halfSize), float2(0, 1)); - vertices[index + 3] = createVertex(quad, float2(halfSize, -halfSize), float2(1, 0)); - vertices[index + 4] = createVertex(quad, float2(-halfSize, halfSize), float2(0, 1)); - vertices[index + 5] = createVertex(quad, float2(halfSize, halfSize), float2(1, 1)); + vertices[vertexIndex] = createVertex(quad, float2(-halfSize, -halfSize), float2(0, 0)); + vertices[vertexIndex + 1] = createVertex(quad, float2(halfSize, -halfSize), float2(1, 0)); + vertices[vertexIndex + 2] = createVertex(quad, float2(-halfSize, halfSize), float2(0, 1)); + vertices[vertexIndex + 3] = createVertex(quad, float2(halfSize, halfSize), float2(1, 1)); + + indices[index] = vertexIndex; + indices[index + 1] = vertexIndex + 1; + indices[index + 2] = vertexIndex + 2; + indices[index + 3] = vertexIndex + 1; + indices[index + 4] = vertexIndex + 2; + indices[index + 5] = vertexIndex + 3; }