Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,21 @@ public void accept(int messageCount, ByteBuffer buffer) {
if (response.success()) {
if (log.isDebugEnabled()) {
log.debug(
"Successfully sent {} traces {} bytes to the API {}",
"Successfully sent {} traces of size {} bytes to the API {}",
messageCount,
buffer.position(),
sizeInBytes,
mapper.endpoint());
}
healthMetrics.onSend(messageCount, sizeInBytes, response);
} else {
if (log.isDebugEnabled()) {
log.debug(
"Failed to send {} traces of size {} bytes to the API {}",
"Failed to send {} traces of size {} bytes to the API {} status {} response {}",
messageCount,
sizeInBytes,
mapper.endpoint());
mapper.endpoint(),
response.status(),
response.response());
}
healthMetrics.onFailedSend(messageCount, sizeInBytes, response);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import static datadog.communication.http.OkHttpUtils.gzippedMsgpackRequestBodyOf;

import datadog.communication.serialization.GrowableBuffer;
import datadog.communication.serialization.Writable;
import datadog.communication.serialization.msgpack.MsgPackWriter;
import datadog.trace.api.DDTags;
import datadog.trace.api.intake.TrackType;
import datadog.trace.api.llmobs.LLMObs;
Expand Down Expand Up @@ -77,15 +79,30 @@ public class LLMObsSpanMapper implements RemoteMapper {

private static final String PARENT_ID_TAG_INTERNAL_FULL = LLMOBS_TAG_PREFIX + "parent_id";

private final LLMObsSpanMapper.MetaWriter metaWriter = new MetaWriter();
private final MetaWriter metaWriter = new MetaWriter();
private final int size;

private final ByteBuffer header;
private int spansWritten;

public LLMObsSpanMapper() {
this(5 << 20);
}

private LLMObsSpanMapper(int size) {
this.size = size;

GrowableBuffer header = new GrowableBuffer(64);
MsgPackWriter headerWriter = new MsgPackWriter(header);

headerWriter.startMap(3);
headerWriter.writeUTF8(EVENT_TYPE);
headerWriter.writeString("span", null);
headerWriter.writeUTF8(STAGE);
headerWriter.writeString("raw", null);
headerWriter.writeUTF8(SPANS);

this.header = header.slice();
}

@Override
Expand All @@ -98,16 +115,6 @@ public void map(List<? extends CoreSpan<?>> trace, Writable writable) {
return;
}

writable.startMap(3);

writable.writeUTF8(EVENT_TYPE);
writable.writeString("span", null);

writable.writeUTF8(STAGE);
writable.writeString("raw", null);

writable.writeUTF8(SPANS);
writable.startArray(llmobsSpans.size());
for (CoreSpan<?> span : llmobsSpans) {
writable.startMap(11);
// 1
Expand Down Expand Up @@ -148,6 +155,10 @@ public void map(List<? extends CoreSpan<?>> trace, Writable writable) {
/* 9 (metrics), 10 (tags), 11 meta */
span.processTagsAndBaggage(metaWriter.withWritable(writable, getErrorsMap(span)));
}

// Increase only after all spans have been written. This way, if it rolls back because of a
// buffer overflow, the counter won't be skewed.
spansWritten += llmobsSpans.size();
}

private static boolean isLLMObsSpan(CoreSpan<?> span) {
Expand All @@ -157,7 +168,7 @@ private static boolean isLLMObsSpan(CoreSpan<?> span) {

@Override
public Payload newPayload() {
return new PayloadV1();
return new PayloadV1(header, spansWritten);
}

@Override
Expand All @@ -166,7 +177,10 @@ public int messageBufferSize() {
}

@Override
public void reset() {}
public void reset() {
// Reset the number of spans per message with each flush.
spansWritten = 0;
}

@Override
public String endpoint() {
Expand Down Expand Up @@ -206,7 +220,7 @@ private static final class MetaWriter implements MetadataConsumer {
LLMOBS_TAG_PREFIX + LLMObsTags.MODEL_VERSION,
LLMOBS_TAG_PREFIX + LLMObsTags.METADATA)));

LLMObsSpanMapper.MetaWriter withWritable(Writable writable, Map<String, String> errorInfo) {
MetaWriter withWritable(Writable writable, Map<String, String> errorInfo) {
this.writable = writable;
this.errorInfo = errorInfo;
return this;
Expand Down Expand Up @@ -348,14 +362,20 @@ public void accept(Metadata metadata) {
}

private static class PayloadV1 extends Payload {
private final ByteBuffer header;
private final int spansWritten;

public PayloadV1(ByteBuffer header, int spansWritten) {
this.spansWritten = spansWritten;
this.header = header;
}

@Override
public int sizeInBytes() {
if (traceCount() == 0) {
return msgpackMapHeaderSize(0);
}

return body.array().length;
return header.remaining() + msgpackArrayHeaderSize(spansWritten) + body.remaining();
}

@Override
Expand All @@ -368,6 +388,8 @@ public void writeTo(WritableByteChannel channel) throws IOException {
}
} else {
while (body.hasRemaining()) {
channel.write(header.slice());
channel.write(msgpackArrayHeader(spansWritten));
channel.write(body);
}
}
Expand All @@ -379,9 +401,13 @@ public RequestBody toRequest() {
if (traceCount() == 0) {
buffers = Collections.singletonList(msgpackMapHeader(0));
} else {
buffers = Collections.singletonList(body);
buffers =
Arrays.asList(
header.slice(),
// Third Value: is an array of spans serialized into the body
msgpackArrayHeader(spansWritten),
body);
}

return gzippedMsgpackRequestBodyOf(buffers);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ class LLMObsSpanMapperTest extends DDCoreSpecification {
sink.captured != null
def payload = mapper.newPayload()
payload.withBody(1, sink.captured)

// Capture the size before the buffer is written and the body buffer is emptied.
def sizeInBytes = payload.sizeInBytes()

def channel = new ByteArrayOutputStream()
payload.writeTo(new WritableByteChannel() {
@Override
Expand All @@ -76,7 +80,10 @@ class LLMObsSpanMapperTest extends DDCoreSpecification {
@Override
void close() throws IOException { }
})
def result = objectMapper.readValue(channel.toByteArray(), Map)

def bytesWritten = channel.toByteArray()
sizeInBytes == bytesWritten.length
def result = objectMapper.readValue(bytesWritten, Map)

then:
result.containsKey("event_type")
Expand Down Expand Up @@ -140,6 +147,93 @@ class LLMObsSpanMapperTest extends DDCoreSpecification {
sink.captured == null
}

def "test consecutive packer.format calls accumulate spans from multiple traces"() {
setup:
def mapper = new LLMObsSpanMapper()
def tracer = tracerBuilder().writer(new ListWriter()).build()

// First trace with 2 LLMObs spans
def llmSpan1 = tracer.buildSpan("chat-completion-1")
.withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND)
.withTag("_ml_obs_tag.model_name", "gpt-4")
.withTag("_ml_obs_tag.model_provider", "openai")
.start()
llmSpan1.setSpanType(InternalSpanTypes.LLMOBS)
llmSpan1.finish()

def llmSpan2 = tracer.buildSpan("chat-completion-2")
.withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND)
.withTag("_ml_obs_tag.model_name", "gpt-3.5")
.withTag("_ml_obs_tag.model_provider", "openai")
.start()
llmSpan2.setSpanType(InternalSpanTypes.LLMOBS)
llmSpan2.finish()

// Second trace with 1 LLMObs span
def llmSpan3 = tracer.buildSpan("chat-completion-3")
.withTag("_ml_obs_tag.span.kind", Tags.LLMOBS_LLM_SPAN_KIND)
.withTag("_ml_obs_tag.model_name", "claude-3")
.withTag("_ml_obs_tag.model_provider", "anthropic")
.start()
llmSpan3.setSpanType(InternalSpanTypes.LLMOBS)
llmSpan3.finish()

def trace1 = [llmSpan1, llmSpan2]
def trace2 = [llmSpan3]
CapturingByteBufferConsumer sink = new CapturingByteBufferConsumer()
MsgPackWriter packer = new MsgPackWriter(new FlushingBuffer(1024, sink))

when:
packer.format(trace1, mapper)
packer.format(trace2, mapper)
packer.flush()

then:
sink.captured != null
def payload = mapper.newPayload()
payload.withBody(3, sink.captured)

// Capture the size before the buffer is written and the body buffer is emptied.
def sizeInBytes = payload.sizeInBytes()

def channel = new ByteArrayOutputStream()
payload.writeTo(new WritableByteChannel() {
@Override
int write(ByteBuffer src) throws IOException {
def bytes = new byte[src.remaining()]
src.get(bytes)
channel.write(bytes)
return bytes.length
}

@Override
boolean isOpen() {
return true
}

@Override
void close() throws IOException { }
})

def bytesWritten = channel.toByteArray()
sizeInBytes == bytesWritten.length
def result = objectMapper.readValue(bytesWritten, Map)

then:
result.containsKey("event_type")
result["event_type"] == "span"
result.containsKey("_dd.stage")
result["_dd.stage"] == "raw"
result.containsKey("spans")
result["spans"] instanceof List
result["spans"].size() == 3

def spanNames = result["spans"].collect { it["name"] }
spanNames.contains("chat-completion-1")
spanNames.contains("chat-completion-2")
spanNames.contains("chat-completion-3")
}

static class CapturingByteBufferConsumer implements ByteBufferConsumer {

ByteBuffer captured
Expand Down