blob: f7e5214d59f70166845699a8a86afacd6ca11dc1 [file] [log] [blame]
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 Tamas Hubai
`default_nettype none
// activation functions used in the neural network
// leaky ReLU
module leaky_relu_comb (
input [`NUM_WIDTH-1:0] x,
output [`NUM_WIDTH-1:0] res
);
wire sig_x = x[`NUM_WIDTH-1];
assign res = sig_x ? {{(`LEAK_SHIFT){1'b1}}, x[`NUM_WIDTH-1:`LEAK_SHIFT]} : x;
endmodule
// derivative of leaky ReLU
module leaky_relu_diff_comb (
input [`NUM_WIDTH-1:0] x,
output [`NUM_WIDTH-1:0] res
);
wire sig_x = x[`NUM_WIDTH-1];
assign res = sig_x ?
{{(`INT_WIDTH+`LEAK_SHIFT-1){1'b0}}, 1'b1, {(`FRAC_WIDTH-`LEAK_SHIFT){1'b0}}} :
{{(`INT_WIDTH-1){1'b0}}, 1'b1, {(`FRAC_WIDTH){1'b0}}};
endmodule
// very rough approximation of 2^x, used in softmax
module approx_exp_comb (
input [`NUM_WIDTH-1:0] x,
output [`NUM_WIDTH-1:0] res
);
wire saturated = ~x[`NUM_WIDTH-1] & (x[`NUM_WIDTH-1:`FRAC_WIDTH] > `INT_WIDTH - 2);
assign res[`NUM_WIDTH-1] = 1'b0;
generate genvar g;
for (g=0; g<`NUM_WIDTH-1; g=g+1) begin:g_exp
assign res[g] = saturated | (x[`NUM_WIDTH-1:`FRAC_WIDTH] == g - `FRAC_WIDTH);
end
endgenerate
endmodule
// piecewise linear approximation of 1/x, used in softmax
module approx_inv_comb (
input [`NUM_WIDTH-1:0] x, // assuming x > 0
output [`NUM_WIDTH-1:0] res
);
wire [`NUM_WIDTH:0] bnd;
wire [`NUM_WIDTH-1:0] msb;
wire [`FRAC_WIDTH-1:0] m;
assign bnd[`NUM_WIDTH] = 0;
generate genvar g;
for (g=`NUM_WIDTH-1; g>=0; g=g-1) begin:g_msb
assign bnd[g] = bnd[g+1] | x[g];
assign msb[g] = bnd[g] & ~bnd[g+1];
end
for (g=0; g<`NUM_WIDTH; g=g+1) begin:g_mant
wire [`FRAC_WIDTH-1:0] mc = msb[g] ? ({x, {(`FRAC_WIDTH){1'b0}}} >> g) : {(`FRAC_WIDTH){1'b0}};
wire [`FRAC_WIDTH-1:0] ms;
if (g==0) begin:i_mantz
assign ms = mc;
end else begin:i_mantnz
assign ms = g_mant[g-1].ms | mc;
end
end
assign m = g_mant[`NUM_WIDTH-1].ms;
// m contains the input bit-shifted to within [1, 2), with its integer part (i.e. 1) removed
// for 1 <= x < 1.25 we use 1/x ~= 115/64 - 51/64 x
wire [`FRAC_WIDTH:0] minv_a = {7'd115, {(`FRAC_WIDTH-6){1'b0}}} - (({{(`FRAC_WIDTH){1'b0}}, 7'd51} * {7'd1, m}) >> 6);
// for 1.25 <= x < 1.5 we use 1/x ~= 95/64 - 35/64 x
wire [`FRAC_WIDTH:0] minv_b = {7'd95, {(`FRAC_WIDTH-6){1'b0}}} - (({{(`FRAC_WIDTH){1'b0}}, 7'd35} * {7'd1, m}) >> 6);
// for 1.5 <= x < 1.75 we use 1/x ~= 157/128 - 3/8 x
wire [`FRAC_WIDTH:0] minv_c = {8'd157, {(`FRAC_WIDTH-7){1'b0}}} - (({{(`FRAC_WIDTH){1'b0}}, 3'd3} * {3'd1, m}) >> 3);
// for 1.75 <= x < 2 we use 1/x ~= 17/16 - 9/32 x
wire [`FRAC_WIDTH:0] minv_d = {5'd17, {(`FRAC_WIDTH-4){1'b0}}} - (({{(`FRAC_WIDTH){1'b0}}, 5'd9} * {5'd1, m}) >> 5);
wire [`FRAC_WIDTH:0] minv = m[`FRAC_WIDTH-1] ? (m[`FRAC_WIDTH-2] ? minv_d : minv_c) : (m[`FRAC_WIDTH-2] ? minv_b : minv_a);
for (g=0; g<`NUM_WIDTH; g=g+1) begin:g_mrec
wire [`NUM_WIDTH-1:0] mrc = msb[g] ? ({{(`INT_WIDTH){1'b0}}, minv, {(`FRAC_WIDTH){1'b0}}} >> g) : {(`NUM_WIDTH){1'b0}};
wire [`NUM_WIDTH-1:0] mrs;
if (g==0) begin:i_mrecz
assign mrs = mrc;
end else begin:i_mrecnz
assign mrs = g_mrec[g-1].mrs | mrc;
end
end
assign res = g_mrec[`NUM_WIDTH-1].mrs;
endgenerate
endmodule
// softmax using approximated 2^x and 1/x
module approx_softmax_comb (
input [`OUTPUT_SIZE*`NUM_WIDTH-1:0] x_pk,
output [`OUTPUT_SIZE*`NUM_WIDTH-1:0] res_pk
);
wire [`NUM_WIDTH-1:0] x[`OUTPUT_SIZE-1:0];
wire [`NUM_WIDTH-1:0] res[`OUTPUT_SIZE-1:0];
`UNPACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, x, x_pk)
`PACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, res, res_pk)
wire [`NUM_WIDTH-1:0] xmax;
wire [`INDEX_WIDTH-1:0] _ignore;
max_comb i_max (
.x_pk,
.res_val(xmax),
.res_pos(_ignore)
);
wire [`NUM_WIDTH-1:0] dexp[`OUTPUT_SIZE-1:0];
wire [`NUM_WIDTH-1:0] esum;
generate genvar g;
for (g=0; g<`OUTPUT_SIZE; g=g+1) begin:g_expsum
wire [`NUM_WIDTH-1:0] diff;
sub_sat_comb i_sub (
.a(x[g]),
.b(xmax),
.res(diff)
);
approx_exp_comb i_exp (
.x(diff),
.res(dexp[g])
);
wire [`NUM_WIDTH-1:0] psum;
if (g==0) begin
assign psum = dexp[g];
end else begin
assign psum = g_expsum[g-1].psum + dexp[g];
end
end
assign esum = g_expsum[`OUTPUT_SIZE-1].psum;
endgenerate
wire [`NUM_WIDTH-1:0] isum;
approx_inv_comb i_inv (
.x(esum),
.res(isum)
);
generate
for (g=0; g<`OUTPUT_SIZE; g=g+1) begin:g_div
mul_sat_comb i_mul (
.a(dexp[g]),
.b(isum),
.res(res[g])
);
end
endgenerate
endmodule
// derivative of softmax
module approx_softmax_diff_comb (
input [`OUTPUT_SIZE*`NUM_WIDTH-1:0] x_pk,
output [`OUTPUT_SIZE*`NUM_WIDTH-1:0] res_pk
);
wire [`OUTPUT_SIZE*`NUM_WIDTH-1:0] sm_pk;
approx_softmax_comb i_sm (
.x_pk,
.res_pk(sm_pk)
);
wire [`NUM_WIDTH-1:0] sm[`OUTPUT_SIZE-1:0];
wire [`NUM_WIDTH-1:0] res[`OUTPUT_SIZE-1:0];
`UNPACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, sm, sm_pk)
`PACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, res, res_pk)
generate genvar g;
for (g=0; g<`OUTPUT_SIZE; g=g+1) begin
wire [`NUM_WIDTH-1:0] sqr;
mul_sat_comb i_mul (
.a(sm[g]),
.b(sm[g]),
.res(sqr)
);
sub_sat_comb i_sub (
.a(sm[g]),
.b(sqr),
.res(res[g])
);
end
endgenerate
endmodule