Author: Adam <git@apiote.xyz>
add bare library
bare/.gitignore | 1 bare/build.gradle | 8 bare/src/main/java/org/nobloat/bare/AggregateBareDecoder.java | 70 + bare/src/main/java/org/nobloat/bare/AggregateBareEncoder.java | 100 ++ bare/src/main/java/org/nobloat/bare/Bare.java | 4 bare/src/main/java/org/nobloat/bare/BareException.java | 8 bare/src/main/java/org/nobloat/bare/Int.java | 16 bare/src/main/java/org/nobloat/bare/PrimitiveBareDecoder.java | 134 +++ bare/src/main/java/org/nobloat/bare/PrimitiveBareEncoder.java | 156 +++ bare/src/main/java/org/nobloat/bare/ReflectiveBareDecoder.java | 165 ++++ bare/src/main/java/org/nobloat/bare/Union.java | 69 +
diff --git a/bare/.gitignore b/bare/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..42afabfd2abebf31384ca7797186a27a4b7dbee8 --- /dev/null +++ b/bare/.gitignore @@ -0,0 +1 @@ +/build \ No newline at end of file diff --git a/bare/build.gradle b/bare/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e493c42fffdc9ab79b54c37722e922f4053069df --- /dev/null +++ b/bare/build.gradle @@ -0,0 +1,8 @@ +plugins { + id 'java-library' +} + +java { + sourceCompatibility = JavaVersion.VERSION_1_7 + targetCompatibility = JavaVersion.VERSION_1_7 +} \ No newline at end of file diff --git a/bare/src/main/java/org/nobloat/bare/AggregateBareDecoder.java b/bare/src/main/java/org/nobloat/bare/AggregateBareDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..54a8f28deecb4900ffc542a76cf0693e58bd202f --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/AggregateBareDecoder.java @@ -0,0 +1,70 @@ +package org.nobloat.bare; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class AggregateBareDecoder extends PrimitiveBareDecoder { + + public int MaxMapLength = 1000000000; + + public AggregateBareDecoder(InputStream inputStream) { + super(inputStream); + } + + public <T> Optional<T> optional(DecodeFunction<T> itemDecoder) throws IOException, BareException { + boolean exists = bool(); + if (exists) { + return Optional.of(itemDecoder.apply(this)); + } + return Optional.empty(); + } + + public <T> List<T> array(int count, DecodeFunction<T> itemDecoder) throws IOException, BareException { + var result = new ArrayList<T>(count); + for (int i=0; i < count; i++) { + result.add(itemDecoder.apply(this)); + } + return result; + } + + public <T> List<T> slice(DecodeFunction<T> itemDecoder) throws IOException, BareException { + var length = variadicUint().intValue(); + if (length > MaxSliceLength) { + throw new BareException(String.format("Decoding slice with entries %d > %d max length", length, MaxSliceLength)); + } + return array(length, itemDecoder); + } + + public <K,V> Map<K,V> map(DecodeFunction<K> keyDecoder, DecodeFunction<V> valueDecoder) throws IOException, BareException { + var length = variadicUint().intValue(); + if (length > MaxMapLength) { + throw new BareException(String.format("Decoding map with entries %d > %d max length", length, MaxSliceLength)); + } + var result = new HashMap<K,V>(); + for (int i=0; i < length; i++) { + result.put(keyDecoder.apply(this), valueDecoder.apply(this)); + } + return result; + } + + public Union union(Map<Integer, DecodeFunction> decodeFunctions) throws IOException, BareException { + int type = variadicUint().intValue(); + var decoder = decodeFunctions.get(type); + + if (decoder == null) { + throw new BareException("Unknown union type: " + type); + } + return new Union(type, decoder.apply(this)); + } + + @FunctionalInterface + public interface DecodeFunction<T> { + T apply(AggregateBareDecoder decoder) throws IOException, BareException; + } + +} diff --git a/bare/src/main/java/org/nobloat/bare/AggregateBareEncoder.java b/bare/src/main/java/org/nobloat/bare/AggregateBareEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..704dada5d7a9c13ae2ddb42bf122838a0510eb48 --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/AggregateBareEncoder.java @@ -0,0 +1,100 @@ +package org.nobloat.bare; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class AggregateBareEncoder extends PrimitiveBareEncoder { + + public AggregateBareEncoder(OutputStream os) { + super(os); + } + + public <T> void optional(Optional<T> value, EncodeFunction<T> encoder) throws IOException, BareException { + if (value.isPresent()) { + bool(true); + encoder.apply(value.get()); + } else { + bool(false); + } + } + + public <T> void array(T[] value, EncodeFunction<T> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(int[] value, EncodeFunction<Integer> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(long[] value, EncodeFunction<Long> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(short[] value, EncodeFunction<Short> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(boolean[] value, EncodeFunction<Boolean> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(float[] value, EncodeFunction<Float> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(double[] value, EncodeFunction<Double> itemEncoder) throws IOException, BareException { + for (var item : value) { + itemEncoder.apply(item); + } + } + + public void array(byte[] value) throws IOException { + for (var item : value) { + u8(item); + } + } + + public <T> void slice(List<T> value, EncodeFunction<T> itemEncoder) throws IOException, BareException { + variadicUInt(value.size()); + for (var item : value) { + itemEncoder.apply(item); + } + } + + public <K,V> void map(Map<K,V> values, EncodeFunction<K> keyEncoder, EncodeFunction<V> valueEncoder) throws IOException, BareException { + variadicUInt(values.size()); + for (var entry : values.entrySet()) { + keyEncoder.apply(entry.getKey()); + valueEncoder.apply(entry.getValue()); + } + } + + public void union(Union value, Map<Integer, EncodeFunction> encodeFunctions) throws IOException, BareException { + var encoder = encodeFunctions.get((value.type())); + if (encoder == null) { + throw new BareException("Unmapped union type: " + value.type()); + } + variadicUInt(value.type()); + encoder.apply(value.value); + } + + @FunctionalInterface + public interface EncodeFunction<T> { + void apply(T value) throws IOException, BareException; + } +} diff --git a/bare/src/main/java/org/nobloat/bare/Bare.java b/bare/src/main/java/org/nobloat/bare/Bare.java new file mode 100644 index 0000000000000000000000000000000000000000..84295754fa4fa2c44d0922473c5a13fcf784699e --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/Bare.java @@ -0,0 +1,4 @@ +package org.nobloat.bare; + +public class Bare { +} \ No newline at end of file diff --git a/bare/src/main/java/org/nobloat/bare/BareException.java b/bare/src/main/java/org/nobloat/bare/BareException.java new file mode 100644 index 0000000000000000000000000000000000000000..6b85ef2d119d19c180d15234d395e604bf15fca5 --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/BareException.java @@ -0,0 +1,8 @@ +package org.nobloat.bare; + +public class BareException extends Exception { + + public BareException(String s) { + super(s); + } +} diff --git a/bare/src/main/java/org/nobloat/bare/Int.java b/bare/src/main/java/org/nobloat/bare/Int.java new file mode 100644 index 0000000000000000000000000000000000000000..b863f4639c80bb269cae9eaac1b67099d1a2a8c5 --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/Int.java @@ -0,0 +1,16 @@ +package org.nobloat.bare; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.FIELD) +public @interface Int { + Type value() default Type.i32; + + enum Type { + u8, i8, u16, i16, u32, i32, u64, i64, ui, i + } +} diff --git a/bare/src/main/java/org/nobloat/bare/PrimitiveBareDecoder.java b/bare/src/main/java/org/nobloat/bare/PrimitiveBareDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..c4626c61f66b4fbebec37441347a940250a0e970 --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/PrimitiveBareDecoder.java @@ -0,0 +1,134 @@ +package org.nobloat.bare; + +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; + +public class PrimitiveBareDecoder { + + private final DataInputStream is; + private static final BigInteger UNSIGNED_LONG_MASK = BigInteger.ONE.shiftLeft(Long.SIZE).subtract(BigInteger.ONE); + + public int MaxSliceLength = 1000000000; + + public PrimitiveBareDecoder(InputStream is) { + this.is = new DataInputStream(is); + } + + public byte u8() throws IOException { + return is.readByte(); + } + + public int u16() throws IOException { + int byte1 = is.readByte() & 0xff; + int byte2 = is.readByte() & 0xff; + return (byte1 | byte2 << 8); + } + + public long u32() throws IOException { + long byte1 = is.readByte() & 0xff; + long byte2 = is.readByte() & 0xff; + long byte3 = is.readByte() & 0xff; + long byte4 = is.readByte() & 0xff; + return (byte4 << 24) | (byte3 << 16) | (byte2 << 8) | (byte1); + } + + public BigInteger u64() throws IOException { + return BigInteger.valueOf(i64()).and(UNSIGNED_LONG_MASK); + } + + public byte i8() throws IOException { + return is.readByte(); + } + + public short i16() throws IOException { + int byte1 = is.readByte() & 0xff; + int byte2 = is.readByte() & 0xff; + return (short) (byte2 << 8 | byte1); + } + + public int i32() throws IOException { + return (int) u32(); + } + + public long i64() throws IOException { + long byte1 = is.readByte() & 0xff; + long byte2 = is.readByte() & 0xff; + long byte3 = is.readByte() & 0xff; + long byte4 = is.readByte() & 0xff; + long byte5 = is.readByte() & 0xff; + long byte6 = is.readByte() & 0xff; + long byte7 = is.readByte() & 0xff; + long byte8 = is.readByte() & 0xff; + return (byte8 << 56) | (byte7 << 48) | (byte6 << 40) | (byte5 << 32) |(byte4 << 24) | (byte3 << 16) | + (byte2 << 8) | (byte1); + } + + public float f32() throws IOException { + return Float.intBitsToFloat(i32()); + } + + public double f64() throws IOException { + return Double.longBitsToDouble(i64()); + } + + public boolean bool() throws IOException { + return is.readByte() != 0; + } + + public long variadicInt() throws IOException { + BigInteger r = variadicUint(); + if (r.testBit(0)) { + return r.shiftRight(1).not().longValue(); + } + return r.shiftRight(1).longValue(); + } + + public BigInteger variadicUint() throws IOException { + BigInteger result = BigInteger.ZERO; + int shift = 0; + int b; + do { + b = is.readByte() & 0xff; + if (b >= 0x80) { + result = result.or(BigInteger.valueOf(b & 0x7F).shiftLeft(shift)); + shift += 7; + } else { + result = result.or(BigInteger.valueOf(b).shiftLeft(shift)); + } + } while (b >= 0x80); + return result; + } + + public String string() throws IOException, BareException { + int length = variadicUint().intValue(); + if (length > MaxSliceLength) { + throw new BareException(String.format("Decoding slice with length %d > %d max length", length, MaxSliceLength)); + } + var target = new byte[length]; + is.read(target); + return new String(target, StandardCharsets.UTF_8); + } + + public byte[] data(int length) throws IOException { + var result = new byte[length]; + for (int i=0; i < length; i++) { + result[i] = is.readByte(); + } + return result; + } + + public Byte[] data() throws IOException, BareException { + int length = variadicUint().intValue(); + if (length > MaxSliceLength) { + throw new BareException(String.format("Decoding slice with length %d > %d max length", length, MaxSliceLength)); + } + var result = new Byte[length]; + for (int i=0; i < length; i++) { + result[i] = is.readByte(); + } + return result; + } +} diff --git a/bare/src/main/java/org/nobloat/bare/PrimitiveBareEncoder.java b/bare/src/main/java/org/nobloat/bare/PrimitiveBareEncoder.java new file mode 100644 index 0000000000000000000000000000000000000000..e4a9e43ab1bdc54dbe4e39499cd14628c1336a9b --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/PrimitiveBareEncoder.java @@ -0,0 +1,156 @@ +package org.nobloat.bare; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.NotSerializableException; +import java.io.OutputStream; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; + +public class PrimitiveBareEncoder { + private final DataOutputStream os; + private static final BigInteger UNSIGNED_LONG_MASK = BigInteger.ONE.shiftLeft(Long.SIZE).subtract(BigInteger.ONE); + private final boolean verifyInput; + + public PrimitiveBareEncoder(OutputStream os, boolean verifyInput) { + this.verifyInput = verifyInput; + this.os = new DataOutputStream(os); + } + + public PrimitiveBareEncoder(OutputStream os) { + this(os, true); + } + + public void u8(byte b) throws IOException { + os.writeByte(b); + } + + public void u8(short b) throws IOException, BareException { + if (verifyInput && b > 255) { + throw new BareException("u8 must not exceed value of 255"); + } + os.writeByte(b); + } + + public void u16(int b) throws IOException, BareException { + if (verifyInput && b > 65535) { + throw new BareException("u16 must not exceed value of 65535"); + } + os.write(new byte[]{(byte) b, (byte) (b >> 8)}); + } + + public void u32(long b) throws IOException, BareException { + if (verifyInput && b > 4294967295L) { + throw new BareException("u16 must not exceed value of 4294967295"); + } + os.write(new byte[]{(byte) b, (byte) (b >> 8), (byte) (b >> 16), (byte) (b >> 24)}); + } + + public void u64(BigInteger b) throws IOException, BareException { + if (verifyInput && b.bitLength() > 64) { + throw new BareException("value for variadicUint must not have more than 64 bits, value has " + b.bitLength() + " bits"); + } + i64(b.and(UNSIGNED_LONG_MASK).longValue()); + } + + public void i8(short b) throws IOException, BareException { + if (verifyInput && b > 128 || b < -127) { + throw new BareException("i8 must not exceed range between 255 and -255"); + } + u8(b); + } + + public void i8(byte b) throws IOException { + u8(b); + } + + public void i16(short b) throws IOException, BareException { + u16(b); + } + + public void i32(int b) throws IOException, BareException { + u32(b); + } + + public void i64(long b) throws IOException { + os.write(new byte[]{(byte) b, (byte) (b >> 8), (byte) (b >> 16), (byte) (b >> 24), + (byte) (b >> 32), (byte) (b >> 40), (byte) (b >> 48), (byte) (b >> 56) + }); + } + + public void f32(float b) throws IOException, BareException { + i32(Float.floatToIntBits(b)); + } + + public void f64(double b) throws IOException { + i64(Double.doubleToLongBits(b)); + } + + public void bool(boolean b) throws IOException { + if (b) { + u8((byte) 1); + } else { + u8((byte) 0); + } + } + + public void data(byte[] data) throws IOException, BareException { + variadicUInt(data.length); + os.write(data); + } + + public void data(Byte[] data) throws IOException, BareException { + variadicUInt(data.length); + for(var b : data) { + os.write(b); + } + } + + public void string(String s) throws IOException, BareException { + if (s == null) { + data(new byte[]{}); + } else { + data(s.getBytes(StandardCharsets.UTF_8)); + } + } + + public int variadicUInt(BigInteger value) throws IOException, BareException { + if (verifyInput && value.bitLength() > 64) { + throw new BareException("value for variadicUint must not have more than 64 bits, value has " + value.bitLength() + " bits"); + } + if (verifyInput && value.signum() == -1) { + throw new BareException("value for variadicUint must not be negative: " + value); + } + + int i = 0; + while (value.longValue() >= 0x80) { + os.write((byte) (value.longValue() | 0x80)); + value = value.shiftRight(7); + i++; + } + os.write((byte) value.longValue()); + return i + 1; + } + + public int variadicUInt(long value) throws IOException, BareException { + if (verifyInput && value < 0) { + throw new BareException("value for variadicUint must not be negative: " + value); + } + int i = 0; + while (value >= 0x80) { + os.write((byte) (value | 0x80)); + value >>= 7; + i++; + } + os.write((byte) value); + return i + 1; + } + + public int variadicInt(long value) throws IOException, BareException { + long unsigned = value << 1; + if (unsigned < 0) { + unsigned = ~unsigned; + } + return variadicUInt(unsigned); + } +} diff --git a/bare/src/main/java/org/nobloat/bare/ReflectiveBareDecoder.java b/bare/src/main/java/org/nobloat/bare/ReflectiveBareDecoder.java new file mode 100644 index 0000000000000000000000000000000000000000..141ffefbd36a2b798134381c450022b5bc300e9b --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/ReflectiveBareDecoder.java @@ -0,0 +1,165 @@ +package org.nobloat.bare; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UnsupportedEncodingException; +import java.lang.reflect.Array; +import java.lang.reflect.Field; +import java.lang.reflect.ParameterizedType; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class ReflectiveBareDecoder extends AggregateBareDecoder { + + public static final List<String> INTEGER_TYPES = List.of(new String[]{"java.lang.Long", "java.lang.Integer", "java.lang.BigInteger", "java.lang.Short"}); + public static final List<String> PRIMITIVE_TYPES = List.of(new String[]{"java.lang.String", "java.lang.Boolean", "java.lang.Byte", "java.lang.Float", "java.lang.Double"}); + + public ReflectiveBareDecoder(InputStream inputStream) { + super(inputStream); + } + + public <T> Optional<T> optional(Class<T> c) throws IOException, BareException { + if (bool()) { + return Optional.of(readPrimitiveType(c)); + } + return Optional.empty(); + } + + public <T> List<T> slice(Class<T> c) throws IOException, ReflectiveOperationException, BareException { + var length = variadicUint().intValue(); + if (length > MaxSliceLength) { + throw new BareException(String.format("Decoding slice with entries %d > %d max length", length, MaxSliceLength)); + } + return array(c, length); + } + + public <T> List<T> array(Class<T> c, int length) throws IOException, ReflectiveOperationException, BareException { + var result = new ArrayList<T>(length); + for (int i = 0; i < length; i++) { + result.add(readType(c)); + } + return result; + } + + public <K, V> Map<K, V> map(Class<K> key, Class<V> value) throws IOException, ReflectiveOperationException, BareException { + assert PRIMITIVE_TYPES.contains(key.getName()); + + var length = variadicUint().intValue(); + + if (length > MaxMapLength) { + throw new BareException(String.format("Decoding map with entries %d > %d max length", length, MaxSliceLength)); + } + + var result = new HashMap<K, V>(length); + for(int i=0; i < length; i++) { + result.put(readPrimitiveType(key), readType(value)); + } + return result; + } + + + public Union union(Class<?>... possibleTypes) throws IOException, ReflectiveOperationException, BareException { + var union = new Union(possibleTypes); + int type = variadicUint().intValue(); + var clazz = union.type(type); + union.set(type, readType(clazz)); + return union; + } + + public <T> T enumeration(Class<? extends Enum> c) throws IOException, ReflectiveOperationException, BareException { + var enumValue = variadicUint().intValue(); + var valueField = c.getField("value"); + + for (var constant : c.getEnumConstants()) { + int value = (int) valueField.get(constant); + if (value == enumValue) { + return (T) Enum.valueOf(c, constant.name()); + } + } + + throw new BareException("Unexpected enum value: " + enumValue); + } + + @SuppressWarnings("unchecked") + public <T> T struct(Class<T> c) throws ReflectiveOperationException, IOException, BareException { + var fields = c.getFields(); + var result = c.getConstructor().newInstance(); + for(var f : fields) { + f.setAccessible(true); + if (INTEGER_TYPES.contains(f.getType().getName())) { + f.set(result, readIntegerType(f)); + } else if (PRIMITIVE_TYPES.contains(f.getType().getName())) { + f.set(result, readPrimitiveType(f.getType())); + } else if(f.getType().isArray()) { + var array = (T[])f.get(result); + f.set(result, array(f.getType().getComponentType(), array.length).toArray((Object[])Array.newInstance(f.getType().getComponentType(), array.length))); + } else if(f.getType().getName().equals("java.util.List")) { + ParameterizedType type = (ParameterizedType)f.getGenericType(); + var elementType = type.getActualTypeArguments()[0]; + f.set(result, slice((Class<?>) elementType)); + } else if(f.getType().getName().equals("java.util.Map")) { + ParameterizedType type = (ParameterizedType)f.getGenericType(); + var keyType = type.getActualTypeArguments()[0]; + var valueType = type.getActualTypeArguments()[0]; + f.set(result, map((Class<?>)keyType, (Class<?>) valueType)); + } else{ + f.set(result, readType(f.getType())); + } + } + return result; + } + + @SuppressWarnings("unchecked") + public <T> T readPrimitiveType(Class<?> c) throws IOException, BareException { + switch (c.getName()) { + case "java.lang.Boolean": + return (T) Boolean.valueOf(bool()); + case "java.lang.Byte": + return (T) Byte.valueOf(i8()); + case "java.lang.Float": + return (T) Float.valueOf(f32()); + case "java.lang.Double": + return (T) Double.valueOf(f64()); + case "java.lang.String": + return (T) string(); + default: + throw new UnsupportedOperationException("readType not implemented for " + c.getName()); + } + } + + @SuppressWarnings("unchecked") + public <T> T readIntegerType(Field f) throws IOException { + var annotation = f.getAnnotation(Int.class); + if (annotation == null) { + throw new UnsupportedEncodingException("Missing @Int type annotation on number field: " + f.getName()); + } + switch (annotation.value()) { + case i8: return (T) Byte.valueOf(i8()); + case u8: return (T) Byte.valueOf(u8()); + case i16: return (T) Short.valueOf(i16()); + case u16: return (T) Integer.valueOf(u16()); + case i32: return (T) Integer.valueOf(i32()); + case u32: return (T) Long.valueOf(u32()); + case u64: return (T) u64(); + case i64: return (T) Long.valueOf(i64()); + case i: return (T) Long.valueOf(variadicInt()); + case ui: return (T) variadicUint(); + default: + throw new UnsupportedEncodingException("Unknown Int type: " + annotation.value()); + } + } + + public <T> T readType(Class<T> c) throws IOException, ReflectiveOperationException, BareException { + try { + if (c.isEnum()) { + return enumeration((Class<? extends Enum>) c); + } + return readPrimitiveType(c); + } catch (UnsupportedOperationException | BareException e) { + return struct(c); + } + } +} diff --git a/bare/src/main/java/org/nobloat/bare/Union.java b/bare/src/main/java/org/nobloat/bare/Union.java new file mode 100644 index 0000000000000000000000000000000000000000..d701c29accd7055395f1f46256ddee7ac67ebba7 --- /dev/null +++ b/bare/src/main/java/org/nobloat/bare/Union.java @@ -0,0 +1,69 @@ +package org.nobloat.bare; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.HashMap; +import java.util.Map; + +public class Union { + + Map<Long, Class<?>> types; + Object value; + long type; + + public Union(int type, Object value) { + this.type = type; + this.value = value; + } + + public Union(Class<?> ...allowedTypes) { + this.types = new HashMap<>(); + for (var c : allowedTypes) { + var unionId = c.getAnnotation(Id.class); + if (unionId == null) { + throw new UnsupportedOperationException("Missing annotation @Union.Id on " + c.getName()); + } + this.types.put(unionId.value(), c); + } + } + + public void set(long id, Object object) { + if (types.containsKey(id)) { + this.value = object; + this.type = id; + } else { + throw new UnsupportedOperationException("Could not map union type: " + id); + } + } + + Class<?> type(long id) { + if (types.containsKey(id)) { + return types.get(id); + } + throw new UnsupportedOperationException("Unexpected union type: " + id); + } + + public int type() { + return (int) type; + } + + public <T> T get(Class<T> type) { + return (T)value; + } + + @Override + public String toString() { + return "Union{" + + "value=" + value + + '}'; + } + + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.TYPE) + public @interface Id { + long value() default 0; + } + +}