blob: 4b4abab764a5c1d8041595a21bd4c34740927c43 [file] [log] [blame]
// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2022 Tamas Hubai
`default_nettype none
// wrapper for neural_network that allows
// - setting initial weights, inputs & ground truth
// - retrieving outputs
// - starting forward or backward propagation
// - detecting when forward or backward propagation finishes
// by using memory i/o within a single virtual address space
module neural_interface (
input clk,
input [23:0] addr,
input [`NUM_WIDTH-1:0] data_in,
input we,
output reg [`NUM_WIDTH-1:0] data_out
);
reg fp;
wire fp_out;
reg [`NUM_WIDTH-1:0] a0[`INPUT_SIZE-1:0];
wire [`NUM_WIDTH-1:0] a3[`OUTPUT_SIZE-1:0];
reg bp;
wire bp_out;
reg [`NUM_WIDTH-1:0] g3[`OUTPUT_SIZE-1:0];
reg wu;
reg [1:0] w_layer;
reg [`INDEX_WIDTH-1:0] w_i;
reg [`INDEX_WIDTH-1:0] w_j;
reg [`NUM_WIDTH-1:0] w_in;
wire [`NUM_WIDTH-1:0] w_out;
wire [`INPUT_SIZE*`NUM_WIDTH-1:0] a0_pk;
`PACK_ARRAY(`NUM_WIDTH, `INPUT_SIZE, a0, a0_pk)
wire [`OUTPUT_SIZE*`NUM_WIDTH-1:0] a3_pk;
`UNPACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, a3, a3_pk)
wire [`OUTPUT_SIZE*`NUM_WIDTH-1:0] g3_pk;
`PACK_ARRAY(`NUM_WIDTH, `OUTPUT_SIZE, g3, g3_pk)
neural_network i_nn (
.clk,
.fp,
.fp_out,
.a0_pk,
.a3_pk,
.bp,
.bp_out,
.g3_pk,
.wu,
.w_layer,
.w_i,
.w_j,
.w_in,
.w_out
);
wire [`NUM_WIDTH-1:0] _ignore_a3;
wire [`INDEX_WIDTH-1:0] argmax_a3;
max_comb i_max_a3 (
.x_pk(a3_pk),
.res_val(_ignore_a3),
.res_pos(argmax_a3)
);
wire [`NUM_WIDTH-1:0] _ignore_g3;
wire [`INDEX_WIDTH-1:0] argmax_g3;
max_comb i_max_g3 (
.x_pk(g3_pk),
.res_val(_ignore_g3),
.res_pos(argmax_g3)
);
reg fp_out_hold;
reg bp_out_hold;
wire fp_out_hn = fp_out_hold | fp_out;
wire bp_out_hn = bp_out_hold | bp_out;
wire [3:0] sel = addr[23:20];
wire [9:0] index_h = addr[19:10];
wire [9:0] index_l = addr[9:0];
always @(posedge clk) begin
fp <= 0;
bp <= 0;
wu <= 0;
fp_out_hold <= fp_out_hn;
bp_out_hold <= bp_out_hn;
data_out <= {(`NUM_WIDTH){1'b0}};
if (sel[3:2]==0) begin
w_layer <= sel[1:0];
w_i <= index_h;
w_j <= index_l;
data_out <= w_out;
if (we) begin
w_in <= data_in;
wu <= 1;
end
end else if (sel==4) begin
data_out <= a0[index_l];
if (we) begin
a0[index_l] <= data_in;
end
end else if (sel==5) begin
data_out <= a3[index_l];
end else if (sel==6) begin
data_out <= g3[index_l];
if (we) begin
g3[index_l] <= data_in;
end
end else if (sel==7) begin
if (index_l == 0) begin
data_out <= {(`NUM_WIDTH){fp_out_hn}};
if (we) begin
if (|data_in) begin
fp <= 1;
end else begin
fp_out_hold <= 0;
end
end
end else if (index_l == 1) begin
data_out <= {(`NUM_WIDTH){bp_out_hn}};
if (we) begin
if (|data_in) begin
bp <= 1;
end else begin
bp_out_hold <= 0;
end
end
end else if (index_l == 2) begin
data_out <= argmax_a3 << `FRAC_WIDTH;
end else if (index_l == 3) begin
data_out <= argmax_g3 << `FRAC_WIDTH;
end
end
end
endmodule